diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index a11aa57bdf..9b1399b72f 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1777,6 +1777,7 @@ def equal_computations( ys: list[Union[np.ndarray, Variable]], in_xs: Optional[list[Variable]] = None, in_ys: Optional[list[Variable]] = None, + strict_dtype=True, ) -> bool: """Checks if PyTensor graphs represent the same computations. @@ -1908,7 +1909,10 @@ def compare_nodes(nd_x, nd_y, common, different): if dx != dy: if isinstance(dx, Constant) and isinstance(dy, Constant): if not dx.equals(dy): - return False + if strict_dtype: + return False + elif not np.array_equal(dx.data, dy.data): + return False else: return False diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 6f6467cff7..567b4407ab 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs): def batched_dot(a, b): if a.shape[0] != b.shape[0]: raise TypeError("Shapes must match in the 0-th dimension") - if a.ndim == 2 or b.ndim == 2: - return jnp.einsum("n...j,nj...->n...", a, b) - return jnp.einsum("nij,njk->nik", a, b) + return jnp.matmul(a, b) return batched_dot diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index ab7054ccaf..9c9c800b92 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs): @numba_njit def batched_dot(x, y): + # Numba does not support 3D matmul + # https://github.com/numba/numba/issues/3804 shape = x.shape[:-1] + y.shape[2:] z0 = np.empty(shape, dtype=dtype) for i in range(z0.shape[0]): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 434f8b85e7..946660e431 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -26,6 +26,7 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.type import HasShape, Type from pytensor.link.c.op import COp @@ -41,6 +42,7 @@ as_tensor_variable, get_vector_length, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import ( @@ -1657,16 +1659,22 @@ def do_constant_folding(self, fgraph, node): if not clients: return False - for client in clients: - if client[0] == "output": + for client, idx in clients: + if client == "output": # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False + # Allow alloc to be lifted out of Elemwise before constant folding it + elif isinstance(client.op, Elemwise): + return None + # Same for Blockwise, unless it has no batch_dims + elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client): + return None elif ( # The following ops work inplace of their input id 0. - client[1] == 0 + idx == 0 and isinstance( - client[0].op, + client.op, ( # Ops that will work inplace on the Alloc. So if they # get constant_folded, they would copy the @@ -3497,10 +3505,17 @@ def make_node(self, x): if x.ndim < 2: raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x) + + out_shape = [ + st_dim + for i, st_dim in enumerate(x.type.shape) + if i not in (self.axis1, self.axis2) + ] + [None] + return Apply( self, [x], - [x.type.clone(dtype=x.dtype, shape=(None,) * (x.ndim - 1))()], + [x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()], ) def perform(self, node, inputs, outputs): @@ -3601,6 +3616,17 @@ def diagonal(a, offset=0, axis1=0, axis2=1): return ExtractDiag(offset, axis1, axis2)(a) +@_vectorize_node.register(ExtractDiag) +def vectorize_extract_diag(op: ExtractDiag, node, batched_x): + batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim + return diagonal( + batched_x, + offset=op.offset, + axis1=op.axis1 + batched_ndims, + axis2=op.axis2 + batched_ndims, + ).owner + + def trace(a, offset=0, axis1=0, axis2=1): """ Returns the sum along diagonals of the array. diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 78a80bd323..301cc5d199 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -98,10 +98,11 @@ from pytensor.printing import FunctionPrinter, pprint from pytensor.scalar import bool as bool_t from pytensor.tensor import basic as at +from pytensor.tensor.basic import expand_dims from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import add, mul, neg, sub -from pytensor.tensor.shape import specify_broadcastable +from pytensor.tensor.shape import shape_padright, specify_broadcastable from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor from pytensor.utils import memoize @@ -1637,48 +1638,53 @@ def c_code_cache_version(self): class BatchedDot(COp): """ - Computes the batched dot product of two variables: + Computes a batch matrix-matrix dot with tensor3 variables batched_dot(a, b)[i] = dot(a[i], b[i]) """ __props__ = () + gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" - def make_node(self, *inputs): - inputs = list(map(at.as_tensor_variable, inputs)) + def make_node(self, x, y): + x = at.as_tensor_variable(x) + y = at.as_tensor_variable(y) - if any(not isinstance(i.type, DenseTensorType) for i in inputs): + if not ( + isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) + ): raise NotImplementedError("Only dense tensor types are supported") - if len(inputs) != 2: - raise TypeError(f"Two arguments required, but {len(inputs)} given.") - if inputs[0].ndim not in (2, 3): + if not (x.type.ndim == 3 and y.type.ndim == 3): raise TypeError( - "Input 0 (0-indexed)" - f" must have ndim of 2 or 3, {int(inputs[0].ndim)} given. Consider" - " calling batched_dot instead." - ) - if inputs[1].ndim not in (2, 3): - raise TypeError( - "Input 1 (0-indexed)" - f" must have ndim of 2 or 3, {int(inputs[1].ndim)} given. Consider" - " calling batched_dot instead." + f"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. " + "Consider calling batched_dot instead." ) - dtype = pytensor.scalar.upcast(*[input.type.dtype for input in inputs]) - # upcast inputs to common dtype if needed - upcasted_inputs = [at.cast(input, dtype) for input in inputs] - out_shape = ( - ( - 1 - if inputs[0].type.shape[0] == 1 or inputs[1].type.shape[0] == 1 - else None, - ) - + inputs[0].type.shape[1:-1] - + inputs[1].type.shape[2:] - ) - out_shape = tuple(1 if s == 1 else None for s in out_shape) - return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)]) + def extract_static_dim(dim_x, dim_y): + dims = {dim_x, dim_y} - {None} + if len(dims) > 1: + # BatchedDot doesn't allow broadcasting + raise ValueError( + f"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}" + ) + elif not dims: + return None + else: + return dims.pop() + + x_batch_dim, x_row_dim, x_sum_dim = x.type.shape + y_batch_dim, y_sum_dim, y_col_dim = y.type.shape + batch_dim = extract_static_dim(x_batch_dim, y_batch_dim) + # Raise if static sum dimensions do not match + _ = extract_static_dim(x_sum_dim, y_sum_dim) + out_shape = (batch_dim, x_row_dim, y_col_dim) + + # Change dtype if needed + dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) + x, y = at.cast(x, dtype), at.cast(y, dtype) + out = tensor(dtype=dtype, shape=out_shape) + return Apply(self, [x, y], [out]) def perform(self, node, inp, out): x, y = inp @@ -1690,11 +1696,7 @@ def perform(self, node, inp, out): f" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]." ) - shape = self.infer_shape(None, node, [i.shape for i in inp])[0] - dtype = node.outputs[0].dtype - z0 = z[0] = np.empty(shape, dtype=dtype) - for i in range(z0.shape[0]): - z0[i] = np.dot(x[i], y[i]) + z[0] = np.matmul(x, y) def c_support_code(self, **kwargs): batch_gemm_defn = """ @@ -1792,14 +1794,6 @@ def c_lib_dirs(self, **kwargs): def c_header_dirs(self, **kwargs): return ldflags(libs=False, include_dir=True) - def c_code_cleanup(self, node, name, inputs, outputs, sub): - return """ - // clean up views - Py_XDECREF(xs); xs = 0; - Py_XDECREF(ys); ys = 0; - Py_XDECREF(zs); zs = 0; - """ - def c_code(self, node, name, inp, out, sub): _x, _y = inp (_z,) = out @@ -1832,12 +1826,11 @@ def contiguous(var, ndim): ) # generate code to allocate output based on runtime input shapes - z_dims = [f"PyArray_DIMS({_x})[0]"] - if x_ndim == 3: - z_dims.append(f"PyArray_DIMS({_x})[1]") - if y_ndim == 3: - z_dims.append(f"PyArray_DIMS({_y})[2]") - assert len(z_dims) == z_ndim + z_dims = [ + f"PyArray_DIMS({_x})[0]", + f"PyArray_DIMS({_x})[1]", + f"PyArray_DIMS({_y})[2]", + ] z_shape_correct = " && ".join( "PyArray_DIMS(%s)[%i] == %s" % (_z, i, dim) for i, dim in enumerate(z_dims) @@ -1880,76 +1873,26 @@ def contiguous(var, ndim): ) contiguate = "\n".join(contiguate) - def c_dimshuffle(newname, oldname, shape): - _fail = fail - _shape = ", ".join( - "1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis) - for axis in shape - ) - return ( - """{ - npy_intp dims[3] = {%(_shape)s}; - PyArray_Dims newshape = {dims, 3}; - %(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER); - if (!%(newname)s) - %(_fail)s - // make sure we didn't accidentally copy - assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s)); - }""" - % locals() - ) - - # create tensor3 views for any of x, y, z that are not tensor3, so that - # we only need to implement the tensor3-tensor3 batched dot product. - # xs, ys and zs will point to these views, or to the original array if - # it was already tensor3. - # in the latter case, we artificially increase the reference count of - # the original array so that the c_code_cleanup method can decref them - # all indiscriminately. - upcast = [] - if x_ndim == 3: - upcast.append("xs = %(_x)s; Py_XINCREF(xs);") - elif x_ndim == 2: - upcast.append(c_dimshuffle("xs", _x, (0, None, 1))) - if y_ndim == 3: - upcast.append("ys = %(_y)s; Py_XINCREF(ys);") - elif y_ndim == 2: - upcast.append(c_dimshuffle("ys", _y, (0, 1, None))) - if z_ndim == 3: - upcast.append("zs = %(_z)s; Py_XINCREF(zs);") - else: - upcast.append( - c_dimshuffle( - "zs", - _z, - (0, None if x_ndim == 2 else 1, None if y_ndim == 2 else 1), - ) - ) - upcast = "\n".join(upcast) % locals() - return ( """ int type_num = PyArray_DESCR(%(_x)s)->type_num; int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes - // xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s - PyArrayObject *xs = 0, *ys = 0, *zs = 0; - - if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) { + if (PyArray_NDIM(%(_x)s) != 3) { PyErr_Format(PyExc_NotImplementedError, - "rank(x) != %(x_ndim)s. rank(x) is %%d.", + "rank(x) != 3. rank(x) is %%d.", PyArray_NDIM(%(_x)s)); %(fail)s; } - if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) { + if (PyArray_NDIM(%(_y)s) != 3) { PyErr_Format(PyExc_NotImplementedError, - "rank(y) != %(y_ndim)s. rank(y) is %%d.", + "rank(y) != 3. rank(y) is %%d.", PyArray_NDIM(%(_y)s)); %(fail)s; } - if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) { + if (%(_z)s && PyArray_NDIM(%(_z)s) != 3) { PyErr_Format(PyExc_NotImplementedError, - "rank(z) != %(z_ndim)s. rank(z) is %%d.", + "rank(z) != 3. rank(z) is %%d.", PyArray_NDIM(%(_z)s)); %(fail)s; } @@ -1958,36 +1901,32 @@ def c_dimshuffle(newname, oldname, shape): %(allocate)s // reallocate any noncontiguous arrays or arrays with invalid strides %(contiguate)s - // add dims to make sure everything is tensor3 - %(upcast)s - // from here on, use xs, ys and zs as they are tensor3 and share memory - // with the original %(_x)s, %(_y)s and %(_z)s arrays. - if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(xs)->type_num != NPY_FLOAT)) + if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT)) {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} - if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(ys)->type_num != NPY_FLOAT)) + if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT)) {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} - if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(zs)->type_num != NPY_FLOAT)) + if ((PyArray_DESCR(%(_z)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_z)s)->type_num != NPY_FLOAT)) {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} - if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(ys)->type_num) - ||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num)) + if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num) + ||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_z)s)->type_num)) { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } switch (type_num) { case NPY_FLOAT: - if (batch_gemm(sgemm_, type_size, xs, ys, zs)) { + if (batch_gemm(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) { %(fail)s; } break; case NPY_DOUBLE: - if (batch_gemm(dgemm_, type_size, xs, ys, zs)) { + if (batch_gemm(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) { %(fail)s; } break; @@ -1999,32 +1938,14 @@ def c_dimshuffle(newname, oldname, shape): def c_code_cache_version(self): from pytensor.tensor.blas_headers import blas_header_version - return (4, blas_header_version()) + return (5, blas_header_version()) def grad(self, inp, grads): x, y = inp (gz,) = grads - xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim - # grad is a vector, so x is a matrix and y is a matrix - if gdim == 1: - xgrad = gz.dimshuffle(0, "x") * y - ygrad = gz.dimshuffle(0, "x") * x - - # x is a matrix, y is a tensor3, grad is a matrix - elif xdim == 2 and ydim == 3: - xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) - ygrad = x.dimshuffle(0, 1, "x") * gz.dimshuffle(0, "x", 1) - - # x is a tensor3, y is a matrix, grad is a matrix - elif xdim == 3 and ydim == 2: - xgrad = gz.dimshuffle(0, 1, "x") * y.dimshuffle(0, "x", 1) - ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) - - # x is a tensor3, y is a tensor3, grad is a tensor3 - elif xdim == ydim == 3: - xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) - ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) + xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) + ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) # If x or y contain broadcastable dimensions but only one of # them know that a matching dimensions is broadcastable, the @@ -2105,6 +2026,7 @@ def R_op(self, inputs, eval_points): + " to BatchedDot.R_op should have the same shape, but " f"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively" ) + if eval_points[0]: t1 = self(eval_points[0], inputs[1]) if eval_points[1]: @@ -2118,9 +2040,6 @@ def R_op(self, inputs, eval_points): return [t2] def infer_shape(self, fgraph, node, shapes): - for shape_ in shapes: - if len(shape_) not in (2, 3): - raise NotImplementedError() xshp, yshp = shapes return [xshp[:-1] + yshp[2:]] @@ -2157,14 +2076,24 @@ def batched_dot(a, b): elif b.ndim == 0: raise TypeError("b must have at least one (batch) axis") elif a.ndim == 1: - return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b + return shape_padright(a, (b.ndim - 1)) * b elif b.ndim == 1: - return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1))) + return a * shape_padright(b, (a.ndim - 1)) elif a.ndim > 3 or b.ndim > 3: return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]]) else: - # avoid circular import - return _batched_dot(a, b) + # If either a or b is a batched vector, expand dims and later squeeze them + expanded_axis = [] + if a.ndim == 2: + a = expand_dims(a, axis=1) + expanded_axis.append(1) + if b.ndim == 2: + b = expand_dims(b, axis=2) + expanded_axis.append(2) + out = _batched_dot(a, b) + if expanded_axis: + out = out.squeeze(axis=expanded_axis) + return out def batched_tensordot(x, y, axes=2): diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 96357f59f8..d21af2f651 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from copy import copy from typing import Any, Optional, cast import numpy as np @@ -13,9 +14,10 @@ _vectorize_not_needed, vectorize_graph, ) +from pytensor.scalar import ScalarType from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor +from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor from pytensor.tensor.utils import ( _parse_gufunc_signature, broadcast_static_dim_lengths, @@ -57,6 +59,7 @@ def __init__( core_op: Op, signature: Optional[str] = None, name: Optional[str] = None, + gufunc_spec: Optional[tuple[str, int, int]] = None, **kwargs, ): """ @@ -68,7 +71,12 @@ def __init__( signature Generalized universal function signature, e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication - + gufunc: tuple, Optional + Tuple containing: + 1. String import path for a numpy/scipy function (e.g., "numpy.matmul", "scipy.special.softmax") + that implements the blockwised operation of the scalar op. + 2 Number of inputs of the function + 3 Number of outputs of the function """ if isinstance(core_op, Blockwise): raise TypeError("Core Op is already a Blockwise") @@ -84,9 +92,15 @@ def __init__( self.signature = signature self.name = name self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self.gufunc_spec = gufunc_spec self._gufunc = None super().__init__(**kwargs) + def __getstate__(self): + d = copy(self.__dict__) + d["_gufunc"] = None + return d + def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: core_input_types = [] for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): @@ -157,8 +171,8 @@ def make_node(self, *inputs): return Apply(self, batched_inputs, batched_outputs) - def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int: - return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0])) + def batch_ndim(self, node: Apply) -> int: + return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0])) def infer_shape( self, fgraph, node, input_shapes @@ -166,7 +180,7 @@ def infer_shape( from pytensor.tensor import broadcast_shape from pytensor.tensor.shape import Shape_i - batch_ndims = self._batch_ndim_from_outputs(node.outputs) + batch_ndims = self.batch_ndim(node) core_dims: dict[str, Any] = {} batch_shapes = [] for input_shape, sig in zip(input_shapes, self.inputs_sig): @@ -272,7 +286,7 @@ def L_op(self, inputs, outs, ograds): return new_rval # Sum out the broadcasted dimensions - batch_ndims = self._batch_ndim_from_outputs(outs) + batch_ndims = self.batch_ndim(outs[0].owner) batch_shape = outs[0].type.shape[:batch_ndims] for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): if isinstance(rval[i].type, (NullType, DisconnectedType)): @@ -291,10 +305,14 @@ def L_op(self, inputs, outs, ograds): return rval def _create_gufunc(self, node): - if hasattr(self.core_op, "gufunc_spec"): - self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0]) + gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) + + if gufunc_spec is not None: + self._gufunc = import_func_from_string(gufunc_spec[0]) if self._gufunc: return self._gufunc + else: + raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") n_outs = len(self.outputs_sig) core_node = self._create_dummy_core_node(node.inputs) @@ -314,7 +332,7 @@ def core_func(*inner_inputs): return self._gufunc def _check_runtime_broadcast(self, node, inputs): - batch_ndim = self._batch_ndim_from_outputs(node.outputs) + batch_ndim = self.batch_ndim(node) for dims_and_bcast in zip( *[ @@ -356,6 +374,12 @@ def __str__(self): @_vectorize_node.register(Op) def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: + for inp in node.inputs: + if not isinstance(inp.type, (TensorType, ScalarType)): + raise NotImplementedError( + f"Cannot vectorize node {node} with input {inp} of type {inp.type}" + ) + if hasattr(op, "gufunc_signature"): signature = op.gufunc_signature else: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 8f4158696a..b9bcceb2db 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -603,6 +603,10 @@ def squeeze(x, axis=None): except np.AxisError: raise np.AxisError(axis, ndim=_x.ndim) + if not axis: + # Nothing to do + return _x + return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis]) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 0f035272af..7d1e32ba21 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -9,6 +9,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import Generic @@ -25,7 +26,7 @@ stack, switch, ) -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.type import ( @@ -2873,7 +2874,11 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) -_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") +_matrix_matrix_matmul = Blockwise( + _dot, + signature="(m,k),(k,n)->(m,n)", + gufunc_spec=("numpy.matmul", 2, 1), +) def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): @@ -2937,6 +2942,15 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None return out +@_vectorize_node.register(Dot) +def vectorize_node_to_matmul(op, node, batched_x, batched_y): + old_x, old_y = node.inputs + if old_x.type.ndim == 2 and old_y.type.ndim == 2: + return matmul(batched_x, batched_y).owner + else: + return vectorize_node_fallback(op, node, batched_x, batched_y) + + __all__ = [ "max_and_argmax", "max", diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 021660d8e0..98f6d68dab 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -67,9 +67,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays -from pytensor.tensor.math import Sum, add -from pytensor.tensor.math import all as at_all -from pytensor.tensor.math import eq +from pytensor.tensor.math import Sum, add, eq from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.sort import TopKOp from pytensor.tensor.type import DenseTensorType, TensorType @@ -266,6 +264,7 @@ def local_elemwise_alloc(fgraph, node): introduces them as a canonicalization of `Alloc`'s with leading broadcastable dimensions. """ + # This is handled by local_alloc_unary if len(node.inputs) == 1: return None @@ -465,14 +464,7 @@ def local_useless_alloc(fgraph, node): inp.type.dtype == output.type.dtype and inp.type.broadcastable == output.type.broadcastable ): - if inp.ndim == 0: - return [inp] - else: - return [ - Assert("Shapes must be equal")( - inp, at_all(eq(inp.shape, node.inputs[1:])) - ) - ] + return [inp] @register_specialize diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index a310cb5837..7434fd7e1c 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -59,6 +59,8 @@ import numpy as np +from pytensor.tensor.rewriting.basic import register_specialize + try: import numpy.__config__ # noqa @@ -79,12 +81,12 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError -from pytensor.printing import debugprint from pytensor.tensor import basic as at from pytensor.tensor.blas import ( Dot22, _dot22, _dot22scalar, + batched_dot, gemm_inplace, gemm_no_inplace, gemv_inplace, @@ -94,7 +96,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import Dot, add, mul, neg, sub +from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.type import ( DenseTensorType, @@ -899,9 +901,32 @@ def local_dot22_to_dot22scalar(fgraph, node): ) -# from opt import register_specialize, register_canonicalize -# @register_specialize -@node_rewriter([sub, add]) -def local_print_as_we_go_along(fgraph, node): - if node.op in (sub, add): - debugprint(node) +@register_specialize +@node_rewriter([_matrix_matrix_matmul]) +def specialize_matmul_to_batched_dot(fgraph, node): + """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot. + + TODO: Do the same for Blockwise BatchedDot + """ + x, y = node.inputs + + # BatchedDot does not allow implicit broadcasting of the batch dimensions + # We do not want to explicitly broadcast as it may result in huge arrays + if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]: + return None + + x_shape = tuple(x.shape) + y_shape = tuple(y.shape) + if len(x_shape) > 3: + # If we have more than one batch dim, ravel it + x = x.reshape((-1, x_shape[-2], x_shape[-1])) + y = y.reshape((-1, y_shape[-2], y_shape[-1])) + + new_out = batched_dot(x, y) + + if len(x_shape) > 3: + # And then unravel it + new_out = new_out.reshape((*x_shape[:-2], x_shape[-2], y_shape[-1])) + + copy_stack_trace(node.outputs, [new_out]) + return [new_out] diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 101aeec368..4cbfcdaa32 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,10 +1,18 @@ +from typing import Optional + from pytensor.compile.mode import optdb -from pytensor.graph import node_rewriter +from pytensor.graph import Constant, node_rewriter from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in +from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.math import _matrix_matrix_matmul -from pytensor.tensor.rewriting.basic import register_canonicalize +from pytensor.tensor.math import Dot +from pytensor.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, +) +from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor @node_rewriter([Blockwise]) @@ -29,8 +37,17 @@ def local_useless_unbatched_blockwise(fgraph, node): op = node.op inputs = node.inputs - if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0: - return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs) + batch_ndims = node.op.batch_ndim(node) + if all(all(inp.type.broadcastable[:batch_ndims]) for inp in inputs): + if batch_ndims: + # Remove dummy batch dims + axis = tuple(range(batch_ndims)) + inputs = [inp.squeeze(axis) for inp in inputs] + new_outs = op.core_op.make_node(*inputs).outputs + if batch_ndims: + # Reintroduce dummy batch dims + new_outs = [shape_padleft(out, batch_ndims) for out in new_outs] + return copy_stack_trace(node.outputs, new_outs) # We register this rewrite late, so that other rewrites need only target Blockwise Ops @@ -46,6 +63,139 @@ def local_useless_unbatched_blockwise(fgraph, node): # Avoid redundant cases early on for Ops whose default form is not Blockwised @register_canonicalize -@node_rewriter(tracks=[_matrix_matrix_matmul]) +@register_stabilize +@register_specialize +@node_rewriter(tracks=[Blockwise]) def local_eager_useless_unbatched_blockwise(fgraph, node): - return local_useless_unbatched_blockwise.fn(fgraph, node) + if isinstance( + node.op.core_op, + ( + # Many Dot-related rewrites (e.g., all of BlasOpt) happen before specialize + Dot, + # These Ops can't always be trivially vectorized at runtime, + # Since their inputs may imply non-rectangular shapes. + Alloc, + ARange, + Subtensor, + AdvancedSubtensor, + AdvancedIncSubtensor, + ), + ): + return local_useless_unbatched_blockwise.fn(fgraph, node) + + +def _squeeze_left(x, stop_at_dim: Optional[int] = None): + """Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached.""" + x_dims = x.type.broadcastable + squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False) + if stop_at_dim is not None: + squeeze_ndim = min(squeeze_ndim, stop_at_dim) + if squeeze_ndim == 0: + return x + return x.squeeze(axis=tuple(range(squeeze_ndim))) + + +@register_specialize("shape_unsafe") +@node_rewriter([Blockwise]) +def local_blockwise_alloc(fgraph, node): + """Push Allocs from the inputs to the output of Blockwise Ops. + + BOp = Blockwise(Op, signature="(x),(x)->(x)") + BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5) + BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5) + BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) + """ + + if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner): + return None + + op: Blockwise = node.op # type: ignore + + batch_ndim = op.batch_ndim(node) + if not batch_ndim: + return None + + new_inputs = [] + batch_shapes = [] + can_push_any_alloc = False + for inp, inp_sig in zip(node.inputs, op.inputs_sig): + if inp.owner and isinstance(inp.owner.op, Alloc): + # Push batch dims from Alloc + value, *shape = inp.owner.inputs + + # Check what to do with the value of the Alloc + squeezed_value = _squeeze_left(value, batch_ndim) + missing_ndim = len(shape) - value.type.ndim + if ( + ((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:] + ) != inp.type.broadcastable[batch_ndim:]: + # We still need an Alloc for the core dims + core_shape = shape[batch_ndim:] + # And the batch dims of the squeezed value + squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape) + batch_shape = [ + 1 if broadcastable else dim + for broadcastable, dim in zip( + squeezed_value.type.broadcastable[:squeezed_value_batch_ndim], + tuple(squeezed_value.shape)[:squeezed_value_batch_ndim], + ) + ] + squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) + if squeezed_value.type.broadcastable == inp.type.broadcastable: + # We can't change anything about this Alloc input + new_inputs.append(inp) + continue + + # We can push batch dims of this Alloc input + batch_shapes.append( + tuple( + 1 if broadcastable else dim + for broadcastable, dim in zip( + inp.type.broadcastable, shape[:batch_ndim] + ) + ) + ) + new_inputs.append(squeezed_value) + can_push_any_alloc = True + + else: + # Nothing to do with this input other than removing dummy batch dims + new_inputs.append(_squeeze_left(inp, batch_ndim)) + + if not can_push_any_alloc: + return None + + new_outs = node.op.make_node(*new_inputs).outputs + + new_out_type = new_outs[0].type + old_out_type = node.outputs[0].type + if new_out_type.broadcastable != old_out_type.broadcastable: + # An Alloc is still needed to broadcast the new output to the original shape + # We pick the most parsimonious batch dim from the pushed Alloc + missing_ndim = old_out_type.ndim - new_out_type.ndim + batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] + for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples + for batch_dim in batch_dims: + if batch_dim == 1: + continue + if isinstance(batch_dim, Constant): + # Give preference to Constants + batch_shape[i] = batch_dim + break + elif old_out_type.broadcastable[i]: + # Only use non Constant shapes if absolutely necessary + # Otherwise, we use the shape of the non-alloc output + batch_shape[i] = batch_dim + + copy_stack_trace(node.outputs, new_outs) + new_outs = [ + alloc( + new_out, + *batch_shape, + *tuple(new_out.shape)[batch_ndim - missing_ndim :], + ) + for new_out in new_outs + ] + assert new_outs[0].type.broadcastable == old_out_type.broadcastable + copy_stack_trace(node.outputs, new_outs) + return new_outs diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 475e454037..bc3eef6fca 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -138,14 +138,12 @@ def generic_solve_to_solve_triangular(fgraph, node): ] -@register_stabilize @register_specialize @node_rewriter([Blockwise]) def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions. - Only the last two dimensions of `b` and the output are swapped. """ core_op = node.op.core_op @@ -175,8 +173,17 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): new_core_op = type(core_op)(**props) matrix_b_solve = Blockwise(new_core_op) + # Ravel any batched dims + original_b_shape = tuple(b.shape) + if len(original_b_shape) > 2: + b = b.reshape((-1, original_b_shape[-1])) + # Apply the rewrite - new_solve = _T(matrix_b_solve(a, _T(b))) + new_solve = matrix_b_solve(a, b.T).T + + # Unravel any batched dims + if len(original_b_shape) > 2: + new_solve = new_solve.reshape(original_b_shape) old_solve = node.outputs[0] copy_stack_trace(old_solve, new_solve) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 4e80d3bb30..e860034235 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -29,6 +29,7 @@ register_infer_shape, switch, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Dot, add @@ -336,35 +337,46 @@ def local_subtensor_of_dot(fgraph, node): @node_rewriter([Subtensor]) def local_useless_slice(fgraph, node): """ - Remove Subtensor of the form X[0, :] -> X[0] + Remove Subtensor of the form: + 1. X[0, :] -> X[0] + 2. X[:] -> X + """ - if isinstance(node.op, Subtensor): - slices = get_idx_list(node.inputs, node.op.idx_list) - last_slice = len(slices) - for s in slices[::-1]: - # check if slice and then check slice indices - if ( - isinstance(s, slice) - and s.start is None - and s.stop is None - and ( - s.step is None - or extract_constant(s.step, only_process_constants=True) == 1 - ) - ): - last_slice -= 1 - else: - break - # check if we removed something - if last_slice < len(slices): - subtens = Subtensor(slices[:last_slice]) - sl_ins = get_slice_elements( - slices[:last_slice], lambda x: isinstance(x, Variable) + idxs = get_idx_list(node.inputs, node.op.idx_list) + + if not idxs: + return [node.inputs[0]] + + last_useless_slice = len(idxs) + for s in idxs[::-1]: + # check if slice and then check slice indices + if ( + isinstance(s, slice) + and s.start is None + and s.stop is None + and ( + s.step is None + or extract_constant(s.step, only_process_constants=True) == 1 ) - out = subtens(node.inputs[0], *sl_ins) + ): + last_useless_slice -= 1 + else: + break + # check if we removed something + if last_useless_slice < len(idxs): + new_idxs = idxs[:last_useless_slice] + if new_idxs: + new_subtensor = Subtensor(new_idxs) + new_subtensor_inputs = get_slice_elements( + new_idxs, lambda x: isinstance(x, Variable) + ) + out = new_subtensor(node.inputs[0], *new_subtensor_inputs) # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) return [out] + else: + # Subtensor is not needed at all + return [node.inputs[0]] # fast_compile to allow opt subtensor(cast{float32}(make_vector)) @@ -747,7 +759,13 @@ def local_subtensor_make_vector(fgraph, node): make_vector_op = x.owner.op if isinstance(node.op, Subtensor): - (idx,) = node.op.idx_list + idxs = node.op.idx_list + + # Subtensor has no indexes, return make_vector + if not idxs: + return [x] + + (idx,) = idxs if isinstance(idx, (aes.ScalarType, TensorType)): old_idx, idx = idx, node.inputs[1] @@ -903,7 +921,11 @@ def local_set_to_inc_subtensor(fgraph, node): @node_rewriter([Subtensor]) def local_useless_subtensor(fgraph, node): """Remove `Subtensor` if it takes the full input.""" - # This optimization needs ShapeOpt and fgraph.shape_feature + + if not node.op.idx_list: + return [node.inputs[0]] + + # The more elaborate optimization needs ShapeOpt and fgraph.shape_feature if not hasattr(fgraph, "shape_feature"): return @@ -1859,3 +1881,58 @@ def local_uint_constant_indices(fgraph, node): copy_stack_trace(node.outputs, new_outs) return new_outs + + +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") +@register_specialize("shape_unsafe") +@node_rewriter([Blockwise]) +def local_blockwise_advanced_inc_subtensor(fgraph, node): + """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices.""" + if not isinstance(node.op.core_op, AdvancedIncSubtensor): + return None + + x, y, *idxs = node.inputs + + # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case + if any( + ( + isinstance(idx, (SliceType, NoneTypeT)) + or (idx.type.dtype == "bool" and idx.type.ndim > 0) + ) + for idx in idxs + ): + return None + + op: Blockwise = node.op # type: ignore + batch_ndim = op.batch_ndim(node) + + new_idxs = [] + for idx in idxs: + if all(idx.type.broadcastable[:batch_ndim]): + new_idxs.append(idx.squeeze(tuple(range(batch_ndim)))) + else: + # Rewrite does not apply + return None + + x_batch_bcast = x.type.broadcastable[:batch_ndim] + y_batch_bcast = y.type.broadcastable[:batch_ndim] + if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)): + # Need to broadcast batch x dims + batch_shape = tuple( + x_dim if (not xb or yb) else y_dim + for xb, x_dim, yb, y_dim in zip( + x_batch_bcast, + tuple(x.shape)[:batch_ndim], + y_batch_bcast, + tuple(y.shape)[:batch_ndim], + ) + ) + core_shape = tuple(x.shape)[batch_ndim:] + x = alloc(x, *batch_shape, *core_shape) + + new_idxs = [slice(None)] * batch_ndim + new_idxs + symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:] + new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs + copy_stack_trace(node.outputs, new_out) + return new_out diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 1d8efa02c5..0d8dea8a2e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -868,7 +868,8 @@ def shape_padleft(t, n_ones=1): """ _t = at.as_tensor_variable(t) - + if n_ones == 0: + return _t pattern = ["x"] * n_ones + list(range(_t.type.ndim)) return _t.dimshuffle(pattern) @@ -884,7 +885,8 @@ def shape_padright(t, n_ones=1): """ _t = at.as_tensor_variable(t) - + if n_ones == 0: + return _t pattern = list(range(_t.type.ndim)) + ["x"] * n_ones return _t.dimshuffle(pattern) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index c05e965bf8..de0862f443 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -13,6 +13,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.graph.type import Type from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp @@ -22,6 +23,7 @@ from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero +from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.math import clip @@ -1283,6 +1285,21 @@ def _process(self, idxs, op_inputs, pstate): pprint.assign(Subtensor, SubtensorPrinter()) +# TODO: Implement similar vectorize for Inc/SetSubtensor +@_vectorize_node.register(Subtensor) +def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs): + """Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices.""" + + # TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor + if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs): + return vectorize_node_fallback(op, node, batch_x, *batch_idxs) + + old_x, *_ = node.inputs + batch_ndims = batch_x.type.ndim - old_x.type.ndim + new_idx_list = (slice(None),) * batch_ndims + op.idx_list + return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs) + + def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False): """ Return x with the given subtensor overwritten by y. diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 98bfbb610c..5aec8f88c2 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -43,15 +43,6 @@ def test_jax_BatchedDot(): with pytest.raises(TypeError): pytensor_jax_fn(*inputs) - # matrix . matrix - a = matrix("a") - a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3)) - b = matrix("b") - b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3)) - out = at_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - def test_jax_basic_multiout(): rng = np.random.default_rng(213234) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 7632aa6d33..92ab879e5c 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -843,23 +843,23 @@ def test_Softplus(x, exc): [ ( set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), + at.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), ), set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), + at.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), ), None, ), ( set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), + at.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), ), set_test_value( - at.lmatrix(), - rng.poisson(size=(3, 3)).astype("int64"), + at.ltensor3(), + rng.poisson(size=(2, 3, 3)).astype("int64"), ), None, ), diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 0e5c618ba0..cd5d3cc255 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -272,21 +272,36 @@ class TestLocalCanonicalizeAlloc: def setup_method(self): self.rng = np.random.default_rng(utt.fetch_seed()) - def test_inconsistent_shared(self): + @pytest.mark.parametrize("shape_unsafe", (True, False)) + def test_inconsistent_shared(self, shape_unsafe): # These shapes don't match! x = shared(self.rng.standard_normal((3, 7))) a = at.alloc(x, 6, 7) assert a.owner and isinstance(a.owner.op, Alloc) - f = function([], a, mode=rewrite_mode) + mode = rewrite_mode if shape_unsafe else rewrite_mode.excluding("shape_unsafe") + f = function([], a, mode=mode) - # The rewrite should then be applied, and remove Alloc - assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) - assert any(isinstance(node.op, Assert) for node in f.maker.fgraph.toposort()) - - with pytest.raises(AssertionError): - f() + has_alloc = any( + isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort() + ) + if shape_unsafe: + assert not has_alloc + # Error raised by SpecifyShape that is introduced due to static shape inference + with pytest.raises( + AssertionError, + match="SpecifyShape: dim 0 of input has shape 3, expected 6.", + ): + f() + else: + assert has_alloc + # Error raised by Alloc Op + with pytest.raises( + ValueError, + match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", + ): + f() good_x_val = self.rng.standard_normal((6, 7)) x.set_value(good_x_val) diff --git a/tests/tensor/rewriting/test_blas.py b/tests/tensor/rewriting/test_blas.py new file mode 100644 index 0000000000..efd18c3831 --- /dev/null +++ b/tests/tensor/rewriting/test_blas.py @@ -0,0 +1,48 @@ +import numpy as np +import pytest + +from pytensor import function +from pytensor.compile import get_default_mode +from pytensor.tensor import matmul, tensor, vectorize +from pytensor.tensor.blas import BatchedDot +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot + + +@pytest.mark.parametrize("valid_case", (True, False)) +def test_specialize_matmul_to_batched_dot(valid_case): + signature = BatchedDot.gufunc_signature + rewrite = specialize_matmul_to_batched_dot.__name__ + + def core_pt(x, y): + return matmul(x, y) + + def core_np(x, y): + return np.matmul(x, y) + + x = tensor(shape=(7, 5, 3, 3)) + if valid_case: + y = tensor(shape=(7, 5, 3, 3)) + else: + y = tensor(shape=(5, 3, 3)) + + vectorize_pt = function( + [x, y], + vectorize(core_pt, signature=signature)(x, y), + mode=get_default_mode().including(rewrite), + ) + blocwkise_node = any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + if valid_case: + assert not blocwkise_node + else: + assert blocwkise_node + + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype) + vectorize_np = np.vectorize(core_np, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test, y_test), + vectorize_np(x_test, y_test), + ) diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py index 0b67eba197..d5ea6e2b4e 100644 --- a/tests/tensor/rewriting/test_blockwise.py +++ b/tests/tensor/rewriting/test_blockwise.py @@ -1,7 +1,10 @@ +from functools import partial + from pytensor import function -from pytensor.graph import FunctionGraph +from pytensor.graph import FunctionGraph, rewrite_graph +from pytensor.graph.basic import equal_computations from pytensor.scalar import log as scalar_log -from pytensor.tensor import matrix, tensor3 +from pytensor.tensor import add, alloc, matrix, tensor, tensor3 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.nlinalg import MatrixPinv @@ -36,3 +39,82 @@ def test_useless_unbatched_blockwise(): fn = function([x], out, mode="FAST_COMPILE") assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv) + + +def test_blockwise_alloc(): + rewrite = partial( + rewrite_graph, + include=("ShapeOpt", "specialize"), + exclude=("local_useless_unbatched_blockwise",), + ) + + vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)") + + # Depending on the rewrites the Alloc shape may be upcast to int64 or not + # We do not care about that for the purposes of this test + equal = partial(equal_computations, strict_dtype=False) + + # Case where Alloc is not necessary + x = tensor("x", shape=(7, 5)) + y = tensor("y", shape=(5,)) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = vector_add(x, y) + assert equal([rewrite(out)], [expected_out]) + + # Cases where Alloc can be fully pushed + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(5,)) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = alloc(vector_add(x, y), 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(1, 5)) + y = tensor("y", shape=(5,)) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = alloc(vector_add(x.squeeze(0), y), 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(7, 5)) + y = tensor("y", shape=(7, 5)) + out = vector_add(x, alloc(y, 3, 7, 5)) + expected_out = alloc(vector_add(x, y), 3, 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(7, 1, 5)) + out = vector_add(x, alloc(y, 7, 2, 5)) + expected_out = alloc(vector_add(x, y), 7, 2, 5) + assert equal([rewrite(out)], [expected_out]) + + # Case where Alloc can be partially pushed + x = tensor("x", shape=(5,)) + y = tensor("y", shape=()) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = alloc(vector_add(x, alloc(y, 5)), 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(7, 1, 1)) + out = vector_add(x, alloc(y, 7, 2, 5)) + expected_out = alloc(vector_add(x, alloc(y, 7, 1, 5)), 7, 2, 5) + assert equal([rewrite(out)], [expected_out], strict_dtype=False) + + # Cases involving multiple Allocs being pushed + x = tensor("x", shape=()) + y = tensor("y", shape=()) + out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5)) + expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(5,)) + y = tensor("y", shape=()) + out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5)) + expected_out = alloc(vector_add(x, alloc(y, 5)), 3, 7, 5) + assert equal([rewrite(out)], [expected_out]) + + # Case where Alloc cannot be pushed + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(1,)) + out = vector_add(x, alloc(y, 5)) + expected_out = out + assert equal([rewrite(out)], [expected_out]) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index a5a643d0da..b77cdbe315 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -9,6 +9,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config +from pytensor.graph import FunctionGraph, vectorize_graph from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -17,10 +18,12 @@ from pytensor.raise_op import Assert from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, add, dot, exp, sqr from pytensor.tensor.rewriting.subtensor import ( local_replace_AdvancedSubtensor, + local_subtensor_make_vector, local_subtensor_shape_constant, ) from pytensor.tensor.shape import ( @@ -764,6 +767,17 @@ def test_stack_trace(self): f = function([x, y, z], v_subtensor, mode=mode) assert check_stack_trace(f, ops_to_check="all") + def test_empty_subtensor(self): + x, y = lscalars("xy") + v = make_vector(x, y) + out = v[()] + + fgraph = FunctionGraph(outputs=[out], clone=False) + node = fgraph.outputs[0].owner + assert isinstance(node.op, Subtensor) + + assert local_subtensor_make_vector.transform(fgraph, node) == [v] + class TestLocalSubtensorLift: def test_basic(self): @@ -2301,3 +2315,98 @@ def test_local_uint_constant_indices(): new_index = subtensor_node.inputs[1] assert isinstance(new_index, Constant) assert new_index.type.dtype == "uint8" + + +@pytest.mark.parametrize("set_instead_of_inc", (True, False)) +def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): + core_x = tensor("x", shape=(6,)) + core_y = tensor("y", shape=(3,)) + core_idxs = [0, 2, 4] + if set_instead_of_inc: + core_graph = set_subtensor(core_x[core_idxs], core_y) + else: + core_graph = inc_subtensor(core_x[core_idxs], core_y) + + # Only x is batched + x = tensor("x", shape=(5, 2, 6)) + y = tensor("y", shape=(3,)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype) + expected_out = test_x.copy() + if set_instead_of_inc: + expected_out[:, :, core_idxs] = test_y + else: + expected_out[:, :, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + # Only y is batched + x = tensor("y", shape=(6,)) + y = tensor("y", shape=(2, 3)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype) + expected_out = np.ones((2, *x.type.shape)) + if set_instead_of_inc: + expected_out[:, core_idxs] = test_y + else: + expected_out[:, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + # Both x and y are batched, and do not need to be broadcasted + x = tensor("y", shape=(2, 6)) + y = tensor("y", shape=(2, 3)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype) + expected_out = test_x.copy() + if set_instead_of_inc: + expected_out[:, core_idxs] = test_y + else: + expected_out[:, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + # Both x and y are batched, but must be broadcasted + x = tensor("y", shape=(5, 1, 6)) + y = tensor("y", shape=(1, 2, 3)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype) + final_shape = ( + *np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]), + x.type.shape[-1], + ) + expected_out = np.broadcast_to(test_x, final_shape).copy() + if set_instead_of_inc: + expected_out[:, :, core_idxs] = test_y + else: + expected_out[:, :, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 2c2b82d1b5..3ce5ffce63 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -20,7 +20,7 @@ from pytensor.misc.safe_asarray import _asarray from pytensor.raise_op import Assert from pytensor.scalar import autocast_float, autocast_float_as -from pytensor.tensor import NoneConst +from pytensor.tensor import NoneConst, vectorize from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -88,6 +88,7 @@ vertical_stack, zeros_like, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import dense_dot @@ -4517,3 +4518,26 @@ def test_trace(): trace(x, offset=-1, axis1=0, axis2=-1).eval(), np.trace(x_val, offset=-1, axis1=0, axis2=-1), ) + + +def test_vectorize_extract_diag(): + signature = "(a1,b,a2)->(b,a)" + + def core_pt(x): + return at.diagonal(x, offset=1, axis1=0, axis2=2) + + def core_np(x): + return np.diagonal(x, offset=1, axis1=0, axis2=2) + + x = tensor(shape=(5, 5, 5, 5)) + vectorize_pt = function([x], vectorize(core_pt, signature=signature)(x)) + assert not any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + vectorize_np = np.vectorize(core_np, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test), + vectorize_np(x_test), + ) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index a0533143ed..ac0a3c542e 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -1,3 +1,4 @@ +import re from itertools import product from typing import Optional, Union @@ -6,13 +7,15 @@ import pytensor from pytensor import config, function +from pytensor.compile import get_mode from pytensor.gradient import grad from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node from pytensor.raise_op import assert_op from pytensor.tensor import diagonal, log, tensor -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.nlinalg import MatrixInverse +from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.utils import _parse_gufunc_signature @@ -40,12 +43,29 @@ def test_vectorize_blockwise(): assert new_vect_node.inputs[0] is tns4 +def test_vectorize_node_fallback_unsupported_type(): + x = tensor("x", shape=(2, 6)) + node = x[:, [0, 2, 4]].owner + + with pytest.raises( + NotImplementedError, + match=re.escape( + "Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice" + ), + ): + vectorize_node_fallback(node.op, node, node.inputs) + + def check_blockwise_runtime_broadcasting(mode): a = tensor("a", shape=(None, 3, 5)) b = tensor("b", shape=(None, 5, 3)) out = a @ b - fn = function([a, b], out, mode=mode) + fn = function( + [a, b], + out, + mode=get_mode(mode).excluding(specialize_matmul_to_batched_dot.__name__), + ) assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) for valid_test_values in [ @@ -293,7 +313,7 @@ def test_grad(self): pt_out, np_out, rtol=1e-7 if config.floatX == "float64" else 1e-5, - atol=1e-6 if config.floatX == "float64" else 1e-5, + atol=1e-6 if config.floatX == "float64" else 1e-4, ) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 188c959bbc..9ee39a4a98 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -9,6 +9,7 @@ import pytensor import pytensor.scalar as scal import pytensor.tensor.basic as at +from pytensor import function from pytensor.compile import DeepCopyOp, shared from pytensor.compile.io import In from pytensor.configdefaults import config @@ -16,7 +17,8 @@ from pytensor.graph.rewriting.utils import is_same_graph from pytensor.printing import pprint from pytensor.scalar.basic import as_scalar -from pytensor.tensor import get_vector_length +from pytensor.tensor import get_vector_length, vectorize +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import exp, isinf from pytensor.tensor.math import sum as at_sum @@ -389,7 +391,8 @@ def test_0_dims(self): t = Subtensor([])(n) assert isinstance(t.owner.op, Subtensor) self.eval_output_and_check( - t, mode=self.mode.excluding("local_useless_subtensor") + t, + mode=self.mode.excluding("local_useless_subtensor", "local_useless_slice"), ) def test_err_invalid_2(self): @@ -2708,3 +2711,43 @@ def test_static_shapes(x_shape, indices, expected): x = at.tensor(dtype="float64", shape=x_shape) y = x[indices] assert y.type.shape == expected + + +def test_vectorize_subtensor_without_batch_indices(): + signature = "(t1,t2,t3),()->(t1,t3)" + + def core_fn(x, start): + return x[:, start, :] + + x = tensor(shape=(11, 7, 5, 3)) + start = tensor(shape=(), dtype="int") + vectorize_pt = function( + [x, start], vectorize(core_fn, signature=signature)(x, start) + ) + assert not any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + start_test = np.random.randint(0, x.type.shape[-2]) + vectorize_np = np.vectorize(core_fn, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test, start_test), + vectorize_np(x_test, start_test), + ) + + # If we vectorize start, we should get a Blockwise that still works + x = tensor(shape=(11, 7, 5, 3)) + start = tensor(shape=(11,), dtype="int") + vectorize_pt = function( + [x, start], vectorize(core_fn, signature=signature)(x, start) + ) + assert any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + start_test = np.random.randint(0, x.type.shape[-2], size=start.type.shape[0]) + vectorize_np = np.vectorize(core_fn, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test, start_test), + vectorize_np(x_test, start_test), + )