Skip to content

Faster infer_static_shape #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from pytensor import compile, config, printing
from pytensor import scalar as aes
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
Expand Down Expand Up @@ -1356,6 +1357,45 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)


class CachedEquilibrimDB(EquilibriumDB):
"""A subclass of EquilibriumDB that allows caching of a default query for faster reuse."""

def __init__(self, default_query):
super().__init__()
self._default_query = default_query
self._cached_default_query = None

def register(self, *args, **kwargs):
# If new rewrites are registered, the default cached query is void
self.cached_default_query = None
super().register(*args, **kwargs)

@property
def default_query(self):
if self._cached_default_query is None:
self._cached_default_query = self.query(self._default_query)
return self._cached_default_query


infer_shape_db = CachedEquilibrimDB(
default_query=RewriteDatabaseQuery(include=("infer_shape",))
)


def register_infer_shape(rewrite, *tags, **kwargs):
if isinstance(rewrite, str):

def register(inner_lopt):
return register_infer_shape(inner_lopt, rewrite, *tags, **kwargs)

return register
else:
name = kwargs.pop("name", None) or rewrite.__name__

infer_shape_db.register(name, rewrite, *tags, "infer_shape", **kwargs)
return rewrite


def infer_static_shape(
shape: Union[Variable, Sequence[Union[Variable, int]]]
) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
Expand Down Expand Up @@ -1390,14 +1430,16 @@ def check_type(s):

raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")

sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
sh = folded_shape = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]

if not all(isinstance(s, Constant) for s in folded_shape):
shape_fg = FunctionGraph(outputs=sh, features=[ShapeFeature()], clone=True)
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
infer_shape_db.default_query.rewrite(shape_fg)
if not all(isinstance(s, Constant) for s in shape_fg.outputs):
topo_constant_folding.rewrite(shape_fg)
folded_shape = shape_fg.outputs

shape_fg = FunctionGraph(
outputs=sh,
features=[ShapeFeature()],
clone=True,
)
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
static_shape = tuple(
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
)
Expand Down
9 changes: 9 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
get_underlying_scalar_constant_value,
join,
ones_like,
register_infer_shape,
switch,
tensor_copy,
zeros,
Expand Down Expand Up @@ -420,6 +421,7 @@ def local_fill_to_alloc(fgraph, node):
)


@register_infer_shape
@register_canonicalize("fast_compile", "shape_unsafe")
@register_useless("shape_unsafe")
@node_rewriter([fill])
Expand All @@ -441,6 +443,7 @@ def local_useless_fill(fgraph, node):
return [v]


@register_infer_shape
@register_specialize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_canonicalize("shape_unsafe")
Expand Down Expand Up @@ -530,6 +533,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
)


@register_infer_shape
@register_useless
@register_canonicalize("fast_compile")
@register_specialize
Expand Down Expand Up @@ -806,6 +810,7 @@ def local_remove_all_assert(fgraph, node):
)


@register_infer_shape
@register_specialize
@register_canonicalize
@register_useless
Expand All @@ -826,6 +831,7 @@ def local_join_1(fgraph, node):


# TODO: merge in local_useless_join
@register_infer_shape
@register_useless
@register_specialize
@register_canonicalize
Expand Down Expand Up @@ -1066,6 +1072,7 @@ def local_merge_switch_same_cond(fgraph, node):
]


@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
Expand Down Expand Up @@ -1149,6 +1156,7 @@ def constant_folding(fgraph, node):
register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)


@register_infer_shape
@register_canonicalize("fast_compile")
@register_useless("fast_compile")
@node_rewriter(None)
Expand All @@ -1157,6 +1165,7 @@ def local_view_op(fgraph, node):
return node.inputs


@register_infer_shape
@register_useless
@register_canonicalize
@register_stabilize
Expand Down
2 changes: 2 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
extract_constant,
get_underlying_scalar_constant_value,
ones_like,
register_infer_shape,
switch,
zeros_like,
)
Expand Down Expand Up @@ -1745,6 +1746,7 @@ def local_reduce_join(fgraph, node):
return [ret]


@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce")
@node_rewriter(ALL_REDUCE)
Expand Down
16 changes: 16 additions & 0 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
constant,
extract_constant,
get_underlying_scalar_constant_value,
register_infer_shape,
stack,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
Expand Down Expand Up @@ -964,6 +965,7 @@ def local_reshape_lift(fgraph, node):
return [e]


@register_infer_shape
@register_useless
@register_canonicalize
@node_rewriter([SpecifyShape])
Expand All @@ -990,6 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
return [specify_shape(inner_obj, shape)]


@register_infer_shape
@node_rewriter([Shape])
def local_shape_ground(fgraph, node):
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
[x] = node.inputs
static_shape = x.type.shape
if not any(dim is None for dim in static_shape):
return [stack([constant(dim, dtype="int64") for dim in static_shape])]


@register_infer_shape
@register_useless
@register_canonicalize
@node_rewriter([Shape])
Expand All @@ -1014,6 +1027,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
return [stack(shape).astype(np.int64)]


@register_infer_shape
@register_canonicalize
@register_specialize
@node_rewriter([SpecifyShape])
Expand Down Expand Up @@ -1060,6 +1074,7 @@ def local_specify_shape_lift(fgraph, node):
return new_out


@register_infer_shape
@register_useless
@register_canonicalize
@node_rewriter([Shape_i])
Expand All @@ -1079,6 +1094,7 @@ def local_Shape_i_ground(fgraph, node):
return [as_tensor_variable(s_val, dtype=np.int64)]


@register_infer_shape
@register_specialize
@register_canonicalize
@node_rewriter([Shape])
Expand Down
6 changes: 6 additions & 0 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
concatenate,
extract_constant,
get_underlying_scalar_constant_value,
register_infer_shape,
switch,
)
from pytensor.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -328,6 +329,7 @@ def local_subtensor_of_dot(fgraph, node):
return [r]


@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
Expand Down Expand Up @@ -599,6 +601,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
return [node.inputs[0].dimshuffle(tuple(remain_dim))]


@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
Expand Down Expand Up @@ -707,6 +710,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
return


@register_infer_shape
@register_specialize
@register_canonicalize("fast_compile")
@register_useless
Expand Down Expand Up @@ -785,6 +789,7 @@ def local_subtensor_make_vector(fgraph, node):
pass


@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
Expand Down Expand Up @@ -1461,6 +1466,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
return [r2]


@register_infer_shape
@register_specialize
@register_stabilize
@register_canonicalize
Expand Down
29 changes: 15 additions & 14 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import warnings
from numbers import Number
from textwrap import dedent
from typing import Union
from typing import Union, cast

import numpy as np

import pytensor
from pytensor.gradient import DisconnectedType
from pytensor.graph import Op
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import HasShape
Expand Down Expand Up @@ -145,14 +146,14 @@ def c_code_cache_version(self):
def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
"""Return the shape of `x`."""
if not isinstance(x, Variable):
x = at.as_tensor_variable(x)
x = at.as_tensor_variable(x) # type: ignore

return _shape(x)
return cast(Variable, _shape(x))


@_get_vector_length.register(Shape)
def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim
@_get_vector_length.register(Shape) # type: ignore
def _get_vector_length_Shape(op: Op, var: TensorVariable) -> int:
return cast(int, var.owner.inputs[0].type.ndim)


@_vectorize_node.register(Shape)
Expand Down Expand Up @@ -181,7 +182,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# We assume/call it a scalar
return ()

res = ()
res: tuple[Variable, ...] = ()
symbolic_shape = shape(x)
static_shape = x.type.shape
for i in range(x.type.ndim):
Expand All @@ -191,7 +192,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# TODO: Why not use uint64?
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
else:
res += (symbolic_shape[i],)
res += (symbolic_shape[i],) # type: ignore

return res

Expand Down Expand Up @@ -366,7 +367,7 @@ def shape_i_op(i):
return shape_i_op.cache[key]


shape_i_op.cache = {}
shape_i_op.cache = {} # type: ignore


def register_shape_i_c_code(typ, code, check_input, version=()):
Expand Down Expand Up @@ -578,7 +579,7 @@ def specify_shape(

# If the specified shape is already encoded in the input static shape, do nothing
# This ignores PyTensor constants in shape
x = at.as_tensor_variable(x)
x = at.as_tensor_variable(x) # type: ignore
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None
)
Expand All @@ -589,10 +590,10 @@ def specify_shape(
return _specify_shape(x, *shape)


@_get_vector_length.register(SpecifyShape)
def _get_vector_length_SpecifyShape(op, var):
@_get_vector_length.register(SpecifyShape) # type: ignore
def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int:
try:
return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()
return int(at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item())
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")

Expand Down Expand Up @@ -1104,4 +1105,4 @@ def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> A
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
old_axes = op.axes
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
return unbroadcast(x, *new_axes).owner
return cast(Apply, unbroadcast(x, *new_axes).owner)
1 change: 0 additions & 1 deletion scripts/mypy-failing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py
pytensor/tensor/shape.py
pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
pytensor/tensor/type.py
Expand Down