From 4c0354d4b874dc133e457c263a123a18920c9082 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Jan 2024 12:24:47 +0100 Subject: [PATCH 1/3] Fix vectorize_node function name --- pytensor/tensor/math.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 45c926f501..6abb7f2a5f 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2948,9 +2948,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None @_vectorize_node.register(Dot) -def vectorize_node_to_matmul(op, node, batched_x, batched_y): +def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y): old_x, old_y = node.inputs if old_x.type.ndim == 2 and old_y.type.ndim == 2: + # If original input is equivalent to a matrix-matrix product, + # return specialized Matmul Op to avoid unnecessary new Ops. return matmul(batched_x, batched_y).owner else: return vectorize_node_fallback(op, node, batched_x, batched_y) From a91908174603df781a61d746fa3293044e1c1693 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Jan 2024 12:25:29 +0100 Subject: [PATCH 2/3] Remove useless try/except All types of axis arguments are supported by max_and_argmax --- pytensor/tensor/math.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 6abb7f2a5f..a356896e8c 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -671,13 +671,7 @@ def max(x, axis=None, keepdims=False): # thing is supporting all user interface features, not speed. # Some cases can be implemented only with CAReduce. - # We thus prefer to use MaxAndArgmax, if possible. It does not - # support all axis arguments, so we may need to fall back to CAReduce. - - try: - out = max_and_argmax(x, axis)[0] - except Exception: - out = Max(axis)(x) + out = max_and_argmax(x, axis)[0] if keepdims: out = makeKeepDims(x, out, axis) From 1606171fbcc35ce9a808f39aac5835d84808293c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Jan 2024 12:26:31 +0100 Subject: [PATCH 3/3] Implement vectorize_node for Softmax and Argmax Ops Also refactors shared logic for other batch axed Ops --- pytensor/tensor/basic.py | 22 +++++++++++----- pytensor/tensor/elemwise.py | 50 +++++++++++++++++++++++++----------- pytensor/tensor/math.py | 23 +++++++++++++++-- pytensor/tensor/shape.py | 13 ++++++---- pytensor/tensor/special.py | 28 ++++++++++++++++++++ tests/tensor/test_math.py | 30 ++++++++++++++++++++++ tests/tensor/test_special.py | 32 ++++++++++++++++++++++- 7 files changed, 169 insertions(+), 29 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 207fd4909a..2a39eca520 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -43,7 +43,12 @@ get_vector_length, ) from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise +from pytensor.tensor.elemwise import ( + DimShuffle, + Elemwise, + get_normalized_batch_axes, + scalar_elemwise, +) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import ( Shape, @@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1): @_vectorize_node.register(ExtractDiag) -def vectorize_extract_diag(op: ExtractDiag, node, batched_x): - batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim +def vectorize_extract_diag(op: ExtractDiag, node, batch_x): + core_ndim = node.inputs[0].type.ndim + batch_ndim = batch_x.type.ndim - core_ndim + batch_axis1, batch_axis2 = get_normalized_batch_axes( + (op.axis1, op.axis2), core_ndim, batch_ndim + ) + return diagonal( - batched_x, + batch_x, offset=op.offset, - axis1=op.axis1 + batched_ndims, - axis2=op.axis2 + batched_ndims, + axis1=batch_axis1, + axis2=batch_axis2, ).owner diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 6c01d574d8..42d1ae8768 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,6 +1,8 @@ from copy import copy +from typing import Union import numpy as np +from numpy.core.numeric import normalize_axis_tuple import pytensor.tensor.basic from pytensor.configdefaults import config @@ -1399,7 +1401,7 @@ def make_node(self, input): # scalar inputs are treated as 1D regarding axis in this `Op` if axis is not None: try: - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims)) + axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims)) except np.AxisError: raise np.AxisError(axis, ndim=inp_dims) @@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl return DimShuffle(input_broadcastable, new_order).make_node(x) -@_vectorize_node.register(CAReduce) -def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: - batched_ndims = x.type.ndim - node.inputs[0].type.ndim - if not batched_ndims: - return node.op.make_node(x) - axes = op.axis - # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) - # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) - if axes is None: - axes = list(range(node.inputs[0].type.ndim)) +def get_normalized_batch_axes( + core_axes: Union[None, int, tuple[int, ...]], + core_ndim: int, + batch_ndim: int, +) -> tuple[int, ...]: + """Compute batch axes for a batched operation, from the core input ndim and axes. + + e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) + batch_axes(None, 2, 4) -> (2, 3) + + e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) + batch_axes(0, 2, 4) -> (2,) + + e.g., sum(tensor3, axis=(0, -1)) -> sum(tensor4, axis=(1, 3)) + batch_axes((0, -1), 3, 4) -> (1, 3) + """ + if core_axes is None: + core_axes = tuple(range(core_ndim)) else: - axes = list(axes) - new_axes = [axis + batched_ndims for axis in axes] - new_op = op.clone(axis=new_axes) - return new_op.make_node(x) + core_axes = normalize_axis_tuple(core_axes, core_ndim) + return tuple(core_axis + batch_ndim for core_axis in core_axes) + + +@_vectorize_node.register(CAReduce) +def vectorize_careduce(op: CAReduce, node: Apply, batch_x: TensorVariable) -> Apply: + core_ndim = node.inputs[0].type.ndim + batch_ndim = batch_x.type.ndim - core_ndim + + if not batch_ndim: + return node.op.make_node(batch_x) + + batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim) + return op.clone(axis=batch_axes).make_node(batch_x) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index a356896e8c..a1774d9dac 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -27,7 +27,13 @@ switch, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback -from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise +from pytensor.tensor.elemwise import ( + CAReduce, + DimShuffle, + Elemwise, + get_normalized_batch_axes, + scalar_elemwise, +) from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.type import ( DenseTensorType, @@ -134,7 +140,7 @@ class MaxAndArgmax(COp): _f16_ok = True def __init__(self, axis): - assert isinstance(axis, list) + assert isinstance(axis, (tuple, list)) self.axis = tuple(axis) def get_params(self, node): @@ -465,6 +471,19 @@ def grad(self, inp, grads): return [x.zeros_like()] +@_vectorize_node.register(Argmax) +@_vectorize_node.register(MaxAndArgmax) +def vectorize_argmax_node(op, node, batch_x): + core_ndim = node.inputs[0].type.ndim + batch_ndim = batch_x.type.ndim - core_ndim + + if not batch_ndim: + return node.op.make_node(batch_x) + + batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim) + return type(op)(axis=batch_axes).make_node(batch_x) + + def makeKeepDims(x, y, axis): """ Reintroduces in y with length one the axes of x which have been left out diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 1a83a41122..a76003c90c 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -18,6 +18,7 @@ from pytensor.tensor import _get_vector_length, as_tensor_variable from pytensor.tensor import basic as ptb from pytensor.tensor import get_vector_length +from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor from pytensor.tensor.type_other import NoneConst @@ -1103,8 +1104,10 @@ def unbroadcast(x, *axes): @_vectorize_node.register(Unbroadcast) -def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply: - batched_ndims = x.type.ndim - node.inputs[0].type.ndim - old_axes = op.axes - new_axes = (old_axis + batched_ndims for old_axis in old_axes) - return cast(Apply, unbroadcast(x, *new_axes).owner) +def _vectorize_unbroadcast( + op: Unbroadcast, node: Apply, batch_x: TensorVariable +) -> Apply: + core_ndim = node.inputs[0].type.ndim + batch_ndim = batch_x.type.ndim - core_ndim + batch_axes = get_normalized_batch_axes(op.axes, core_ndim, batch_ndim) + return cast(Apply, unbroadcast(batch_x, *batch_axes).owner) diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index 7b5e52d637..f9d37a3a3f 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -4,8 +4,10 @@ import scipy from pytensor.graph.basic import Apply +from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.math import gamma, gammaln, neg, sum @@ -736,6 +738,32 @@ def log_softmax(c, axis=None): return LogSoftmax(axis=axis)(c) +@_vectorize_node.register(Softmax) +@_vectorize_node.register(LogSoftmax) +def vectorize_softmax_node(op, node, batched_x): + """ + Vectorize Softmax and LogSoftmax nodes. + + """ + core_ndim = node.inputs[0].type.ndim + batch_ndim = batched_x.type.ndim - core_ndim + + if not batch_ndim: + return op.make_node(batched_x) + + batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim) + + if len(batch_axes) > 1: + from pytensor.tensor.blockwise import vectorize_node_fallback + + # The softmax Ops only allow a specific axis (integer) or all axis (None). + # If the vectorized operation requires more than one axis we have to default to a Blockwise + return vectorize_node_fallback(op, node, batched_x) + + [batch_axis] = batch_axes + return type(op)(axis=batch_axis).make_node(batched_x) + + def poch(z, m): """ Pochhammer symbol (rising factorial) function. diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index d543019f8d..a3a6be4235 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -20,6 +20,7 @@ from pytensor.gradient import NullTypeGradError, grad, numeric_grad from pytensor.graph.basic import Variable, applys_between from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import vectorize_node from pytensor.link.c.basic import DualLinker from pytensor.misc.safe_asarray import _asarray from pytensor.printing import pprint @@ -1010,6 +1011,35 @@ def test_numpy_input(self): assert max_pt.eval() == 3 assert argmax_pt.eval() == 2 + @pytest.mark.parametrize( + "core_axis, batch_axis", + [ + (None, (1, 2, 3, 4)), + (0, (1,)), + ((1, -1), (2, 4)), + ], + ) + def test_vectorize(self, core_axis, batch_axis): + x = tensor(shape=(5, 5, 5, 5)) + batch_x = tensor(shape=(3, 5, 5, 5, 5)) + + # Test MaxAndArgmax + max_x, argmax_x = max_and_argmax(x, axis=core_axis) + node = max_x.owner + assert isinstance(node.op, MaxAndArgmax) + + new_node = vectorize_node(node, batch_x) + assert isinstance(new_node.op, MaxAndArgmax) + assert new_node.op.axis == batch_axis + + # Test Argmax + # Argmax is not user-facing, so we have to create it manually + node = Argmax(axis=node.op.axis).make_node(x) + + new_node = vectorize_node(node, batch_x) + assert isinstance(new_node.op, Argmax) + assert new_node.op.axis == batch_axis + class TestArgminArgmax: def setup_method(self): diff --git a/tests/tensor/test_special.py b/tests/tensor/test_special.py index a7448f1d86..298df728ca 100644 --- a/tests/tensor/test_special.py +++ b/tests/tensor/test_special.py @@ -8,6 +8,8 @@ from pytensor.compile.function import function from pytensor.configdefaults import config +from pytensor.graph.replace import vectorize_node +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.special import ( LogSoftmax, Softmax, @@ -19,7 +21,7 @@ poch, softmax, ) -from pytensor.tensor.type import matrix, tensor3, tensor4, vector, vectors +from pytensor.tensor.type import matrix, tensor, tensor3, tensor4, vector, vectors from tests import unittest_tools as utt from tests.tensor.utils import random_ranged @@ -150,6 +152,34 @@ def test_valid_axis(self): SoftmaxGrad(-4)(*x) +@pytest.mark.parametrize( + "core_axis, batch_axis", + [ + (None, (1, 2, 3, 4)), + (0, (1,)), + ], +) +@pytest.mark.parametrize( + "op, constructor", [(Softmax, softmax), (LogSoftmax, log_softmax)] +) +def test_vectorize_softmax(op, constructor, core_axis, batch_axis): + x = tensor(shape=(5, 5, 5, 5)) + batch_x = tensor(shape=(3, 5, 5, 5, 5)) + + node = constructor(x, axis=core_axis).owner + assert isinstance(node.op, op) + + new_node = vectorize_node(node, batch_x) + if len(batch_axis) == 1: + assert isinstance(new_node.op, op) + assert (new_node.op.axis,) == batch_axis + else: + assert isinstance(new_node.op, Blockwise) and isinstance( + new_node.op.core_op, op + ) + assert new_node.op.core_op.axis == core_axis + + def test_poch(): _z, _m = vectors("z", "m") actual_fn = function([_z, _m], poch(_z, _m))