Skip to content

Fix bugs in shape inference #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,11 +1406,8 @@ def infer_static_shape(
`shape` will be validated and constant folded. As a result, this function
can be expensive and shouldn't be used unless absolutely necessary.

It mostly exists as a hold-over from pre-static shape times, when it was
required in order to produce correct broadcastable arrays and prevent
some graphs from being unusable. Now, it is no longer strictly required,
so don't use it unless you want the same shape graphs to be rewritten
multiple times during graph construction.
It is often needed for `Op`s whose static shape and broadcastable flags
depend on the values of their inputs, such as `Alloc` and `RandomVariable`.

Returns
-------
Expand Down
44 changes: 16 additions & 28 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
infer_static_shape,
)
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
params_broadcast_shapes,
)
from pytensor.tensor.random.utils import broadcast_params, normalize_size_param
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst
Expand Down Expand Up @@ -156,6 +152,13 @@ def _infer_shape(

from pytensor.tensor.extra_ops import broadcast_shape_iter

if self.ndim_supp == 0:
supp_shape = ()
else:
supp_shape = tuple(
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
)

size_len = get_vector_length(size)

if size_len > 0:
Expand All @@ -171,30 +174,22 @@ def _infer_shape(
f"Size length must be 0 or >= {param_batched_dims}"
)

if self.ndim_supp == 0:
return size
else:
supp_shape = self._supp_shape_from_params(
dist_params, param_shapes=param_shapes
)
return tuple(size) + tuple(supp_shape)

# Broadcast the parameters
param_shapes = params_broadcast_shapes(
param_shapes or [shape_tuple(p) for p in dist_params],
self.ndims_params,
)
return tuple(size) + supp_shape

# Size was not provided, we must infer it from the shape of the parameters
if param_shapes is None:
param_shapes = [shape_tuple(p) for p in dist_params]

def extract_batch_shape(p, ps, n):
shape = tuple(ps)

if n == 0:
return shape

batch_shape = [
batch_shape = tuple(
s if not b else constant(1, "int64")
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
]
)
return batch_shape

# These are versions of our actual parameters with the anticipated
Expand All @@ -218,15 +213,8 @@ def extract_batch_shape(p, ps, n):
# Distribution has no parameters
batch_shape = ()

if self.ndim_supp == 0:
supp_shape = ()
else:
supp_shape = self._supp_shape_from_params(
dist_params,
param_shapes=param_shapes,
)
shape = batch_shape + supp_shape

shape = tuple(batch_shape) + tuple(supp_shape)
if not shape:
shape = constant([], dtype="int64")

Expand Down
5 changes: 5 additions & 0 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
return [specify_shape(inner_obj, shape)]


_empty_shape = constant([], dtype="int64")


@register_infer_shape
@node_rewriter([Shape])
def local_shape_ground(fgraph, node):
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
[x] = node.inputs
static_shape = x.type.shape
if len(static_shape) == 0:
return [_empty_shape]
if not any(dim is None for dim in static_shape):
return [stack([constant(dim, dtype="int64") for dim in static_shape])]

Expand Down
36 changes: 36 additions & 0 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,42 @@ def test_RandomVariable_incompatible_size():
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))


class MultivariateRandomVariable(RandomVariable):
name = "MultivariateRandomVariable"
ndim_supp = 1
ndims_params = (1, 2)
dtype = "floatX"

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return [dist_params[0].shape[-1]]


@config.change_flags(compute_test_value="off")
def test_multivariate_rv_infer_static_shape():
"""Test that infer shape for multivariate random variable works when a parameter must be broadcasted."""
mv_op = MultivariateRandomVariable()

param1 = tensor(shape=(10, 2, 3))
param2 = tensor(shape=(10, 2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)

param1 = tensor(shape=(2, 3))
param2 = tensor(shape=(10, 2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)

param1 = tensor(shape=(10, 2, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)

param1 = tensor(shape=(10, 1, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)

param1 = tensor(shape=(2, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)


def test_vectorize_node():
vec = tensor(shape=(None,))
vec.tag.test_value = [0, 0, 0]
Expand Down
6 changes: 5 additions & 1 deletion tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def test_runtime_broadcast(self, mode):
self.check_runtime_broadcast(mode)


def test_infer_shape():
def test_infer_static_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_static_shape([constant(1.0)])

Expand All @@ -925,6 +925,10 @@ def test_infer_shape():
sh, static_shape = infer_static_shape(specify_size)
assert static_shape == (1,)

x = scalar("x")
sh, static_shape = infer_static_shape([x.size])
assert static_shape == (1,)


# This is slow for the ('int8', 3) version.
def test_eye():
Expand Down