Skip to content

Commit 7b86899

Browse files
committed
Faster infer_shape
1 parent 7ecb9f8 commit 7b86899

File tree

6 files changed

+73
-12
lines changed

6 files changed

+73
-12
lines changed

pytensor/tensor/basic.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
from pytensor import compile, config, printing
2323
from pytensor import scalar as aes
2424
from pytensor.gradient import DisconnectedType, grad_undefined
25+
from pytensor.graph import RewriteDatabaseQuery
2526
from pytensor.graph.basic import Apply, Constant, Variable
2627
from pytensor.graph.fg import FunctionGraph
2728
from pytensor.graph.op import Op
29+
from pytensor.graph.rewriting.db import TopoDB, EquilibriumDB
2830
from pytensor.graph.rewriting.utils import rewrite_graph
2931
from pytensor.graph.type import HasShape, Type
3032
from pytensor.link.c.op import COp
@@ -1356,6 +1358,25 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
13561358
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
13571359

13581360

1361+
infer_shape_db = EquilibriumDB()
1362+
1363+
1364+
def register_infer_shape(rewrite, *tags, **kwargs):
1365+
if isinstance(rewrite, str):
1366+
1367+
def register(inner_lopt):
1368+
return register_infer_shape(inner_lopt, rewrite, *tags, **kwargs)
1369+
1370+
return register
1371+
else:
1372+
name = kwargs.pop("name", None) or rewrite.__name__
1373+
1374+
infer_shape_db.register(
1375+
name, rewrite, *tags, "infer_shape", **kwargs
1376+
)
1377+
return rewrite
1378+
1379+
13591380
def infer_static_shape(
13601381
shape: Union[Variable, Sequence[Union[Variable, int]]]
13611382
) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
@@ -1391,13 +1412,13 @@ def check_type(s):
13911412
raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
13921413

13931414
sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
1394-
1395-
shape_fg = FunctionGraph(
1396-
outputs=sh,
1397-
features=[ShapeFeature()],
1398-
clone=True,
1399-
)
1400-
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
1415+
shape_fg = FunctionGraph(outputs=sh, features=[ShapeFeature()], clone=True)
1416+
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
1417+
query_rewrites = infer_shape_db.query(RewriteDatabaseQuery(include=("infer_shape",)))
1418+
query_rewrites.rewrite(shape_fg)
1419+
if not all(isinstance(s, Constant) for s in shape_fg.outputs):
1420+
topo_constant_folding.rewrite(shape_fg)
1421+
folded_shape = shape_fg.outputs
14011422
static_shape = tuple(
14021423
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
14031424
)

pytensor/tensor/rewriting/basic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
switch,
6262
tensor_copy,
6363
zeros,
64-
zeros_like,
64+
zeros_like, register_infer_shape,
6565
)
6666
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6767
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -420,6 +420,7 @@ def local_fill_to_alloc(fgraph, node):
420420
)
421421

422422

423+
@register_infer_shape
423424
@register_canonicalize("fast_compile", "shape_unsafe")
424425
@register_useless("shape_unsafe")
425426
@node_rewriter([fill])
@@ -441,6 +442,7 @@ def local_useless_fill(fgraph, node):
441442
return [v]
442443

443444

445+
@register_infer_shape
444446
@register_specialize("shape_unsafe")
445447
@register_stabilize("shape_unsafe")
446448
@register_canonicalize("shape_unsafe")
@@ -530,6 +532,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
530532
)
531533

532534

535+
@register_infer_shape
533536
@register_useless
534537
@register_canonicalize("fast_compile")
535538
@register_specialize
@@ -806,6 +809,7 @@ def local_remove_all_assert(fgraph, node):
806809
)
807810

808811

812+
@register_infer_shape
809813
@register_specialize
810814
@register_canonicalize
811815
@register_useless
@@ -826,6 +830,7 @@ def local_join_1(fgraph, node):
826830

827831

828832
# TODO: merge in local_useless_join
833+
@register_infer_shape
829834
@register_useless
830835
@register_specialize
831836
@register_canonicalize
@@ -1066,6 +1071,7 @@ def local_merge_switch_same_cond(fgraph, node):
10661071
]
10671072

10681073

1074+
@register_infer_shape
10691075
@register_useless
10701076
@register_canonicalize
10711077
@register_specialize
@@ -1149,6 +1155,7 @@ def constant_folding(fgraph, node):
11491155
register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
11501156

11511157

1158+
@register_infer_shape
11521159
@register_canonicalize("fast_compile")
11531160
@register_useless("fast_compile")
11541161
@node_rewriter(None)
@@ -1157,6 +1164,7 @@ def local_view_op(fgraph, node):
11571164
return node.inputs
11581165

11591166

1167+
@register_infer_shape
11601168
@register_useless
11611169
@register_canonicalize
11621170
@register_stabilize

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
get_underlying_scalar_constant_value,
3434
ones_like,
3535
switch,
36-
zeros_like,
36+
zeros_like, register_infer_shape,
3737
)
3838
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3939
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -1745,6 +1745,7 @@ def local_reduce_join(fgraph, node):
17451745
return [ret]
17461746

17471747

1748+
@register_infer_shape
17481749
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
17491750
@register_useless("local_cut_useless_reduce")
17501751
@node_rewriter(ALL_REDUCE)

pytensor/tensor/rewriting/shape.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
copy_stack_trace,
1818
node_rewriter,
1919
)
20+
from pytensor.graph.rewriting.db import TopoDB
2021
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
2122
from pytensor.tensor.basic import (
2223
MakeVector,
@@ -25,7 +26,7 @@
2526
constant,
2627
extract_constant,
2728
get_underlying_scalar_constant_value,
28-
stack,
29+
stack, register_infer_shape,
2930
)
3031
from pytensor.tensor.elemwise import DimShuffle, Elemwise
3132
from pytensor.tensor.exceptions import NotScalarConstantError, ShapeError
@@ -964,6 +965,7 @@ def local_reshape_lift(fgraph, node):
964965
return [e]
965966

966967

968+
@register_infer_shape
967969
@register_useless
968970
@register_canonicalize
969971
@node_rewriter([SpecifyShape])
@@ -990,6 +992,25 @@ def local_merge_consecutive_specify_shape(fgraph, node):
990992
return [specify_shape(inner_obj, shape)]
991993

992994

995+
@register_infer_shape
996+
@node_rewriter([Shape])
997+
def local_shape_ground(fgraph, node):
998+
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
999+
[x] = node.inputs
1000+
static_shape = x.type.shape
1001+
if not any(dim is None for dim in static_shape):
1002+
return [stack([constant(dim, dtype="int64") for dim in static_shape])]
1003+
1004+
1005+
# Don't register in FAST_RUN, as we don't want to get rid of the shape before putting in Shape_i
1006+
pytensor.compile.mode.optdb["useless"].register(
1007+
"local_shape_ground",
1008+
local_shape_ground,
1009+
position="last",
1010+
)
1011+
1012+
1013+
@register_infer_shape
9931014
@register_useless
9941015
@register_canonicalize
9951016
@node_rewriter([Shape])
@@ -1014,6 +1035,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
10141035
return [stack(shape).astype(np.int64)]
10151036

10161037

1038+
@register_infer_shape
10171039
@register_canonicalize
10181040
@register_specialize
10191041
@node_rewriter([SpecifyShape])
@@ -1060,6 +1082,7 @@ def local_specify_shape_lift(fgraph, node):
10601082
return new_out
10611083

10621084

1085+
@register_infer_shape
10631086
@register_useless
10641087
@register_canonicalize
10651088
@node_rewriter([Shape_i])
@@ -1079,6 +1102,7 @@ def local_Shape_i_ground(fgraph, node):
10791102
return [as_tensor_variable(s_val, dtype=np.int64)]
10801103

10811104

1105+
@register_infer_shape
10821106
@register_specialize
10831107
@register_canonicalize
10841108
@node_rewriter([Shape])

pytensor/tensor/rewriting/subtensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
concatenate,
2727
extract_constant,
2828
get_underlying_scalar_constant_value,
29-
switch,
29+
switch, register_infer_shape,
3030
)
3131
from pytensor.tensor.elemwise import Elemwise
3232
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -328,6 +328,7 @@ def local_subtensor_of_dot(fgraph, node):
328328
return [r]
329329

330330

331+
@register_infer_shape
331332
@register_useless
332333
@register_canonicalize
333334
@register_specialize
@@ -599,6 +600,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
599600
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
600601

601602

603+
@register_infer_shape
602604
@register_useless
603605
@register_canonicalize
604606
@register_specialize
@@ -707,6 +709,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
707709
return
708710

709711

712+
@register_infer_shape
710713
@register_specialize
711714
@register_canonicalize("fast_compile")
712715
@register_useless
@@ -785,6 +788,7 @@ def local_subtensor_make_vector(fgraph, node):
785788
pass
786789

787790

791+
@register_infer_shape
788792
@register_useless
789793
@register_canonicalize
790794
@register_specialize
@@ -1461,6 +1465,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
14611465
return [r2]
14621466

14631467

1468+
@register_infer_shape
14641469
@register_specialize
14651470
@register_stabilize
14661471
@register_canonicalize

tests/tensor/test_blockwise.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,10 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
358358
logp = norm - 0.5 * quaddist - logdet
359359
dlogp = grad(logp.sum(), wrt=[value, mu, cov])
360360

361+
# pytensor.dprint([logp, *dlogp], print_type=True)
361362
fn = pytensor.function([value, mu, cov], [logp, *dlogp])
362-
benchmark(fn, *test_values)
363+
pytensor.dprint(fn, print_type=True)
364+
# benchmark(fn, *test_values)
363365

364366

365367
def test_cop_with_params():

0 commit comments

Comments
 (0)