From 77e597d4020876d1bae153d640535bd425403de6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 24 Nov 2023 20:37:01 +0100 Subject: [PATCH] Faster infer_static_shape --- pytensor/tensor/basic.py | 58 ++++++++++++++++++++++---- pytensor/tensor/rewriting/basic.py | 9 ++++ pytensor/tensor/rewriting/math.py | 2 + pytensor/tensor/rewriting/shape.py | 16 +++++++ pytensor/tensor/rewriting/subtensor.py | 6 +++ pytensor/tensor/shape.py | 29 ++++++------- scripts/mypy-failing.txt | 1 - 7 files changed, 98 insertions(+), 23 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index fbde5d17dd..434f8b85e7 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -22,10 +22,11 @@ from pytensor import compile, config, printing from pytensor import scalar as aes from pytensor.gradient import DisconnectedType, grad_undefined +from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op -from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.type import HasShape, Type from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType @@ -1356,6 +1357,45 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None): return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype) +class CachedEquilibrimDB(EquilibriumDB): + """A subclass of EquilibriumDB that allows caching of a default query for faster reuse.""" + + def __init__(self, default_query): + super().__init__() + self._default_query = default_query + self._cached_default_query = None + + def register(self, *args, **kwargs): + # If new rewrites are registered, the default cached query is void + self.cached_default_query = None + super().register(*args, **kwargs) + + @property + def default_query(self): + if self._cached_default_query is None: + self._cached_default_query = self.query(self._default_query) + return self._cached_default_query + + +infer_shape_db = CachedEquilibrimDB( + default_query=RewriteDatabaseQuery(include=("infer_shape",)) +) + + +def register_infer_shape(rewrite, *tags, **kwargs): + if isinstance(rewrite, str): + + def register(inner_lopt): + return register_infer_shape(inner_lopt, rewrite, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or rewrite.__name__ + + infer_shape_db.register(name, rewrite, *tags, "infer_shape", **kwargs) + return rewrite + + def infer_static_shape( shape: Union[Variable, Sequence[Union[Variable, int]]] ) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]: @@ -1390,14 +1430,16 @@ def check_type(s): raise TypeError(f"Shapes must be scalar integers; got {s_as_str}") - sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape] + sh = folded_shape = [check_type(as_tensor_variable(s, ndim=0)) for s in shape] + + if not all(isinstance(s, Constant) for s in folded_shape): + shape_fg = FunctionGraph(outputs=sh, features=[ShapeFeature()], clone=True) + with config.change_flags(optdb__max_use_ratio=10, cxx=""): + infer_shape_db.default_query.rewrite(shape_fg) + if not all(isinstance(s, Constant) for s in shape_fg.outputs): + topo_constant_folding.rewrite(shape_fg) + folded_shape = shape_fg.outputs - shape_fg = FunctionGraph( - outputs=sh, - features=[ShapeFeature()], - clone=True, - ) - folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs static_shape = tuple( s.data.item() if isinstance(s, Constant) else None for s in folded_shape ) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index d246cf738a..021660d8e0 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -58,6 +58,7 @@ get_underlying_scalar_constant_value, join, ones_like, + register_infer_shape, switch, tensor_copy, zeros, @@ -420,6 +421,7 @@ def local_fill_to_alloc(fgraph, node): ) +@register_infer_shape @register_canonicalize("fast_compile", "shape_unsafe") @register_useless("shape_unsafe") @node_rewriter([fill]) @@ -441,6 +443,7 @@ def local_useless_fill(fgraph, node): return [v] +@register_infer_shape @register_specialize("shape_unsafe") @register_stabilize("shape_unsafe") @register_canonicalize("shape_unsafe") @@ -530,6 +533,7 @@ def local_alloc_empty_to_zeros(fgraph, node): ) +@register_infer_shape @register_useless @register_canonicalize("fast_compile") @register_specialize @@ -806,6 +810,7 @@ def local_remove_all_assert(fgraph, node): ) +@register_infer_shape @register_specialize @register_canonicalize @register_useless @@ -826,6 +831,7 @@ def local_join_1(fgraph, node): # TODO: merge in local_useless_join +@register_infer_shape @register_useless @register_specialize @register_canonicalize @@ -1066,6 +1072,7 @@ def local_merge_switch_same_cond(fgraph, node): ] +@register_infer_shape @register_useless @register_canonicalize @register_specialize @@ -1149,6 +1156,7 @@ def constant_folding(fgraph, node): register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True) +@register_infer_shape @register_canonicalize("fast_compile") @register_useless("fast_compile") @node_rewriter(None) @@ -1157,6 +1165,7 @@ def local_view_op(fgraph, node): return node.inputs +@register_infer_shape @register_useless @register_canonicalize @register_stabilize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 97c7138d3a..a814ffdf69 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -32,6 +32,7 @@ extract_constant, get_underlying_scalar_constant_value, ones_like, + register_infer_shape, switch, zeros_like, ) @@ -1745,6 +1746,7 @@ def local_reduce_join(fgraph, node): return [ret] +@register_infer_shape @register_canonicalize("fast_compile", "local_cut_useless_reduce") @register_useless("local_cut_useless_reduce") @node_rewriter(ALL_REDUCE) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index e0d4963388..aa24f217bd 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -25,6 +25,7 @@ constant, extract_constant, get_underlying_scalar_constant_value, + register_infer_shape, stack, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -964,6 +965,7 @@ def local_reshape_lift(fgraph, node): return [e] +@register_infer_shape @register_useless @register_canonicalize @node_rewriter([SpecifyShape]) @@ -990,6 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node): return [specify_shape(inner_obj, shape)] +@register_infer_shape +@node_rewriter([Shape]) +def local_shape_ground(fgraph, node): + """Rewrite shape(x) -> make_vector(x.type.shape) when this is constant.""" + [x] = node.inputs + static_shape = x.type.shape + if not any(dim is None for dim in static_shape): + return [stack([constant(dim, dtype="int64") for dim in static_shape])] + + +@register_infer_shape @register_useless @register_canonicalize @node_rewriter([Shape]) @@ -1014,6 +1027,7 @@ def local_Shape_of_SpecifyShape(fgraph, node): return [stack(shape).astype(np.int64)] +@register_infer_shape @register_canonicalize @register_specialize @node_rewriter([SpecifyShape]) @@ -1060,6 +1074,7 @@ def local_specify_shape_lift(fgraph, node): return new_out +@register_infer_shape @register_useless @register_canonicalize @node_rewriter([Shape_i]) @@ -1079,6 +1094,7 @@ def local_Shape_i_ground(fgraph, node): return [as_tensor_variable(s_val, dtype=np.int64)] +@register_infer_shape @register_specialize @register_canonicalize @node_rewriter([Shape]) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index a84767613e..4e80d3bb30 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -26,6 +26,7 @@ concatenate, extract_constant, get_underlying_scalar_constant_value, + register_infer_shape, switch, ) from pytensor.tensor.elemwise import Elemwise @@ -328,6 +329,7 @@ def local_subtensor_of_dot(fgraph, node): return [r] +@register_infer_shape @register_useless @register_canonicalize @register_specialize @@ -599,6 +601,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): return [node.inputs[0].dimshuffle(tuple(remain_dim))] +@register_infer_shape @register_useless @register_canonicalize @register_specialize @@ -707,6 +710,7 @@ def local_subtensor_inc_subtensor(fgraph, node): return +@register_infer_shape @register_specialize @register_canonicalize("fast_compile") @register_useless @@ -785,6 +789,7 @@ def local_subtensor_make_vector(fgraph, node): pass +@register_infer_shape @register_useless @register_canonicalize @register_specialize @@ -1461,6 +1466,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): return [r2] +@register_infer_shape @register_specialize @register_stabilize @register_canonicalize diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 969fce038a..1d8efa02c5 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -1,12 +1,13 @@ import warnings from numbers import Number from textwrap import dedent -from typing import Union +from typing import Union, cast import numpy as np import pytensor from pytensor.gradient import DisconnectedType +from pytensor.graph import Op from pytensor.graph.basic import Apply, Variable from pytensor.graph.replace import _vectorize_node from pytensor.graph.type import HasShape @@ -145,14 +146,14 @@ def c_code_cache_version(self): def shape(x: Union[np.ndarray, Number, Variable]) -> Variable: """Return the shape of `x`.""" if not isinstance(x, Variable): - x = at.as_tensor_variable(x) + x = at.as_tensor_variable(x) # type: ignore - return _shape(x) + return cast(Variable, _shape(x)) -@_get_vector_length.register(Shape) -def _get_vector_length_Shape(op, var): - return var.owner.inputs[0].type.ndim +@_get_vector_length.register(Shape) # type: ignore +def _get_vector_length_Shape(op: Op, var: TensorVariable) -> int: + return cast(int, var.owner.inputs[0].type.ndim) @_vectorize_node.register(Shape) @@ -181,7 +182,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: # We assume/call it a scalar return () - res = () + res: tuple[Variable, ...] = () symbolic_shape = shape(x) static_shape = x.type.shape for i in range(x.type.ndim): @@ -191,7 +192,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: # TODO: Why not use uint64? res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),) else: - res += (symbolic_shape[i],) + res += (symbolic_shape[i],) # type: ignore return res @@ -366,7 +367,7 @@ def shape_i_op(i): return shape_i_op.cache[key] -shape_i_op.cache = {} +shape_i_op.cache = {} # type: ignore def register_shape_i_c_code(typ, code, check_input, version=()): @@ -578,7 +579,7 @@ def specify_shape( # If the specified shape is already encoded in the input static shape, do nothing # This ignores PyTensor constants in shape - x = at.as_tensor_variable(x) + x = at.as_tensor_variable(x) # type: ignore new_shape_info = any( s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None ) @@ -589,10 +590,10 @@ def specify_shape( return _specify_shape(x, *shape) -@_get_vector_length.register(SpecifyShape) -def _get_vector_length_SpecifyShape(op, var): +@_get_vector_length.register(SpecifyShape) # type: ignore +def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int: try: - return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item() + return int(at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -1104,4 +1105,4 @@ def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> A batched_ndims = x.type.ndim - node.inputs[0].type.ndim old_axes = op.axes new_axes = (old_axis + batched_ndims for old_axis in old_axes) - return unbroadcast(x, *new_axes).owner + return cast(Apply, unbroadcast(x, *new_axes).owner) diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 1cae4d9152..4b32536bec 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py pytensor/tensor/random/op.py pytensor/tensor/random/utils.py pytensor/tensor/rewriting/basic.py -pytensor/tensor/shape.py pytensor/tensor/slinalg.py pytensor/tensor/subtensor.py pytensor/tensor/type.py