From 9146163359b926d533b8d600fad2b8f2b4868556 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 11 Dec 2023 16:33:53 +0100 Subject: [PATCH 1/2] Fix bug in infer_static_shape of graphs involving the shape of scalars --- pytensor/tensor/basic.py | 7 ++----- pytensor/tensor/rewriting/shape.py | 5 +++++ tests/tensor/test_basic.py | 6 +++++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 946660e431..9ab328266a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1406,11 +1406,8 @@ def infer_static_shape( `shape` will be validated and constant folded. As a result, this function can be expensive and shouldn't be used unless absolutely necessary. - It mostly exists as a hold-over from pre-static shape times, when it was - required in order to produce correct broadcastable arrays and prevent - some graphs from being unusable. Now, it is no longer strictly required, - so don't use it unless you want the same shape graphs to be rewritten - multiple times during graph construction. + It is often needed for `Op`s whose static shape and broadcastable flags + depend on the values of their inputs, such as `Alloc` and `RandomVariable`. Returns ------- diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index aa24f217bd..39c149ad87 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node): return [specify_shape(inner_obj, shape)] +_empty_shape = constant([], dtype="int64") + + @register_infer_shape @node_rewriter([Shape]) def local_shape_ground(fgraph, node): """Rewrite shape(x) -> make_vector(x.type.shape) when this is constant.""" [x] = node.inputs static_shape = x.type.shape + if len(static_shape) == 0: + return [_empty_shape] if not any(dim is None for dim in static_shape): return [stack([constant(dim, dtype="int64") for dim in static_shape])] diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 3ce5ffce63..81dc14ef66 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -908,7 +908,7 @@ def test_runtime_broadcast(self, mode): self.check_runtime_broadcast(mode) -def test_infer_shape(): +def test_infer_static_shape(): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): infer_static_shape([constant(1.0)]) @@ -925,6 +925,10 @@ def test_infer_shape(): sh, static_shape = infer_static_shape(specify_size) assert static_shape == (1,) + x = scalar("x") + sh, static_shape = infer_static_shape([x.size]) + assert static_shape == (1,) + # This is slow for the ('int8', 3) version. def test_eye(): From 4dbf2e2cfff2ca5d3e02ffe2f566826ca650e29a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 11 Dec 2023 17:28:13 +0100 Subject: [PATCH 2/2] Fix shape inference for multivariate random Ops When size is not provided, the batch shapes of the parameters were being broadcasted twice, and the second time, wrongly, due to mixing static shape of the original parameters and the potentially larger shape of the just broadcasted parameters. --- pytensor/tensor/random/op.py | 44 +++++++++++++--------------------- tests/tensor/random/test_op.py | 36 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index f6cb496d0a..e1922143b0 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -20,11 +20,7 @@ infer_static_shape, ) from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType -from pytensor.tensor.random.utils import ( - broadcast_params, - normalize_size_param, - params_broadcast_shapes, -) +from pytensor.tensor.random.utils import broadcast_params, normalize_size_param from pytensor.tensor.shape import shape_tuple from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type_other import NoneConst @@ -156,6 +152,13 @@ def _infer_shape( from pytensor.tensor.extra_ops import broadcast_shape_iter + if self.ndim_supp == 0: + supp_shape = () + else: + supp_shape = tuple( + self._supp_shape_from_params(dist_params, param_shapes=param_shapes) + ) + size_len = get_vector_length(size) if size_len > 0: @@ -171,19 +174,11 @@ def _infer_shape( f"Size length must be 0 or >= {param_batched_dims}" ) - if self.ndim_supp == 0: - return size - else: - supp_shape = self._supp_shape_from_params( - dist_params, param_shapes=param_shapes - ) - return tuple(size) + tuple(supp_shape) - - # Broadcast the parameters - param_shapes = params_broadcast_shapes( - param_shapes or [shape_tuple(p) for p in dist_params], - self.ndims_params, - ) + return tuple(size) + supp_shape + + # Size was not provided, we must infer it from the shape of the parameters + if param_shapes is None: + param_shapes = [shape_tuple(p) for p in dist_params] def extract_batch_shape(p, ps, n): shape = tuple(ps) @@ -191,10 +186,10 @@ def extract_batch_shape(p, ps, n): if n == 0: return shape - batch_shape = [ + batch_shape = tuple( s if not b else constant(1, "int64") for s, b in zip(shape[:-n], p.type.broadcastable[:-n]) - ] + ) return batch_shape # These are versions of our actual parameters with the anticipated @@ -218,15 +213,8 @@ def extract_batch_shape(p, ps, n): # Distribution has no parameters batch_shape = () - if self.ndim_supp == 0: - supp_shape = () - else: - supp_shape = self._supp_shape_from_params( - dist_params, - param_shapes=param_shapes, - ) + shape = batch_shape + supp_shape - shape = tuple(batch_shape) + tuple(supp_shape) if not shape: shape = constant([], dtype="int64") diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 4a389811e1..63661bd177 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -206,6 +206,42 @@ def test_RandomVariable_incompatible_size(): rv_op(np.zeros((2, 4, 3)), 1, size=(4,)) +class MultivariateRandomVariable(RandomVariable): + name = "MultivariateRandomVariable" + ndim_supp = 1 + ndims_params = (1, 2) + dtype = "floatX" + + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return [dist_params[0].shape[-1]] + + +@config.change_flags(compute_test_value="off") +def test_multivariate_rv_infer_static_shape(): + """Test that infer shape for multivariate random variable works when a parameter must be broadcasted.""" + mv_op = MultivariateRandomVariable() + + param1 = tensor(shape=(10, 2, 3)) + param2 = tensor(shape=(10, 2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(2, 3)) + param2 = tensor(shape=(10, 2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(10, 2, 3)) + param2 = tensor(shape=(2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(10, 1, 3)) + param2 = tensor(shape=(2, 3, 3)) + assert mv_op(param1, param2).type.shape == (10, 2, 3) + + param1 = tensor(shape=(2, 3)) + param2 = tensor(shape=(2, 3, 3)) + assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3) + + def test_vectorize_node(): vec = tensor(shape=(None,)) vec.tag.test_value = [0, 0, 0]