Skip to content

Commit e3d2750

Browse files
committed
Remove useless SpecifyShape
1 parent f4e249d commit e3d2750

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

pytensor/tensor/rewriting/shape.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
)
4949
from pytensor.tensor.subtensor import Subtensor, get_idx_list
5050
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
51-
from pytensor.tensor.type_other import NoneConst
51+
from pytensor.tensor.type_other import NoneConst, NoneTypeT
5252

5353

5454
class ShapeFeature(Feature):
@@ -974,6 +974,35 @@ def local_reshape_lift(fgraph, node):
974974
return [e]
975975

976976

977+
@register_useless
978+
@register_canonicalize
979+
@register_stabilize
980+
@register_specialize
981+
@node_rewriter([SpecifyShape])
982+
def local_useless_specify_shape(fgraph, node):
983+
"""Remove SpecifyShape when the asserted shapes are already encoded in the static type of the input."""
984+
x, *shape = node.inputs
985+
for static_dim, specified_dim in zip(x.type.shape, shape, strict=True):
986+
if isinstance(specified_dim.type, NoneTypeT):
987+
continue
988+
if static_dim is None:
989+
# There is an unknown static dimension that is being specified
990+
return None
991+
if not (
992+
isinstance(specified_dim, Constant) and specified_dim.data == static_dim
993+
):
994+
# The specified dim is either:
995+
# 1. Not constant or
996+
# 2. Constant that does not match the static dim
997+
# Either way, we must keep the SpecifyShape
998+
return None
999+
1000+
# If we arrived here, it means SpecifyShape was already encoded in the static shape
1001+
# We don't need it
1002+
copy_stack_trace(node.outputs[0], x)
1003+
return [x]
1004+
1005+
9771006
@register_infer_shape
9781007
@register_useless
9791008
@register_canonicalize
@@ -1189,10 +1218,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
11891218
@register_specialize
11901219
@node_rewriter([Unbroadcast])
11911220
def local_useless_unbroadcast(fgraph, node):
1192-
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
1193-
1194-
TODO: Implement equivalent rewrite for SpecifyShape
1195-
"""
1221+
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern."""
11961222
if isinstance(node.op, Unbroadcast):
11971223
x = node.inputs[0]
11981224
if x.type.ndim == node.outputs[0].type.ndim and all(

tests/tensor/rewriting/test_shape.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ShapeFeature,
2424
local_reshape_to_dimshuffle,
2525
local_useless_reshape,
26+
local_useless_specify_shape,
2627
)
2728
from pytensor.tensor.shape import (
2829
Reshape,
@@ -476,6 +477,30 @@ def test_vector_dim_err(self):
476477
shape_feature.same_shape(x, o, 0, 1)
477478

478479

480+
def test_useless_specify_shape():
481+
x = tensor("x", shape=(None, 5, 3))
482+
483+
# We avoid the helper specify_shape that optimizes some (but not all) cases eagerly
484+
ss = SpecifyShape()
485+
486+
out = ss(x, None, 5, None)
487+
assert isinstance(out.owner.op, SpecifyShape)
488+
ret = local_useless_specify_shape.transform(None, out.owner)
489+
assert ret == [x]
490+
491+
# SpecifyShape is needed to enfore unknown dim is 3
492+
out = ss(x, 3, 5, None)
493+
assert isinstance(out.owner.op, SpecifyShape)
494+
ret = local_useless_specify_shape.transform(None, out.owner)
495+
assert ret is None
496+
497+
# SpecifyShape is needed to raise mismatch between static and specified dim
498+
out = ss(x, None, 5, 4)
499+
assert isinstance(out.owner.op, SpecifyShape)
500+
ret = local_useless_specify_shape.transform(None, out.owner)
501+
assert ret is None
502+
503+
479504
@pytest.mark.parametrize(
480505
"shape",
481506
[lscalar(), iscalar()],

0 commit comments

Comments
 (0)