diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 946660e431..9ab328266a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1406,11 +1406,8 @@ def infer_static_shape( `shape` will be validated and constant folded. As a result, this function can be expensive and shouldn't be used unless absolutely necessary. - It mostly exists as a hold-over from pre-static shape times, when it was - required in order to produce correct broadcastable arrays and prevent - some graphs from being unusable. Now, it is no longer strictly required, - so don't use it unless you want the same shape graphs to be rewritten - multiple times during graph construction. + It is often needed for `Op`s whose static shape and broadcastable flags + depend on the values of their inputs, such as `Alloc` and `RandomVariable`. Returns ------- diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index f6cb496d0a..e1922143b0 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -20,11 +20,7 @@ infer_static_shape, ) from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType -from pytensor.tensor.random.utils import ( - broadcast_params, - normalize_size_param, - params_broadcast_shapes, -) +from pytensor.tensor.random.utils import broadcast_params, normalize_size_param from pytensor.tensor.shape import shape_tuple from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type_other import NoneConst @@ -156,6 +152,13 @@ def _infer_shape( from pytensor.tensor.extra_ops import broadcast_shape_iter + if self.ndim_supp == 0: + supp_shape = () + else: + supp_shape = tuple( + self._supp_shape_from_params(dist_params, param_shapes=param_shapes) + ) + size_len = get_vector_length(size) if size_len > 0: @@ -171,19 +174,11 @@ def _infer_shape( f"Size length must be 0 or >= {param_batched_dims}" ) - if self.ndim_supp == 0: - return size - else: - supp_shape = self._supp_shape_from_params( - dist_params, param_shapes=param_shapes - ) - return tuple(size) + tuple(supp_shape) - - # Broadcast the parameters - param_shapes = params_broadcast_shapes( - param_shapes or [shape_tuple(p) for p in dist_params], - self.ndims_params, - ) + return tuple(size) + supp_shape + + # Size was not provided, we must infer it from the shape of the parameters + if param_shapes is None: + param_shapes = [shape_tuple(p) for p in dist_params] def extract_batch_shape(p, ps, n): shape = tuple(ps) @@ -191,10 +186,10 @@ def extract_batch_shape(p, ps, n): if n == 0: return shape - batch_shape = [ + batch_shape = tuple( s if not b else constant(1, "int64") for s, b in zip(shape[:-n], p.type.broadcastable[:-n]) - ] + ) return batch_shape # These are versions of our actual parameters with the anticipated @@ -218,15 +213,8 @@ def extract_batch_shape(p, ps, n): # Distribution has no parameters batch_shape = () - if self.ndim_supp == 0: - supp_shape = () - else: - supp_shape = self._supp_shape_from_params( - dist_params, - param_shapes=param_shapes, - ) + shape = batch_shape + supp_shape - shape = tuple(batch_shape) + tuple(supp_shape) if not shape: shape = constant([], dtype="int64") diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index aa24f217bd..39c149ad87 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node): return [specify_shape(inner_obj, shape)] +_empty_shape = constant([], dtype="int64") + + @register_infer_shape @node_rewriter([Shape]) def local_shape_ground(fgraph, node): """Rewrite shape(x) -> make_vector(x.type.shape) when this is constant.""" [x] = node.inputs static_shape = x.type.shape + if len(static_shape) == 0: + return [_empty_shape] if not any(dim is None for dim in static_shape): return [stack([constant(dim, dtype="int64") for dim in static_shape])] diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 4a389811e1..63661bd177 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -206,6 +206,42 @@ def test_RandomVariable_incompatible_size(): rv_op(np.zeros((2, 4, 3)), 1, size=(4,)) +class MultivariateRandomVariable(RandomVariable): + name = "MultivariateRandomVariable" + ndim_supp = 1 + ndims_params = (1, 2) + dtype = "floatX" + + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return [dist_params[0].shape[-1]] + + +@config.change_flags(compute_test_value="off") +def test_multivariate_rv_infer_static_shape(): + """Test that infer shape for multivariate random variable works when a parameter must be broadcasted.""" + mv_op = MultivariateRandomVariable() + + param1 = tensor(shape=(10, 2, 3)) + param2 = tensor(shape=(10, 2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(2, 3)) + param2 = tensor(shape=(10, 2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(10, 2, 3)) + param2 = tensor(shape=(2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(10, 1, 3)) + param2 = tensor(shape=(2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(2, 3)) + param2 = tensor(shape=(2, 3, 3)) + assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3) + + def test_vectorize_node(): vec = tensor(shape=(None,)) vec.tag.test_value = [0, 0, 0] diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 3ce5ffce63..81dc14ef66 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -908,7 +908,7 @@ def test_runtime_broadcast(self, mode): self.check_runtime_broadcast(mode) -def test_infer_shape(): +def test_infer_static_shape(): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): infer_static_shape([constant(1.0)]) @@ -925,6 +925,10 @@ def test_infer_shape(): sh, static_shape = infer_static_shape(specify_size) assert static_shape == (1,) + x = scalar("x") + sh, static_shape = infer_static_shape([x.size]) + assert static_shape == (1,) + # This is slow for the ('int8', 3) version. def test_eye():