|
48 | 48 | )
|
49 | 49 | from pytensor.tensor.subtensor import Subtensor, get_idx_list
|
50 | 50 | from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
|
51 |
| -from pytensor.tensor.type_other import NoneConst |
| 51 | +from pytensor.tensor.type_other import NoneConst, NoneTypeT |
52 | 52 |
|
53 | 53 |
|
54 | 54 | class ShapeFeature(Feature):
|
@@ -974,6 +974,35 @@ def local_reshape_lift(fgraph, node):
|
974 | 974 | return [e]
|
975 | 975 |
|
976 | 976 |
|
| 977 | +@register_useless |
| 978 | +@register_canonicalize |
| 979 | +@register_stabilize |
| 980 | +@register_specialize |
| 981 | +@node_rewriter([SpecifyShape]) |
| 982 | +def local_useless_specify_shape(fgraph, node): |
| 983 | + """Remove SpecifyShape when the asserted shapes are already encoded in the static type of the input.""" |
| 984 | + x, *shape = node.inputs |
| 985 | + for static_dim, specified_dim in zip(x.type.shape, shape, strict=True): |
| 986 | + if isinstance(specified_dim.type, NoneTypeT): |
| 987 | + continue |
| 988 | + if static_dim is None: |
| 989 | + # There is an unknown static dimension that is being specified |
| 990 | + return None |
| 991 | + if not ( |
| 992 | + isinstance(specified_dim, Constant) and specified_dim.data == static_dim |
| 993 | + ): |
| 994 | + # The specified dim is either: |
| 995 | + # 1. Not constant or |
| 996 | + # 2. Constant that does not match the static dim |
| 997 | + # Either way, we must keep the SpecifyShape |
| 998 | + return None |
| 999 | + |
| 1000 | + # If we arrived here, it means SpecifyShape was already encoded in the static shape |
| 1001 | + # We don't need it |
| 1002 | + copy_stack_trace(node.outputs[0], x) |
| 1003 | + return [x] |
| 1004 | + |
| 1005 | + |
977 | 1006 | @register_infer_shape
|
978 | 1007 | @register_useless
|
979 | 1008 | @register_canonicalize
|
@@ -1189,10 +1218,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
|
1189 | 1218 | @register_specialize
|
1190 | 1219 | @node_rewriter([Unbroadcast])
|
1191 | 1220 | def local_useless_unbroadcast(fgraph, node):
|
1192 |
| - """Remove `Unbroadcast` if it does not actually change the broadcasting pattern. |
1193 |
| -
|
1194 |
| - TODO: Implement equivalent rewrite for SpecifyShape |
1195 |
| - """ |
| 1221 | + """Remove `Unbroadcast` if it does not actually change the broadcasting pattern.""" |
1196 | 1222 | if isinstance(node.op, Unbroadcast):
|
1197 | 1223 | x = node.inputs[0]
|
1198 | 1224 | if x.type.ndim == node.outputs[0].type.ndim and all(
|
|
0 commit comments