Skip to content

Commit 205da7f

Browse files
committed
Fix bug in infer_static_shape of graphs involving the shape of scalars
1 parent 2cef9c0 commit 205da7f

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

pytensor/tensor/basic.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,11 +1406,8 @@ def infer_static_shape(
14061406
`shape` will be validated and constant folded. As a result, this function
14071407
can be expensive and shouldn't be used unless absolutely necessary.
14081408
1409-
It mostly exists as a hold-over from pre-static shape times, when it was
1410-
required in order to produce correct broadcastable arrays and prevent
1411-
some graphs from being unusable. Now, it is no longer strictly required,
1412-
so don't use it unless you want the same shape graphs to be rewritten
1413-
multiple times during graph construction.
1409+
It is often needed for `Op`s whose static shape and broadcastable flags
1410+
depend on the values of their inputs, such as `Alloc` and `RandomVariable`.
14141411
14151412
Returns
14161413
-------

pytensor/tensor/rewriting/shape.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
992992
return [specify_shape(inner_obj, shape)]
993993

994994

995+
_empty_shape = constant([], dtype="int64")
996+
997+
995998
@register_infer_shape
996999
@node_rewriter([Shape])
9971000
def local_shape_ground(fgraph, node):
9981001
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
9991002
[x] = node.inputs
10001003
static_shape = x.type.shape
1004+
if len(static_shape) == 0:
1005+
return [_empty_shape]
10011006
if not any(dim is None for dim in static_shape):
10021007
return [stack([constant(dim, dtype="int64") for dim in static_shape])]
10031008

tests/tensor/test_basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,7 @@ def test_runtime_broadcast(self, mode):
908908
self.check_runtime_broadcast(mode)
909909

910910

911-
def test_infer_shape():
911+
def test_infer_static_shape():
912912
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
913913
infer_static_shape([constant(1.0)])
914914

@@ -925,6 +925,10 @@ def test_infer_shape():
925925
sh, static_shape = infer_static_shape(specify_size)
926926
assert static_shape == (1,)
927927

928+
x = scalar("x")
929+
sh, static_shape = infer_static_shape([x.size])
930+
assert static_shape == (1,)
931+
928932

929933
# This is slow for the ('int8', 3) version.
930934
def test_eye():

0 commit comments

Comments
 (0)