diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index b6f806bb4c..b39383a2bd 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -44,7 +44,16 @@ ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum +from pytensor.tensor.math import ( + All, + Argmax, + Max, + Min, + MulWithoutZeros, + Prod, + ProdWithoutZeros, + Sum, +) from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.type import scalar @@ -546,37 +555,52 @@ def ov_elemwise(*inputs): @numba_funcify.register(Sum) -def numba_funcify_Sum(op, node, **kwargs): +@numba_funcify.register(Prod) +@numba_funcify.register(ProdWithoutZeros) +@numba_funcify.register(Max) +@numba_funcify.register(Min) +@numba_funcify.register(All) +@numba_funcify.register(Any) +def numba_funcify_CAReduce_specialized(op, node, **kwargs): + if isinstance(op, ProdWithoutZeros): + # ProdWithoutZeros is the same as Prod but the gradient can assume no zeros + np_op = np.prod + else: + np_op = getattr(np, op.__class__.__name__.lower()) + axes = op.axis if axes is None: axes = list(range(node.inputs[0].ndim)) - axes = tuple(axes) + axes = tuple(sorted(axes)) ndim_input = node.inputs[0].ndim + out_dtype = np.dtype(node.outputs[0].dtype) - if hasattr(op, "acc_dtype") and op.acc_dtype is not None: - acc_dtype = op.acc_dtype - else: - acc_dtype = node.outputs[0].type.dtype - - np_acc_dtype = np.dtype(acc_dtype) + if len(axes) == 0: - out_dtype = np.dtype(node.outputs[0].dtype) + @numba_njit(fastmath=True) + def impl_sum(array): + return np.asarray(array, dtype=out_dtype) - if ndim_input == len(axes): + elif ( + len(axes) == 1 + # Some Ops don't support axis in Numba + and not isinstance(op, Prod | ProdWithoutZeros | All | Prod | Mean | Max | Min) + ): @numba_njit(fastmath=True) def impl_sum(array): - return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) + return np.asarray(np_op(array, axis=axes[0])).astype(out_dtype) - elif len(axes) == 0: + elif len(axes) == ndim_input: @numba_njit(fastmath=True) def impl_sum(array): - return np.asarray(array, dtype=out_dtype) + return np.asarray(np_op(array)).astype(out_dtype) else: + # Slow path impl_sum = numba_funcify_CAReduce(op, node, **kwargs) return impl_sum diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 8619b124be..410ba18e87 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1304,12 +1304,18 @@ def complex_from_polar(abs, angle): class Mean(FixedOpCAReduce): + # FIXME: Mean is not a true CAReduce in the PyTensor sense, because it needs to keep + # track of the number of elements already reduced in order to work iteratively. + # This should subclass a `ReduceOp` which `CAReduce` could also inherit from. __props__ = ("axis",) nfunc_spec = ("mean", 1, 1) def __init__(self, axis=None): super().__init__(ps.mean, axis) - assert self.axis is None or len(self.axis) == 1 + if not (self.axis is None or len(self.axis) == 1): + raise NotImplementedError( + "Mean Op only supports axis=None or a single axis. Use `mean` function instead" + ) def __str__(self): if self.axis is not None: diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 8bbbe164fc..726893089f 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -236,157 +236,21 @@ def test_Dimshuffle_non_contiguous(): assert func(np.zeros(3), np.array([1])).ndim == 0 -@pytest.mark.parametrize( - "careduce_fn, axis, v", - [ - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - (0, 1), - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - (1, 0), - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - None, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 1, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Prod( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Prod( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Prod( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 1, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), - None, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), - None, - set_test_value( - pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), - None, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), - None, - set_test_value( - pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) - ), - ), - ], -) -def test_CAReduce(careduce_fn, axis, v): - g = careduce_fn(v, axis=axis) - g_fg = FunctionGraph(outputs=[g]) +@pytest.mark.parametrize("axis", [0, None, (0, 1)]) +@pytest.mark.parametrize("op", [Sum, Prod, ProdWithoutZeros, All, Any, Mean, Max, Min]) +def test_CAReduce(op, axis): + if op == Mean and isinstance(axis, tuple) and len(axis) > 1: + pytest.xfail("Mean does not support multiple partial axes") - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) + bool_reduction = op in (All, Any) + x = pt.tensor3("x", dtype=bool if bool_reduction else config.floatX) + g = op(axis=axis)(x) + g_fg = FunctionGraph([x], [g]) + + x_test = np.random.normal(size=(2, 3, 4)).astype(config.floatX) + if bool_reduction: + x_test = x_test > 0 + compare_numba_and_py(g_fg, [x_test]) def test_scalar_Elemwise_Clip(): @@ -665,3 +529,27 @@ def test_elemwise_out_type(): x_val = np.broadcast_to(np.zeros((3,)), (6, 3)) assert func(x_val).shape == (18,) + + +@pytest.mark.parametrize("axis", [0, 2, (0, 2), None]) +@pytest.mark.parametrize("op", [Sum, Max, Any]) +def test_careduce_benchmark(benchmark, op, axis): + rng = np.random.default_rng(123) + N = 256 + if op == All: + # Sparse tensor + value = np.zeros((N, N, N), dtype="bool") + true_arrays = np.random.choice(N, size=N // 2, replace=False) + true_rows = np.random.choice(N, size=N // 2, replace=False) + true_cols = np.random.choice(N, size=N // 2, replace=False) + value[true_arrays, true_rows, true_cols] = True + else: + value = rng.normal(size=(N, N, N)) + + x = pytensor.shared(value, name="x") + out = op(axis=axis)(x) + + func = pytensor.function([], [out], mode="NUMBA") + # JIT compile first + func() + benchmark(func)