From 22fbf9dcde5df6a86586787a07254f5f5d8ef83f Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 27 Apr 2024 07:38:52 +0530 Subject: [PATCH 1/4] Break MaxandArgmax to TensorMax and Argmax seperately --- pytensor/ifelse.py | 3 +- pytensor/tensor/math.py | 218 +++++- pytensor/tensor/rewriting/uncanonicalize.py | 49 +- tests/tensor/rewriting/test_uncanonicalize.py | 109 +-- tests/tensor/test_math.py | 2 +- tests/tensor/test_max_argmax.py | 693 ++++++++++++++++++ 6 files changed, 990 insertions(+), 84 deletions(-) create mode 100644 tests/tensor/test_max_argmax.py diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 1383cea263..6aea34f262 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -477,7 +477,8 @@ def cond_make_inplace(fgraph, node): Reshape, Unbroadcast, pt.math.Dot, - pt.math.MaxAndArgmax, + pt.math.TensorMax, + pt.math.Argmax, pt.subtensor.Subtensor, pt.subtensor.IncSubtensor, pt.basic.Alloc, diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 63a943e1f1..706db19702 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -149,6 +149,8 @@ class MaxAndArgmax(COp): def __init__(self, axis): assert isinstance(axis, tuple | list) + # print(axis) + # assert 0 self.axis = tuple(axis) def get_params(self, node): @@ -343,6 +345,208 @@ def grad(self, inp, grads): return (g_x,) +class TensorMax(COp): + """ + Calculate the max over a given axis or over all axes. + + """ + + nin = 2 # tensor, axis + nout = 1 # max val + E_axis = "invalid axis" + params_type = Generic() + __props__ = ("axis",) + _f16_ok = True + + def __init__(self, axis): + assert isinstance(axis, tuple | list) + self.axis = tuple(axis) + + def get_params(self, node): + return self.axis + + def make_node(self, x): + x = as_tensor_variable(x) + + # Keep the original shapes for axes on which we do not perform the max/argmax. + all_axes = set(self.axis) + inputs = [x] + out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes) + outputs = [ + tensor(dtype=x.type.dtype, shape=out_shape, name="max"), + ] + return Apply(self, inputs, outputs) + + def prepare_node(self, node, storage_map, compute_map, impl): + if len(node.inputs) == 2: + raise ValueError( + "You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format." + ) + + def perform(self, node, inp, outs): + x = inp[0] + axes = self.axis + # max, max_idx = outs + (max,) = outs + if axes is None: + axes = tuple(range(x.ndim)) + else: + axes = tuple(int(ax) for ax in axes) + max[0] = _asarray(np.max(x, axes), dtype=node.outputs[0].dtype) + # # Numpy does not support multiple axes for argmax + # # Work around + # keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") + # # Not-reduced axes in front + # transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) + # kept_shape = transposed_x.shape[: len(keep_axes)] + # reduced_shape = transposed_x.shape[len(keep_axes) :] + + # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 + # # Otherwise reshape would complain citing float arg + # new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) + # reshaped_x = transposed_x.reshape(new_shape) + + # max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") + + def c_code(self, node, name, inp, out, sub): + if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim: + raise NotImplementedError( + "NumPy C-API can compute max only for 1 axis or for all axes." + ) + x = inp[0] + axis = sub["params"] + # max, argmax = out + (max,) = out + fail = sub["fail"] + ret = """ + #if PY_MAJOR_VERSION >= 3 + #ifndef PyInt_AS_LONG + #define PyInt_AS_LONG PyLong_AS_LONG + #endif + #endif + + int axis; + + if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) { + axis = NPY_MAXDIMS; + } else if(PyTuple_GET_SIZE(%(axis)s) == 1) { + PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0); + axis = (int)PyInt_AS_LONG(axis_object); + if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) { + PyErr_SetString(PyExc_ValueError, + "TensorMax: bad axis argument"); + %(fail)s + } + } else { + PyErr_SetString(PyExc_NotImplementedError, + "TensorMax: NumPy C-API can compute max only for 1 axis or for all axes."); + %(fail)s + } + + Py_CLEAR(%(max)s); + + %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL); + if (%(max)s == NULL) { + %(fail)s; + } + if (!PyArray_CheckExact(%(max)s)) { + %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); + if(%(max)s == NULL){ + %(fail)s; + } + } + """ + return ret % locals() + + def c_code_cache_version(self): + return (5,) + + def infer_shape(self, fgraph, node, shapes): + ishape = shapes[0] + rval = tuple( + ishape[i] + for (i, b) in enumerate(node.inputs[0].type.broadcastable) + if i not in self.axis + ) + return [rval] + + def R_op(self, inputs, eval_points): + if eval_points[0] is None: + return [None, None] + + if len(self.axis) != 1: + raise ValueError("R_op supported for arg_max only for one axis!") + if self.axis[0] > 1: + raise ValueError("R_op supported for arg_max only when axis is 0 or 1") + if inputs[0].ndim != 2: + raise ValueError("R_op supported for arg_max only when input is a matrix") + # max_vals, max_pos = self.make_node(*inputs).outputs + # max_vals = self.make_node(*inputs).outputs + if self.axis[0] == 0: + return [eval_points[0][arange(eval_points[0].shape[1])], None] + else: + return [eval_points[0][arange(eval_points[0].shape[0])], None] + + def grad(self, inp, grads): + # The strict sense mathematical gradient of the maximum function is + # not calculated here for it is not defined at every point where some + # coordinates are identical. However, since the latter set has null + # Lebesgue measure, the result may be interpreted as weak gradient. + + # @note: This function should work correctly for L{vector}s. + # (x, y), (gz, gw) + # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy + # gMax * dMax/dx + gArgMax * dArgMax/dx, + # gMax * dMax/daxis + gArgMax * dArgMax/daxis + # g_max has one less dimension than x, so you need to complete + # g_max to x's shape when axis=0 the broadcasting mechanism + # does it automatically + x = inp[0] + axis = as_tensor_variable(self.axis) + # g_max, g_max_idx = grads + (g_max,) = grads + + g_max_disconnected = isinstance(g_max.type, DisconnectedType) + # g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType) + + # # if the op is totally disconnected, so are its inputs + # if g_max_disconnected and g_max_idx_disconnected: + # return [DisconnectedType()(), DisconnectedType()()] + + # if the op is totally disconnected, so are its inputs + if g_max_disconnected: + return [DisconnectedType()()] + + # if the max is disconnected but the argmax is not, + # the gradient on its inputs is zero + # if g_max_disconnected: + # return [x.zeros_like()] + if NoneConst.equals(axis): + axis_ = list(range(x.ndim)) + else: + axis_ = axis + xmax = max(x, axis_) + + # Raise the g_max and xmax to the same number of dim as the input. + pattern = [] + out_dim = 0 + if NoneConst.equals(axis): + # We are taking the max/argmax over all dimensions. + axis = None + for i in range(x.ndim): + if axis is None or i in axis.data: + pattern.append("x") + else: + pattern.append(out_dim) + out_dim += 1 + g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) + xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) + + # Set the grad to the correct position. + g_x = eq(xmax_pad, x) * g_max_pad + return (g_x,) + + class Argmax(COp): """ Calculate the argmax over a given axis or over all axes. @@ -357,8 +561,10 @@ class Argmax(COp): params_type = ParamsType(c_axis=ps.int64) def __init__(self, axis): - if axis is not None: - axis = tuple(axis) + # if axis is not None: + # axis = tuple(axis) + assert isinstance(axis, tuple | list) + # print(axis) self.axis = tuple(axis) def get_params(self, node): @@ -395,6 +601,8 @@ def perform(self, node, inp, outs): (max_idx,) = outs if axes is None: axes = tuple(range(x.ndim)) + else: + axes = tuple(int(ax) for ax in axes) # Numpy does not support multiple axes for argmax # Work around @@ -477,7 +685,7 @@ def grad(self, inp, grads): @_vectorize_node.register(Argmax) -@_vectorize_node.register(MaxAndArgmax) +# @_vectorize_node.register(MaxAndArgmax) def vectorize_argmax_node(op, node, batch_x): core_ndim = node.inputs[0].type.ndim batch_ndim = batch_x.type.ndim - core_ndim @@ -600,7 +808,9 @@ def max_and_argmax(a, axis=None, keepdims=False): axis = check_and_normalize_axes(a, axis) if len(axis) == 0: axis = list(range(a.type.ndim)) - out, argout = MaxAndArgmax(axis)(a) + out = TensorMax(axis)(a) + argout = Argmax(axis)(a) + # out, argout = MaxAndArgmax(axis)(a) if keepdims: out = makeKeepDims(a, out, axis) diff --git a/pytensor/tensor/rewriting/uncanonicalize.py b/pytensor/tensor/rewriting/uncanonicalize.py index 15a316c5a0..624353c253 100644 --- a/pytensor/tensor/rewriting/uncanonicalize.py +++ b/pytensor/tensor/rewriting/uncanonicalize.py @@ -31,33 +31,34 @@ """ -from pytensor import scalar as ps from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.tensor.basic import Alloc, alloc, constant -from pytensor.tensor.elemwise import CAReduce, DimShuffle -from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg +from pytensor.tensor.elemwise import DimShuffle + +# from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg +from pytensor.tensor.math import Min, TensorMax, neg from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.subtensor import Subtensor -@register_uncanonicalize -@node_rewriter([MaxAndArgmax]) -def local_max_and_argmax(fgraph, node): - """ - If we don't use the argmax, change it to a max only. - """ - if isinstance(node.op, MaxAndArgmax): - axis = node.op.axis - if len(fgraph.clients[node.outputs[1]]) == 0: - new = Max(axis)(node.inputs[0]) - copy_stack_trace(node.outputs[0], new) - return [new, None] +# @register_uncanonicalize +# @node_rewriter([MaxAndArgmax]) +# def local_max_and_argmax(fgraph, node): +# """ +# If we don't use the argmax, change it to a max only. +# """ +# if isinstance(node.op, MaxAndArgmax): +# axis = node.op.axis +# if len(fgraph.clients[node.outputs[1]]) == 0: +# new = Max(axis)(node.inputs[0]) +# copy_stack_trace(node.outputs[0], new) +# return [new, None] - if len(fgraph.clients[node.outputs[0]]) == 0: - new = Argmax(axis)(node.inputs[0]) - copy_stack_trace(node.outputs[0], new) - return [None, new] +# if len(fgraph.clients[node.outputs[0]]) == 0: +# new = Argmax(axis)(node.inputs[0]) +# copy_stack_trace(node.outputs[0], new) +# return [None, new] @register_uncanonicalize @@ -74,13 +75,13 @@ def local_max_to_min(fgraph, node): the interface put only MaxAndArgmax into the graph. """ + # pytensor.dprint(node) + # print() + # print(node.op == neg) if node.op == neg and node.inputs[0].owner: max = node.inputs[0] - if ( - max.owner - and isinstance(max.owner.op, CAReduce) - and max.owner.op.scalar_op == ps.scalar_maximum - ): + # print(max.owner.op.scalar_op) + if max.owner and isinstance(max.owner.op, TensorMax): neg_node = max.owner.inputs[0] if neg_node.owner and neg_node.owner.op == neg: new = Min(max.owner.op.axis)(neg_node.owner.inputs[0]) diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py index d36447ac20..32c141cab3 100644 --- a/tests/tensor/rewriting/test_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -9,8 +9,7 @@ from pytensor.graph.rewriting.basic import out2in from pytensor.link.basic import PerformLinker from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import MaxAndArgmax, max_and_argmax -from pytensor.tensor.math import max as pt_max +from pytensor.tensor.math import TensorMax from pytensor.tensor.math import min as pt_min from pytensor.tensor.rewriting.uncanonicalize import ( local_alloc_dimshuffle, @@ -23,26 +22,26 @@ from tests.link.test_link import make_function -class TestMaxAndArgmax: - def test_optimization(self): - # If we use only the max output, we should replace this op with - # a faster one. - mode = pytensor.compile.mode.get_default_mode().including( - "canonicalize", "fast_run" - ) +# class TestMaxAndArgmax: +# def test_optimization(self): +# # If we use only the max output, we should replace this op with +# # a faster one. +# mode = pytensor.compile.mode.get_default_mode().including( +# "canonicalize", "fast_run" +# ) - for axis in [0, 1, -1]: - n = matrix() +# for axis in [0, 1, -1]: +# n = matrix() - f = function([n], max_and_argmax(n, axis)[0], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, CAReduce) +# f = function([n], max_and_argmax(n, axis)[0], mode=mode) +# topo = f.maker.fgraph.toposort() +# assert len(topo) == 1 +# assert isinstance(topo[0].op, CAReduce) - f = function([n], max_and_argmax(n, axis), mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, MaxAndArgmax) +# f = function([n], max_and_argmax(n, axis), mode=mode) +# topo = f.maker.fgraph.toposort() +# assert len(topo) == 1 +# assert isinstance(topo[0].op, MaxAndArgmax) class TestMinMax: @@ -51,38 +50,40 @@ def setup_method(self): "canonicalize", "fast_run" ) - def test_optimization_max(self): - data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) - n = matrix() - - for axis in [0, 1, -1]: - f = function([n], pt_max(n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, CAReduce) - f(data) - - f = function([n], pt_max(-n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - assert isinstance(topo[0].op, Elemwise) - assert isinstance(topo[0].op.scalar_op, ps.Neg) - assert isinstance(topo[1].op, CAReduce) - f(data) - - f = function([n], -pt_max(n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - assert isinstance(topo[0].op, CAReduce) - assert isinstance(topo[1].op, Elemwise) - assert isinstance(topo[1].op.scalar_op, ps.Neg) - f(data) - - f = function([n], -pt_max(-n, axis), mode=self.mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, CAReduce) # min - f(data) + # def test_optimization_max(self): + # data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) + # n = matrix() + + # for axis in [0, 1, -1]: + # f = function([n], pt_max(n, axis), mode=self.mode) + # topo = f.maker.fgraph.toposort() + # assert len(topo) == 1 + # # assert isinstance(topo[0].op, CAReduce) + # f(data) + + # f = function([n], pt_max(-n, axis), mode=self.mode) + # topo = f.maker.fgraph.toposort() + # import pytensor + # pytensor.dprint(topo) + # assert len(topo) == 2 + # assert isinstance(topo[0].op, Elemwise) + # assert isinstance(topo[0].op.scalar_op, ps.Neg) + # assert isinstance(topo[1].op, CAReduce) + # f(data) + + # f = function([n], -pt_max(n, axis), mode=self.mode) + # topo = f.maker.fgraph.toposort() + # assert len(topo) == 2 + # assert isinstance(topo[0].op, CAReduce) + # assert isinstance(topo[1].op, Elemwise) + # assert isinstance(topo[1].op.scalar_op, ps.Neg) + # f(data) + + # f = function([n], -pt_max(-n, axis), mode=self.mode) + # topo = f.maker.fgraph.toposort() + # assert len(topo) == 1 + # assert isinstance(topo[0].op, CAReduce) # min + # f(data) def test_optimization_min(self): data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) @@ -99,7 +100,7 @@ def test_optimization_min(self): f = function([n], pt_min(-n, axis), mode=self.mode) topo = f.maker.fgraph.toposort() assert len(topo) == 2 - assert isinstance(topo[0].op, CAReduce) # max + assert isinstance(topo[0].op, TensorMax) # max assert isinstance(topo[1].op, Elemwise) assert isinstance(topo[1].op.scalar_op, ps.Neg) f(data) @@ -109,13 +110,13 @@ def test_optimization_min(self): assert len(topo) == 2 assert isinstance(topo[0].op, Elemwise) assert isinstance(topo[0].op.scalar_op, ps.Neg) - assert isinstance(topo[1].op, CAReduce) # max + assert isinstance(topo[1].op, TensorMax) # max f(data) f = function([n], -pt_min(-n, axis), mode=self.mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 - assert isinstance(topo[0].op, CAReduce) # max + assert isinstance(topo[0].op, TensorMax) # max f(data) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index e346348406..bb460d1fa6 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -763,7 +763,7 @@ def setup_method(self): MaxAndArgmax.debug = 0 def test_basic(self): - n = as_tensor_variable(5.0) + n = as_tensor_variable([5.0]) v, i = eval_outputs(max_and_argmax(n)) assert v == 5.0 assert i == 0 diff --git a/tests/tensor/test_max_argmax.py b/tests/tensor/test_max_argmax.py new file mode 100644 index 0000000000..d9b263e57c --- /dev/null +++ b/tests/tensor/test_max_argmax.py @@ -0,0 +1,693 @@ +import builtins + +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import get_default_mode +from pytensor.compile.sharedvalue import shared +from pytensor.configdefaults import config +from pytensor.gradient import grad, numeric_grad +from pytensor.graph.replace import vectorize_node +from pytensor.tensor.basic import ( + as_tensor_variable, + constant, + get_underlying_scalar_constant_value, +) +from pytensor.tensor.math import ( + Argmax, + TensorMax, + argmax, + argmin, + max, + max_and_argmax, + min, +) +from pytensor.tensor.type import ( + matrix, + tensor, +) +from pytensor.tensor.type_other import NoneConst +from tests import unittest_tools as utt +from tests.tensor.utils import ( + eval_outputs, + random, +) + + +if config.mode == "FAST_COMPILE": + mode_opt = "FAST_RUN" +else: + mode_opt = get_default_mode() + + +class TestMaxAndArgmax: + def setup_method(self): + TensorMax.debug = 0 + + def test_basic(self): + # dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0) + n = as_tensor_variable([5.0]) + v, i = eval_outputs(max_and_argmax(n)) + assert v == 5.0 + assert i == 0 + assert i.dtype == "int64" + v = eval_outputs(max_and_argmax(n)[0].shape) + assert len(v) == 0 + v = eval_outputs(max_and_argmax(n)[1].shape) + assert len(v) == 0 + + def test_basic_1(self): + n = as_tensor_variable([1, 2, 3, 2, -6]) + v, i = eval_outputs(max_and_argmax(n)) + assert v == 3 + assert i == 2 + assert i.dtype == "int64" + v = eval_outputs(max_and_argmax(n)[0].shape) + assert len(v) == 0 + + @pytest.mark.parametrize( + "axis,np_axis", + [ + (-1, -1), + (0, 0), + (1, 1), + (None, None), + ([0, 1], None), + ([1, 0], None), + (NoneConst.clone(), None), + (constant(0), 0), + ], + ) + def test_basic_2(self, axis, np_axis): + data = random(2, 3) + n = as_tensor_variable(data) + # Test shape propagates (static & eval) + vt, it = max_and_argmax(n, axis) + np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) + assert vt.type.shape == np_max.shape + assert it.type.shape == np_argm.shape + v_shape, i_shape = eval_outputs([vt.shape, it.shape]) + assert tuple(v_shape) == vt.type.shape + assert tuple(i_shape) == it.type.shape + # Test values + v, i = eval_outputs([vt, it]) + assert i.dtype == "int64" + assert np.all(v == np_max) + assert np.all(i == np_argm) + + @pytest.mark.parametrize( + "axis,np_axis", + [ + (-1, -1), + (0, 0), + (1, 1), + (None, None), + ([0, 1], None), + ([1, 0], None), + (NoneConst.clone(), None), + (constant(0), 0), + ], + ) + def test_basic_2_float16(self, axis, np_axis): + # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16 + data = (random(20, 30).astype("float16") - 0.5) * 20 + n = as_tensor_variable(data) + # Test shape propagates (static & eval) + vt, it = max_and_argmax(n, axis) + np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) + assert vt.type.shape == np_max.shape + assert it.type.shape == np_argm.shape + v_shape, i_shape = eval_outputs([vt.shape, it.shape]) + assert tuple(v_shape) == vt.type.shape + assert tuple(i_shape) == it.type.shape + # Test values + v, i = eval_outputs([vt, it]) + assert i.dtype == "int64" + assert np.all(v == np_max) + assert np.all(i == np_argm) + + def test_basic_2_invalid(self): + n = as_tensor_variable(random(2, 3)) + with pytest.raises(ValueError): + eval_outputs(max_and_argmax(n, 3)) + + n = as_tensor_variable(random(2, 3)) + with pytest.raises(ValueError): + eval_outputs(max_and_argmax(n, -3)) + + def test_basic_2_valid_neg(self): + n = as_tensor_variable(random(2, 3)) + v, i = eval_outputs(max_and_argmax(n, -1)) + assert i.dtype == "int64" + assert v.shape == (2,) + assert i.shape == (2,) + assert np.all(v == np.max(n.value, -1)) + assert np.all(i == np.argmax(n.value, -1)) + v, i = eval_outputs(max_and_argmax(n, -2)) + assert i.dtype == "int64" + assert v.shape == (3,) + assert i.shape == (3,) + assert np.all(v == np.max(n.value, -2)) + assert np.all(i == np.argmax(n.value, -2)) + v = eval_outputs(max_and_argmax(n, -1)[0].shape) + assert v == (2) + v = eval_outputs(max_and_argmax(n, -2)[0].shape) + assert v == (3) + + @pytest.mark.parametrize( + "axis,np_axis", + [ + (-1, -1), + (0, 0), + (1, 1), + (None, None), + ([0, 1, 2], None), + ([1, 2, 0], None), + ], + ) + def test_basic_3(self, axis, np_axis): + data = random(2, 3, 4) + n = as_tensor_variable(data) + # Test shape propagates (static & eval) + vt, it = max_and_argmax(n, axis) + np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) + assert vt.type.shape == np_max.shape + assert it.type.shape == np_argm.shape + v_shape, i_shape = eval_outputs([vt.shape, it.shape]) + assert tuple(v_shape) == vt.type.shape + assert tuple(i_shape) == it.type.shape + # Test values + v, i = eval_outputs([vt, it]) + assert i.dtype == "int64" + assert np.all(v == np_max) + assert np.all(i == np_argm) + + def test_arg_grad(self): + # The test checks that the gradient of argmax(x).sum() is 0 + + x = matrix() + cost = argmax(x, axis=0).sum() + gx = grad(cost, x) + val = get_underlying_scalar_constant_value(gx) + assert val == 0.0 + + def test_grad(self): + data = random(2, 3) + n = as_tensor_variable(data) + + def safe_verify_grad(func, data): + # Wrapper around 'verify_grad' that picks a proper value for epsilon. + # + # This is needed because 'verify_grad' may fail when its epsilon is + # too large, due to the fact the argmax is not continuous. + # We make sure epsilon is less than the minimum absolute value found + # in the matrix of pairwise differences between all elements in the + # data. This way, the argmax will not change when adding epsilon. + + # 'data' is a one-element list. + (data_tensor,) = data + # Flatten it into a 1D vector. + data_vector = data_tensor.flatten() + # Compute pairwise absolute differences. + diff = np.abs(data_vector.reshape((-1, 1)) - data_vector) + # Alter the diagonal to avoid a zero minimum. + for i in range(len(diff)): + diff[i, i] = 1 + # Find an appropriate epsilon. + eps = builtins.min(numeric_grad.type_eps[config.floatX], diff.min() / 2) + # Run gradient verification. + utt.verify_grad(func, data, eps=eps) + + def check_grad_max(data, max_grad_data, axis=None): + # Why this is needed? verify_grad is not enough? + # This works only for axis in [0, None]. + assert axis in [0, None] + z = np.zeros_like(data) + z = z.flatten() + argmax = np.argmax(data, axis=axis) + if argmax.ndim == 0: + z[argmax] += 1 + else: + for id, v in enumerate(argmax): + z[v * np.prod(data.shape[data.ndim - 1 : axis : -1]) + id] += 1 + + z = z.reshape(data.shape) + assert np.all(max_grad_data == z) + + for axis in (-1, 0, 1, None): + for j in range(2): + safe_verify_grad(lambda v: max_and_argmax(v, axis=axis)[j], [data]) + if axis != 1: + safe_verify_grad( + lambda v: max_and_argmax(v.flatten(), axis=axis)[j], [data] + ) + if axis in (0, None): + check_grad_max( + data, + eval_outputs(grad(max_and_argmax(n, axis=axis)[0].sum(), n)), + axis=axis, + ) + check_grad_max(data, eval_outputs(grad(max_and_argmax(n.flatten())[0], n))) + + # Test 3d inner dimensions + data = random(3, 4, 5) + + for i in [0, 1, 2]: + safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data]) + safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) + + # Test 4d inner dimensions + data = random(2, 3, 4, 5) + + for i in [0, 1, 2, 3]: + safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data]) + safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) + + # Test grad with multiple axes + for i in [[0, 1], [0, 0]]: + safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data]) + safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data]) + + def test_preserve_broadcastable(self): + # Ensure the original broadcastable flags are preserved by Max/Argmax. + x = matrix().dimshuffle("x", 0, "x", 1, "x") + y = x.max(axis=1) + assert y.type.shape == (1, 1, None, 1) + assert y.type.broadcastable == (True, True, False, True) + + def test_multiple_axes(self): + data = np.arange(24).reshape(3, 2, 4) + x = as_tensor_variable(data) + vt, it = max_and_argmax(x, [1, -1]) + assert vt.type.shape == it.type.shape == (3,) + v, i = eval_outputs([vt, it]) + assert np.all(v == np.array([7, 15, 23])) + assert np.all(i == np.array([7, 7, 7])) + v = eval_outputs(vt.shape) + assert tuple(v) == vt.type.shape + + def test_zero_shape(self): + x = matrix() + m, i = max_and_argmax(x, axis=1) + f = function([x], [m, i]) + xv = np.zeros((0, 4), dtype=config.floatX) + mv, iv = f(xv) + assert mv.shape == (0,) + assert iv.shape == (0,) + + def test_numpy_input(self): + ar = np.array([1, 2, 3]) + max_pt, argmax_pt = max_and_argmax(ar, axis=None) + assert max_pt.eval() == 3 + assert argmax_pt.eval() == 2 + + @pytest.mark.parametrize( + "core_axis, batch_axis", + [ + (None, (1, 2, 3, 4)), + (0, (1,)), + ((1, -1), (2, 4)), + ], + ) + def test_vectorize(self, core_axis, batch_axis): + x = tensor(shape=(5, 5, 5, 5)) + batch_x = tensor(shape=(3, 5, 5, 5, 5)) + + # Test MaxAndArgmax + max_x, argmax_x = max_and_argmax(x, axis=core_axis) + node = max_x.owner + assert isinstance(node.op, TensorMax) + + # dbt: how to make Argmax user facing? + # new_node = vectorize_node(node, batch_x) + # pytensor.dprint(new_node) + # assert isinstance(new_node.op, Argmax) + # assert new_node.op.axis == batch_axis + + # Test Argmax + # Argmax is not user-facing, so we have to create it manually + node = Argmax(axis=node.op.axis).make_node(x) + + new_node = vectorize_node(node, batch_x) + # print() + # pytensor.dprint(new_node) + # print() + assert isinstance(new_node.op, Argmax) + assert new_node.op.axis == batch_axis + + +class TestArgminArgmax: + def setup_method(self): + TensorMax.debug = 0 + + def test_scalar(self): + for fct in [argmin, argmax]: + n = as_tensor_variable([5.0]) + i = eval_outputs(fct(n)) + assert i == 0 + v = eval_outputs(fct(n).shape) + assert len(v) == 0 + + def test_list(self): + n = as_tensor_variable([1, 2, 3, 2, -6]) + i = eval_outputs(argmin(n)) + assert i == 4 + v = eval_outputs(argmin(n).shape) + assert len(v) == 0 + + n = as_tensor_variable([1, 2, 3, 2, -6]) + i = eval_outputs(argmax(n)) + assert i == 2 + v = eval_outputs(argmax(n).shape) + assert len(v) == 0 + + def test2(self): + data = random(2, 3) + n = as_tensor_variable(data) + for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: + for axis, np_axis in [ + (-1, -1), + (0, 0), + (1, 1), + (None, None), + ([0, 1], None), + ([1, 0], None), + ]: + v = eval_outputs(fct(n, axis)) + assert np.all(v == nfct(data, np_axis)) + v_shape = eval_outputs(fct(n, axis).shape) + assert tuple(v_shape) == nfct(data, np_axis).shape + + def test2_float16(self): + # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16 + data = (random(20, 30).astype("float16") - 0.5) * 20 + n = shared(data) + mode = get_default_mode().including("local_max_and_argmax", "uncanonicalize") + for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: + for axis, np_axis in [ + (-1, -1), + (0, 0), + (1, 1), + (None, None), + ([0, 1], None), + ([1, 0], None), + ]: + v = eval_outputs(fct(n, axis), (Argmax,), mode=mode) + assert np.all(v == nfct(data, np_axis)) + v_shape = eval_outputs(fct(n, axis).shape, mode=mode) + assert tuple(v_shape) == nfct(data, np_axis).shape + + def test2_invalid(self): + for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: + n = as_tensor_variable(random(2, 3)) + with pytest.raises(ValueError): + eval_outputs(fct(n, 3)) + with pytest.raises(ValueError): + eval_outputs(fct(n, -3)) + + def test2_valid_neg(self): + for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: + n = as_tensor_variable(random(2, 3)) + i = eval_outputs(fct(n, -1)) + assert i.shape == (2,) + assert np.all(i == nfct(n.value, -1)) + i = eval_outputs(fct(n, -2)) + assert i.shape == (3,) + assert np.all(i == nfct(n.value, -2)) + + v = eval_outputs(fct(n, -1).shape) + assert v == (2) + v = eval_outputs(fct(n, -2).shape) + assert v == (3) + + def test3(self): + data = random(2, 3, 4) + n = as_tensor_variable(data) + for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: + for axis, np_axis in [ + (-1, -1), + (0, 0), + (1, 1), + (2, 2), + (None, None), + ([0, 1, 2], None), + ([1, 0, 2], None), + ]: + v = eval_outputs(fct(n, axis)) + assert np.all(v == nfct(data, np_axis)) + v_shape = eval_outputs(fct(n, axis).shape) + assert tuple(v_shape) == nfct(data, np_axis).shape + + def test_grad_argmin(self): + data = random(2, 3) + n = as_tensor_variable(data) + n.name = "n" + + # test grad of argmin + utt.verify_grad(lambda v: argmin(v, axis=-1), [data]) + + utt.verify_grad(lambda v: argmin(v, axis=[0]), [data]) + + utt.verify_grad(lambda v: argmin(v, axis=[1]), [data]) + + utt.verify_grad(lambda v: argmin(v.flatten()), [data]) + + try: + cost = argmin(n, axis=-1) + cost.name = None + grad(cost, n) + raise Exception("Expected an error") + except TypeError: + pass + + def test_grad_argmax(self): + data = random(2, 3) + n = as_tensor_variable(data) + + # test grad of argmax + utt.verify_grad(lambda v: argmax(v, axis=-1), [data]) + + utt.verify_grad(lambda v: argmax(v, axis=[0]), [data]) + + utt.verify_grad(lambda v: argmax(v, axis=[1]), [data]) + + utt.verify_grad(lambda v: argmax(v.flatten()), [data]) + + try: + grad(argmax(n, axis=-1), n) + raise Exception("Expected an error") + except TypeError: + pass + + def test_uint(self): + for dtype in ("uint8", "uint16", "uint32", "uint64"): + itype = np.iinfo(dtype) + data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) + n = as_tensor_variable(data) + i = eval_outputs(argmin(n)) + assert i == 1 + i = eval_outputs(argmax(n)) + assert i == 3 + + def test_bool(self): + data = np.array([True, False], "bool") + n = as_tensor_variable(data) + i = eval_outputs(argmin(n)) + assert i == 1 + i = eval_outputs(argmax(n)) + assert i == 0 + + +class TestMinMax: + def setup_method(self): + TensorMax.debug = 0 + + def test_scalar(self): + for fct in [max, min]: + n = as_tensor_variable(5.0) + v = eval_outputs(fct(n)) + assert v == 5.0 + + v = eval_outputs(fct(n).shape) + assert len(v) == 0 + + def test_list(self): + for fct, nfct in [(max, np.max), (min, np.min)]: + n = as_tensor_variable([1, 2, 3, 2, -6]) + v = eval_outputs([fct(n)]) + assert v == nfct(n.value) + + v = eval_outputs(fct(n).shape) + assert len(v) == 0 + + def test2(self): + data = random(2, 3) + n = as_tensor_variable(data) + for fct, nfct in [(max, np.max), (min, np.min)]: + for axis, np_axis in [ + (-1, -1), + (0, 0), + (1, 1), + (None, None), + ([0, 1], None), + ([1, 0], None), + ]: + v = eval_outputs(fct(n, axis)) + assert np.all(v == nfct(data, np_axis)) + v_shape = eval_outputs(fct(n, axis).shape) + assert tuple(v_shape) == nfct(data, np_axis).shape + + def test2_invalid(self): + for fct in [max, min]: + n = as_tensor_variable(random(2, 3)) + with pytest.raises(ValueError): + eval_outputs(fct(n, 3)) + with pytest.raises(ValueError): + eval_outputs(fct(n, -3)) + + def test2_valid_neg(self): + for fct, nfct in [(max, np.max), (min, np.min)]: + n = as_tensor_variable(random(2, 3)) + v = eval_outputs(fct(n, -1)) + assert v.shape == (2,) + assert np.all(v == nfct(n.value, -1)) + v = eval_outputs(fct(n, -2)) + assert v.shape == (3,) + assert np.all(v == nfct(n.value, -2)) + + v = eval_outputs(fct(n, -1).shape) + assert v == (2) + v = eval_outputs(fct(n, -2).shape) + assert v == (3) + + def test3(self): + # Test with 1 axis or all axis out of 3 dims + data = random(2, 3, 4) + n = as_tensor_variable(data) + for fct, nfct in [(max, np.max), (min, np.min)]: + for axis, np_axis in [ + (-1, -1), + (0, 0), + (1, 1), + (2, 2), + (None, None), + ([0, 1, 2], None), + ([1, 0, 2], None), + ]: + v = eval_outputs(fct(n, axis)) + assert np.all(v == nfct(data, np_axis)) + v_shape = eval_outputs(fct(n, axis).shape) + assert tuple(v_shape) == nfct(data, np_axis).shape + + def test3b(self): + # Test with 2 axis out of 3 dims + data = random(2, 3, 4) + n = as_tensor_variable(data) + for fct, nfct in [(max, np.max), (min, np.min)]: + for axis in [[0, 1], [1, 2], [0, 2]]: + v = eval_outputs(fct(n, axis)) + np_v = nfct(nfct(data, axis[1]), axis[0]) + assert np.all(v == np_v) + v_shape = eval_outputs(fct(n, axis).shape) + assert tuple(v_shape) == np_v.shape + + def test_grad_max(self): + data = random(2, 3) + n = as_tensor_variable(data) + + def check_grad_max(data, max_grad_data, axis=None): + # This work only for axis in [0,None] + assert axis in [0, None] + z = np.zeros_like(data) + z = z.flatten() + argmax = np.argmax(data, axis=axis) + if argmax.ndim == 0: + z[np.argmax(data, axis=axis)] += 1 + else: + for id, v in enumerate(argmax): + z[v * np.prod(data.shape[data.ndim - 1 : axis : -1]) + id] += 1 + + z = z.reshape(data.shape) + assert np.all(max_grad_data == z) + + # test grad of max + # axis is the last one + utt.verify_grad(lambda v: max(v, axis=-1), [data]) + + utt.verify_grad(lambda v: max(v, axis=[0]), [data]) + check_grad_max(data, eval_outputs(grad(max(n, axis=0).sum(), n)), axis=0) + + utt.verify_grad(lambda v: max(v, axis=[1]), [data]) + # check_grad_max(data,eval_outputs(grad(max(n,axis=1),n)),axis=1) + + utt.verify_grad(lambda v: max(v.flatten()), [data]) + check_grad_max(data, eval_outputs(grad(max(n.flatten()), n))) + + def test_grad_min(self): + data = random(2, 3) + n = as_tensor_variable(data) + + def check_grad_min(data, min_grad_data, axis=None): + # This work only for axis in [0, None] + assert axis in [0, None] + z = np.zeros_like(data) + z = z.flatten() + argmin = np.argmin(data, axis=axis) + if argmin.ndim == 0: + z[np.argmin(data, axis=axis)] += 1 + else: + for id, v in enumerate(argmin): + z[v * np.prod(data.shape[data.ndim - 1 : axis : -1]) + id] += 1 + + z = z.reshape(data.shape) + assert np.all(min_grad_data == z) + + # test grad of min + # axis is the last one + utt.verify_grad(lambda v: min(v, axis=-1), [data]) + + utt.verify_grad(lambda v: min(v, axis=[0]), [data]) + check_grad_min(data, eval_outputs(grad(min(n, axis=0).sum(), n)), axis=0) + + utt.verify_grad(lambda v: min(v, axis=[1]), [data]) + # check_grad_min(data,eval_outputs(grad(min(n,axis=1),n)),axis=1) + + utt.verify_grad(lambda v: min(v.flatten()), [data]) + check_grad_min(data, eval_outputs(grad(min(n.flatten()), n))) + + def _grad_list(self): + # Test the gradient when we have multiple axis at the same time. + # + # This not implemented, so we disable the test. See ticket: + # http://www.assembla.com/spaces/pytensor/tickets/511 + data = random(2, 3) + for fct in [max_and_argmax, max, min]: + utt.verify_grad(lambda v: fct(v, axis=[0, 1]), [data]) + # n = as_tensor_variable(data) + # check_grad_max(data, eval_outputs(grad(max_and_argmax(n, + # axis=1)[0], n)),axis=1) + + def test_uint(self): + for dtype in ("uint8", "uint16", "uint32", "uint64"): + itype = np.iinfo(dtype) + data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) + n = as_tensor_variable(data) + assert min(n).dtype == dtype + i = eval_outputs(min(n)) + assert i == itype.min + assert max(n).dtype == dtype + i = eval_outputs(max(n)) + assert i == itype.max + + def test_bool(self): + data = np.array([True, False], "bool") + n = as_tensor_variable(data) + assert min(n).dtype == "bool" + i = eval_outputs(min(n)) + assert i.ndim == 0 + assert not np.any(i) + assert max(n).dtype == "bool" + i = eval_outputs(max(n)) + assert i.ndim == 0 + assert np.all(i) From 8c2931422bbc8cc2a6f694697ff25b926d5cbe70 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 27 Apr 2024 07:42:36 +0530 Subject: [PATCH 2/4] Added changes for seperating MaxandArgmax Op Scalar problem solved Finalise changes to seperate MaxAndArgmax Op --- pytensor/compile/function/__init__.py | 1 + pytensor/compile/function/types.py | 1 + pytensor/graph/op.py | 1 + pytensor/ifelse.py | 2 +- pytensor/link/jax/dispatch/nlinalg.py | 71 +- pytensor/link/numba/dispatch/elemwise.py | 32 +- pytensor/tensor/math.py | 522 +++---------- pytensor/tensor/rewriting/uncanonicalize.py | 35 +- tests/link/numba/test_basic.py | 2 + tests/link/numba/test_elemwise.py | 49 +- tests/tensor/rewriting/test_math.py | 6 +- tests/tensor/rewriting/test_uncanonicalize.py | 64 +- tests/tensor/test_math.py | 99 ++- tests/tensor/test_max_argmax.py | 693 ------------------ tests/tensor/utils.py | 1 + 15 files changed, 307 insertions(+), 1272 deletions(-) delete mode 100644 tests/tensor/test_max_argmax.py diff --git a/pytensor/compile/function/__init__.py b/pytensor/compile/function/__init__.py index b7ebba01e8..f020e84376 100644 --- a/pytensor/compile/function/__init__.py +++ b/pytensor/compile/function/__init__.py @@ -312,6 +312,7 @@ def opt_log1p(node): else: # note: pfunc will also call orig_function -- orig_function is # a choke point that all compilation must pass through + fn = pfunc( params=inputs, outputs=outputs, diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index c221d7cf41..959a0cb14c 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -1758,6 +1758,7 @@ def orig_function( name=name, fgraph=fgraph, ) + print(m) with config.change_flags(compute_test_value="off"): fn = m.create(defaults) finally: diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 160a65dd7a..a70c2b9730 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -291,6 +291,7 @@ def __call__( """ node = self.make_node(*inputs, **kwargs) + if name is not None: if len(node.outputs) == 1: node.outputs[0].name = name diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 6aea34f262..ecc00dbe1a 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -477,7 +477,7 @@ def cond_make_inplace(fgraph, node): Reshape, Unbroadcast, pt.math.Dot, - pt.math.TensorMax, + pt.math.Max, pt.math.Argmax, pt.subtensor.Subtensor, pt.subtensor.IncSubtensor, diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 2f92364379..e76781afa2 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -2,7 +2,7 @@ from pytensor.link.jax.dispatch import jax_funcify from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.math import Dot, MaxAndArgmax +from pytensor.tensor.math import Argmax, Dot, Max from pytensor.tensor.nlinalg import ( SVD, Det, @@ -104,18 +104,73 @@ def batched_dot(a, b): return batched_dot -@jax_funcify.register(MaxAndArgmax) -def jax_funcify_MaxAndArgmax(op, **kwargs): +# @jax_funcify.register(Max) +# @jax_funcify.register(Argmax) +# def jax_funcify_MaxAndArgmax(op, **kwargs): +# axis = op.axis + +# def maxandargmax(x, axis=axis): +# if axis is None: +# axes = tuple(range(x.ndim)) +# else: +# axes = tuple(int(ax) for ax in axis) + +# max_res = jnp.max(x, axis) + +# # NumPy does not support multiple axes for argmax; this is a +# # work-around +# keep_axes = jnp.array( +# [i for i in range(x.ndim) if i not in axes], dtype="int64" +# ) +# # Not-reduced axes in front +# transposed_x = jnp.transpose( +# x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) +# ) +# kept_shape = transposed_x.shape[: len(keep_axes)] +# reduced_shape = transposed_x.shape[len(keep_axes) :] + +# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 +# # Otherwise reshape would complain citing float arg +# new_shape = ( +# *kept_shape, +# jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), +# ) +# reshaped_x = transposed_x.reshape(new_shape) + +# max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") + +# return max_res, max_idx_res + +# return maxandargmax + + +@jax_funcify.register(Max) +def jax_funcify_Max(op, **kwargs): axis = op.axis - def maxandargmax(x, axis=axis): + def max(x, axis=axis): + # if axis is None: + # axes = tuple(range(x.ndim)) + # else: + # axes = tuple(int(ax) for ax in axis) + + max_res = jnp.max(x, axis) + + return max_res + + return max + + +@jax_funcify.register(Argmax) +def jax_funcify_Argmax(op, **kwargs): + axis = op.axis + + def argmax(x, axis=axis): if axis is None: axes = tuple(range(x.ndim)) else: axes = tuple(int(ax) for ax in axis) - max_res = jnp.max(x, axis) - # NumPy does not support multiple axes for argmax; this is a # work-around keep_axes = jnp.array( @@ -138,6 +193,6 @@ def maxandargmax(x, axis=axis): max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") - return max_res, max_idx_res + return max_idx_res - return maxandargmax + return argmax diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 207ebd5cf2..fbbec2587c 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -44,7 +44,7 @@ ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum +from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.type import scalar @@ -985,8 +985,8 @@ def log_softmax_py_fn(x): return log_softmax -@numba_funcify.register(MaxAndArgmax) -def numba_funcify_MaxAndArgmax(op, node, **kwargs): +@numba_funcify.register(Argmax) +def numba_funcify_Argmax(op, node, **kwargs): axis = op.axis x_at = node.inputs[0] x_dtype = x_at.type.numpy_dtype @@ -996,8 +996,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): if x_ndim == 0: @numba_basic.numba_njit(inline="always") - def maxandargmax(x): - return x, 0 + def argmax(x): + return 0 else: axes = tuple(int(ax) for ax in axis) @@ -1006,20 +1006,6 @@ def maxandargmax(x): # work-around keep_axes = tuple(i for i in range(x_ndim) if i not in axes) - reduce_max_py_fn = create_multiaxis_reducer( - scalar_maximum, - -np.inf, - axes, - x_ndim, - x_dtype, - return_scalar=False, - ) - reduce_max = jit_compile_reducer( - Apply(node.op, node.inputs, [node.outputs[0].clone()]), - reduce_max_py_fn, - reduce_to_scalar=False, - ) - reduced_x_ndim = x_ndim - len(axes) + 1 argmax_axis = create_axis_apply_fn( np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64 @@ -1030,9 +1016,7 @@ def maxandargmax(x): sl2 = slice(len(keep_axes), None) @numba_basic.numba_njit - def maxandargmax(x): - max_res = reduce_max(x) - + def argmax(x): # Not-reduced axes in front transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order)) kept_shape = transposed_x.shape[sl1] @@ -1048,6 +1032,6 @@ def maxandargmax(x): max_idx_res = argmax_axis(reshaped_x) - return max_res, max_idx_res + return max_idx_res - return maxandargmax + return argmax diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 706db19702..da872bc791 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -14,7 +14,6 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.link.c.type import Generic from pytensor.misc.safe_asarray import _asarray from pytensor.printing import pprint from pytensor.raise_op import Assert @@ -29,6 +28,7 @@ constant, stack, switch, + zeros_like, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import ( @@ -134,417 +134,13 @@ def _allclose(a, b, rtol=None, atol=None): return np.allclose(a, b, atol=atol_, rtol=rtol_) -class MaxAndArgmax(COp): - """ - Calculate the max and argmax over a given axis or over all axes. - - """ - - nin = 2 # tensor, axis - nout = 2 # max val, max idx - E_axis = "invalid axis" - params_type = Generic() - __props__ = ("axis",) - _f16_ok = True - - def __init__(self, axis): - assert isinstance(axis, tuple | list) - # print(axis) - # assert 0 - self.axis = tuple(axis) - - def get_params(self, node): - return self.axis - - def make_node(self, x): - x = as_tensor_variable(x) - - # Keep the original shapes for axes on which we do not perform the max/argmax. - all_axes = set(self.axis) - inputs = [x] - out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes) - outputs = [ - tensor(dtype=x.type.dtype, shape=out_shape, name="max"), - tensor(dtype="int64", shape=out_shape, name="argmax"), - ] - return Apply(self, inputs, outputs) - - def perform(self, node, inp, outs): - x = inp[0] - axes = self.axis - max, max_idx = outs - if axes is None: - axes = tuple(range(x.ndim)) - else: - axes = tuple(int(ax) for ax in axes) - max[0] = _asarray(np.max(x, axes), dtype=node.outputs[0].dtype) - # Numpy does not support multiple axes for argmax - # Work around - keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") - # Not-reduced axes in front - transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) - kept_shape = transposed_x.shape[: len(keep_axes)] - reduced_shape = transposed_x.shape[len(keep_axes) :] - - # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 - # Otherwise reshape would complain citing float arg - new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) - reshaped_x = transposed_x.reshape(new_shape) - - max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") - - def c_code(self, node, name, inp, out, sub): - if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim: - raise NotImplementedError( - "NumPy C-API can compute max and argmax only for 1 axis or for all axes." - ) - x = inp[0] - axis = sub["params"] - max, argmax = out - fail = sub["fail"] - ret = """ - #if PY_MAJOR_VERSION >= 3 - #ifndef PyInt_AS_LONG - #define PyInt_AS_LONG PyLong_AS_LONG - #endif - #endif - - int axis; - - if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) { - axis = NPY_MAXDIMS; - } else if(PyTuple_GET_SIZE(%(axis)s) == 1) { - PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0); - axis = (int)PyInt_AS_LONG(axis_object); - if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) { - PyErr_SetString(PyExc_ValueError, - "MaxAndArgmax: bad axis argument"); - %(fail)s - } - } else { - PyErr_SetString(PyExc_NotImplementedError, - "MaxAndArgmax: NumPy C-API can compute max and argmax only for 1 axis or for all axes."); - %(fail)s - } - - Py_CLEAR(%(max)s); - Py_CLEAR(%(argmax)s);//todo pass them as out parameter. - - %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL); - if (%(max)s == NULL) { - %(fail)s; - } - if (!PyArray_CheckExact(%(max)s)) { - %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); - if(%(max)s == NULL){ - %(fail)s; - } - } - - %(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL); - if (%(argmax)s == NULL) { - Py_CLEAR(%(max)s); - %(fail)s; - } - if (!PyArray_CheckExact(%(argmax)s)) { - %(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); - if(%(argmax)s == NULL){ - %(fail)s; - } - } - if (PyArray_TYPE(%(argmax)s) != NPY_INT64) { - PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64); - if (NULL == tmp){ - %(fail)s; - } - Py_DECREF(%(argmax)s); - %(argmax)s = (PyArrayObject*)tmp; - } - """ - return ret % locals() - - def c_code_cache_version(self): - return (5,) - - def infer_shape(self, fgraph, node, shapes): - ishape = shapes[0] - rval = tuple( - ishape[i] - for (i, b) in enumerate(node.inputs[0].type.broadcastable) - if i not in self.axis - ) - return [rval, rval] - - def R_op(self, inputs, eval_points): - if eval_points[0] is None: - return [None, None] - if len(self.axis) != 1: - raise ValueError("R_op supported for arg_max only for one axis!") - if self.axis[0] > 1: - raise ValueError("R_op supported for arg_max only when axis is 0 or 1") - if inputs[0].ndim != 2: - raise ValueError("R_op supported for arg_max only when input is a matrix") - max_vals, max_pos = self.make_node(*inputs).outputs - if self.axis[0] == 0: - return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] - else: - return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] - - def grad(self, inp, grads): - # The strict sense mathematical gradient of the maximum function is - # not calculated here for it is not defined at every point where some - # coordinates are identical. However, since the latter set has null - # Lebesgue measure, the result may be interpreted as weak gradient. - - # @note: This function should work correctly for L{vector}s. - # (x, y), (gz, gw) - # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy - # gMax * dMax/dx + gArgMax * dArgMax/dx, - # gMax * dMax/daxis + gArgMax * dArgMax/daxis - # g_max has one less dimension than x, so you need to complete - # g_max to x's shape when axis=0 the broadcasting mechanism - # does it automatically - x = inp[0] - axis = as_tensor_variable(self.axis) - g_max, g_max_idx = grads - - g_max_disconnected = isinstance(g_max.type, DisconnectedType) - g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType) - - # if the op is totally disconnected, so are its inputs - if g_max_disconnected and g_max_idx_disconnected: - return [DisconnectedType()(), DisconnectedType()()] - - # if the max is disconnected but the argmax is not, - # the gradient on its inputs is zero - if g_max_disconnected: - return [x.zeros_like()] - if NoneConst.equals(axis): - axis_ = list(range(x.ndim)) - else: - axis_ = axis - xmax = max(x, axis_) - - # Raise the g_max and xmax to the same number of dim as the input. - pattern = [] - out_dim = 0 - if NoneConst.equals(axis): - # We are taking the max/argmax over all dimensions. - axis = None - for i in range(x.ndim): - if axis is None or i in axis.data: - pattern.append("x") - else: - pattern.append(out_dim) - out_dim += 1 - g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) - xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) - - # Set the grad to the correct position. - g_x = eq(xmax_pad, x) * g_max_pad - return (g_x,) - - -class TensorMax(COp): - """ - Calculate the max over a given axis or over all axes. - - """ - - nin = 2 # tensor, axis - nout = 1 # max val - E_axis = "invalid axis" - params_type = Generic() - __props__ = ("axis",) - _f16_ok = True - - def __init__(self, axis): - assert isinstance(axis, tuple | list) - self.axis = tuple(axis) - - def get_params(self, node): - return self.axis - - def make_node(self, x): - x = as_tensor_variable(x) - - # Keep the original shapes for axes on which we do not perform the max/argmax. - all_axes = set(self.axis) - inputs = [x] - out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes) - outputs = [ - tensor(dtype=x.type.dtype, shape=out_shape, name="max"), - ] - return Apply(self, inputs, outputs) - - def prepare_node(self, node, storage_map, compute_map, impl): - if len(node.inputs) == 2: - raise ValueError( - "You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format." - ) - - def perform(self, node, inp, outs): - x = inp[0] - axes = self.axis - # max, max_idx = outs - (max,) = outs - if axes is None: - axes = tuple(range(x.ndim)) - else: - axes = tuple(int(ax) for ax in axes) - max[0] = _asarray(np.max(x, axes), dtype=node.outputs[0].dtype) - # # Numpy does not support multiple axes for argmax - # # Work around - # keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") - # # Not-reduced axes in front - # transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) - # kept_shape = transposed_x.shape[: len(keep_axes)] - # reduced_shape = transposed_x.shape[len(keep_axes) :] - - # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 - # # Otherwise reshape would complain citing float arg - # new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) - # reshaped_x = transposed_x.reshape(new_shape) - - # max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") - - def c_code(self, node, name, inp, out, sub): - if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim: - raise NotImplementedError( - "NumPy C-API can compute max only for 1 axis or for all axes." - ) - x = inp[0] - axis = sub["params"] - # max, argmax = out - (max,) = out - fail = sub["fail"] - ret = """ - #if PY_MAJOR_VERSION >= 3 - #ifndef PyInt_AS_LONG - #define PyInt_AS_LONG PyLong_AS_LONG - #endif - #endif - - int axis; - - if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) { - axis = NPY_MAXDIMS; - } else if(PyTuple_GET_SIZE(%(axis)s) == 1) { - PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0); - axis = (int)PyInt_AS_LONG(axis_object); - if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) { - PyErr_SetString(PyExc_ValueError, - "TensorMax: bad axis argument"); - %(fail)s - } - } else { - PyErr_SetString(PyExc_NotImplementedError, - "TensorMax: NumPy C-API can compute max only for 1 axis or for all axes."); - %(fail)s - } - - Py_CLEAR(%(max)s); - - %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL); - if (%(max)s == NULL) { - %(fail)s; - } - if (!PyArray_CheckExact(%(max)s)) { - %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); - if(%(max)s == NULL){ - %(fail)s; - } - } - """ - return ret % locals() - - def c_code_cache_version(self): - return (5,) - - def infer_shape(self, fgraph, node, shapes): - ishape = shapes[0] - rval = tuple( - ishape[i] - for (i, b) in enumerate(node.inputs[0].type.broadcastable) - if i not in self.axis +def __getattr__(name): + if name == "MaxandArgmax": + warnings.warn( + "The class `MaxandArgmax` has been deprecated. " + "Call `Max` and `Argmax` seperately as an alternative.", + FutureWarning, ) - return [rval] - - def R_op(self, inputs, eval_points): - if eval_points[0] is None: - return [None, None] - - if len(self.axis) != 1: - raise ValueError("R_op supported for arg_max only for one axis!") - if self.axis[0] > 1: - raise ValueError("R_op supported for arg_max only when axis is 0 or 1") - if inputs[0].ndim != 2: - raise ValueError("R_op supported for arg_max only when input is a matrix") - # max_vals, max_pos = self.make_node(*inputs).outputs - # max_vals = self.make_node(*inputs).outputs - if self.axis[0] == 0: - return [eval_points[0][arange(eval_points[0].shape[1])], None] - else: - return [eval_points[0][arange(eval_points[0].shape[0])], None] - - def grad(self, inp, grads): - # The strict sense mathematical gradient of the maximum function is - # not calculated here for it is not defined at every point where some - # coordinates are identical. However, since the latter set has null - # Lebesgue measure, the result may be interpreted as weak gradient. - - # @note: This function should work correctly for L{vector}s. - # (x, y), (gz, gw) - # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy - # gMax * dMax/dx + gArgMax * dArgMax/dx, - # gMax * dMax/daxis + gArgMax * dArgMax/daxis - # g_max has one less dimension than x, so you need to complete - # g_max to x's shape when axis=0 the broadcasting mechanism - # does it automatically - x = inp[0] - axis = as_tensor_variable(self.axis) - # g_max, g_max_idx = grads - (g_max,) = grads - - g_max_disconnected = isinstance(g_max.type, DisconnectedType) - # g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType) - - # # if the op is totally disconnected, so are its inputs - # if g_max_disconnected and g_max_idx_disconnected: - # return [DisconnectedType()(), DisconnectedType()()] - - # if the op is totally disconnected, so are its inputs - if g_max_disconnected: - return [DisconnectedType()()] - - # if the max is disconnected but the argmax is not, - # the gradient on its inputs is zero - # if g_max_disconnected: - # return [x.zeros_like()] - if NoneConst.equals(axis): - axis_ = list(range(x.ndim)) - else: - axis_ = axis - xmax = max(x, axis_) - - # Raise the g_max and xmax to the same number of dim as the input. - pattern = [] - out_dim = 0 - if NoneConst.equals(axis): - # We are taking the max/argmax over all dimensions. - axis = None - for i in range(x.ndim): - if axis is None or i in axis.data: - pattern.append("x") - else: - pattern.append(out_dim) - out_dim += 1 - g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) - xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) - - # Set the grad to the correct position. - g_x = eq(xmax_pad, x) * g_max_pad - return (g_x,) class Argmax(COp): @@ -561,11 +157,9 @@ class Argmax(COp): params_type = ParamsType(c_axis=ps.int64) def __init__(self, axis): - # if axis is not None: - # axis = tuple(axis) - assert isinstance(axis, tuple | list) - # print(axis) - self.axis = tuple(axis) + if axis is not None: + axis = tuple(axis) + self.axis = axis def get_params(self, node): if self.axis is not None and len(self.axis) == 1: @@ -599,10 +193,9 @@ def perform(self, node, inp, outs): (x,) = inp axes = self.axis (max_idx,) = outs + if axes is None: axes = tuple(range(x.ndim)) - else: - axes = tuple(int(ax) for ax in axes) # Numpy does not support multiple axes for argmax # Work around @@ -611,7 +204,7 @@ def perform(self, node, inp, outs): transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) kept_shape = transposed_x.shape[: len(keep_axes)] reduced_shape = transposed_x.shape[len(keep_axes) :] - new_shape = (*kept_shape, np.prod(reduced_shape)) + new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) reshaped_x = transposed_x.reshape(new_shape) max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") @@ -678,6 +271,9 @@ def infer_shape(self, fgraph, node, shapes): ) return [rval] + def R_op(self, inputs, eval_points): + raise ValueError("Argmax is non-diifferentiable") + def grad(self, inp, grads): (x,) = inp @@ -685,7 +281,6 @@ def grad(self, inp, grads): @_vectorize_node.register(Argmax) -# @_vectorize_node.register(MaxAndArgmax) def vectorize_argmax_node(op, node, batch_x): core_ndim = node.inputs[0].type.ndim batch_ndim = batch_x.type.ndim - core_ndim @@ -805,12 +400,22 @@ def max_and_argmax(a, axis=None, keepdims=False): # Check axis and convert it to a Python list of integers. # Axis will be used as an op param of MaxAndArgmax. a = as_tensor_variable(a) + + flag = False + if axis == (): + flag = True + axis = check_and_normalize_axes(a, axis) - if len(axis) == 0: - axis = list(range(a.type.ndim)) - out = TensorMax(axis)(a) - argout = Argmax(axis)(a) - # out, argout = MaxAndArgmax(axis)(a) + + if len(axis) == 0 and not flag: + axis = None + + out = Max(axis)(a) + + if not flag: + argout = Argmax(axis)(a) + else: + argout = zeros_like(a, dtype="int64") if keepdims: out = makeKeepDims(a, out, axis) @@ -864,6 +469,73 @@ def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) return type(self)(axis=axis) + def grad(self, inp, grads): + # The strict sense mathematical gradient of the maximum function is + # not calculated here for it is not defined at every point where some + # coordinates are identical. However, since the latter set has null + # Lebesgue measure, the result may be interpreted as weak gradient. + + # @note: This function should work correctly for L{vector}s. + # (x, y), (gz, gw) + # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy + # gMax * dMax/dx + gArgMax * dArgMax/dx, + # gMax * dMax/daxis + gArgMax * dArgMax/daxis + # g_max has one less dimension than x, so you need to complete + # g_max to x's shape when axis=0 the broadcasting mechanism + # does it automatically + x = inp[0] + if self.axis is None: + self.axis = tuple(range(x.ndim)) + axis = as_tensor_variable(self.axis) + (g_max,) = grads + + g_max_disconnected = isinstance(g_max.type, DisconnectedType) + + # if the op is totally disconnected, so are its inputs + if g_max_disconnected: + return [DisconnectedType()()] + + if NoneConst.equals(axis): + axis_ = list(range(x.ndim)) + else: + axis_ = axis + xmax = max(x, axis_) + + # Raise the g_max and xmax to the same number of dim as the input. + pattern = [] + out_dim = 0 + if NoneConst.equals(axis): + # We are taking the max/argmax over all dimensions. + axis = None + for i in range(x.ndim): + if axis is None or i in axis.data: + pattern.append("x") + else: + pattern.append(out_dim) + out_dim += 1 + g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) + xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) + + # Set the grad to the correct position. + g_x = eq(xmax_pad, x) * g_max_pad + return (g_x,) + + def R_op(self, inputs, eval_points): + if eval_points[0] is None: + return [None, None] + if len(self.axis) != 1: + raise ValueError("R_op supported for arg_max only for one axis!") + if self.axis[0] > 1: + raise ValueError("R_op supported for arg_max only when axis is 0 or 1") + if inputs[0].ndim != 2: + raise ValueError("R_op supported for arg_max only when input is a matrix") + max_pos = Argmax(self.axis).make_node(*inputs).outputs + # print(eval_points[0].eval()) + if self.axis[0] == 0: + return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] + else: + return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] + class Min(NonZeroDimsCAReduce): nfunc_spec = ("min", 1, 1) diff --git a/pytensor/tensor/rewriting/uncanonicalize.py b/pytensor/tensor/rewriting/uncanonicalize.py index 624353c253..194734f567 100644 --- a/pytensor/tensor/rewriting/uncanonicalize.py +++ b/pytensor/tensor/rewriting/uncanonicalize.py @@ -31,36 +31,16 @@ """ +from pytensor import scalar as ps from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.tensor.basic import Alloc, alloc, constant -from pytensor.tensor.elemwise import DimShuffle - -# from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg -from pytensor.tensor.math import Min, TensorMax, neg +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.math import Min, neg from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.subtensor import Subtensor -# @register_uncanonicalize -# @node_rewriter([MaxAndArgmax]) -# def local_max_and_argmax(fgraph, node): -# """ -# If we don't use the argmax, change it to a max only. -# """ -# if isinstance(node.op, MaxAndArgmax): -# axis = node.op.axis -# if len(fgraph.clients[node.outputs[1]]) == 0: -# new = Max(axis)(node.inputs[0]) -# copy_stack_trace(node.outputs[0], new) -# return [new, None] - -# if len(fgraph.clients[node.outputs[0]]) == 0: -# new = Argmax(axis)(node.inputs[0]) -# copy_stack_trace(node.outputs[0], new) -# return [None, new] - - @register_uncanonicalize @node_rewriter([neg]) def local_max_to_min(fgraph, node): @@ -75,13 +55,14 @@ def local_max_to_min(fgraph, node): the interface put only MaxAndArgmax into the graph. """ - # pytensor.dprint(node) - # print() - # print(node.op == neg) if node.op == neg and node.inputs[0].owner: max = node.inputs[0] # print(max.owner.op.scalar_op) - if max.owner and isinstance(max.owner.op, TensorMax): + if ( + max.owner + and isinstance(max.owner.op, CAReduce) + and max.owner.op.scalar_op == ps.scalar_maximum + ): neg_node = max.owner.inputs[0] if neg_node.owner and neg_node.owner.op == neg: new = Min(max.owner.op.axis)(neg_node.owner.inputs[0]) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 5830983518..8c68600aed 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -256,6 +256,8 @@ def compare_numba_and_py( if assert_fn is None: def assert_fn(x, y): + print(x) + print(y) return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( x, y ) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index b8c131ead6..8bbbe164fc 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc): ), ], ) -def test_MaxAndArgmax(x, axes, exc): - g = ptm.MaxAndArgmax(axes)(x) +def test_Max(x, axes, exc): + g = ptm.Max(axes)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, SharedVariable | Constant) + ], + ) + + +@pytest.mark.parametrize( + "x, axes, exc", + [ + ( + set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")), + [], + None, + ), + ( + set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), + [0], + None, + ), + ( + set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + [0], + None, + ), + ( + set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + [0, 1], + None, + ), + ], +) +def test_Argmax(x, axes, exc): + g = ptm.Argmax(axes)(x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 84322989bf..29c07456b5 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import ( Dot, - MaxAndArgmax, + Max, Prod, Sum, _conj, @@ -3734,8 +3734,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): return # In mode FAST_COMPILE, the rewrites don't replace the - # `MaxAndArgmax` `Op`. - if isinstance(node.op, MaxAndArgmax): + # `Max` `Op`. + if isinstance(node.op, Max): return # TODO FIXME: Refactor this test so that it makes a direct assertion and diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py index 32c141cab3..9d5011b6db 100644 --- a/tests/tensor/rewriting/test_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -9,7 +9,6 @@ from pytensor.graph.rewriting.basic import out2in from pytensor.link.basic import PerformLinker from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import TensorMax from pytensor.tensor.math import min as pt_min from pytensor.tensor.rewriting.uncanonicalize import ( local_alloc_dimshuffle, @@ -22,69 +21,12 @@ from tests.link.test_link import make_function -# class TestMaxAndArgmax: -# def test_optimization(self): -# # If we use only the max output, we should replace this op with -# # a faster one. -# mode = pytensor.compile.mode.get_default_mode().including( -# "canonicalize", "fast_run" -# ) - -# for axis in [0, 1, -1]: -# n = matrix() - -# f = function([n], max_and_argmax(n, axis)[0], mode=mode) -# topo = f.maker.fgraph.toposort() -# assert len(topo) == 1 -# assert isinstance(topo[0].op, CAReduce) - -# f = function([n], max_and_argmax(n, axis), mode=mode) -# topo = f.maker.fgraph.toposort() -# assert len(topo) == 1 -# assert isinstance(topo[0].op, MaxAndArgmax) - - class TestMinMax: def setup_method(self): self.mode = pytensor.compile.mode.get_default_mode().including( "canonicalize", "fast_run" ) - # def test_optimization_max(self): - # data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) - # n = matrix() - - # for axis in [0, 1, -1]: - # f = function([n], pt_max(n, axis), mode=self.mode) - # topo = f.maker.fgraph.toposort() - # assert len(topo) == 1 - # # assert isinstance(topo[0].op, CAReduce) - # f(data) - - # f = function([n], pt_max(-n, axis), mode=self.mode) - # topo = f.maker.fgraph.toposort() - # import pytensor - # pytensor.dprint(topo) - # assert len(topo) == 2 - # assert isinstance(topo[0].op, Elemwise) - # assert isinstance(topo[0].op.scalar_op, ps.Neg) - # assert isinstance(topo[1].op, CAReduce) - # f(data) - - # f = function([n], -pt_max(n, axis), mode=self.mode) - # topo = f.maker.fgraph.toposort() - # assert len(topo) == 2 - # assert isinstance(topo[0].op, CAReduce) - # assert isinstance(topo[1].op, Elemwise) - # assert isinstance(topo[1].op.scalar_op, ps.Neg) - # f(data) - - # f = function([n], -pt_max(-n, axis), mode=self.mode) - # topo = f.maker.fgraph.toposort() - # assert len(topo) == 1 - # assert isinstance(topo[0].op, CAReduce) # min - # f(data) - def test_optimization_min(self): data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) n = matrix() @@ -100,7 +42,7 @@ def test_optimization_min(self): f = function([n], pt_min(-n, axis), mode=self.mode) topo = f.maker.fgraph.toposort() assert len(topo) == 2 - assert isinstance(topo[0].op, TensorMax) # max + assert isinstance(topo[0].op, CAReduce) # max assert isinstance(topo[1].op, Elemwise) assert isinstance(topo[1].op.scalar_op, ps.Neg) f(data) @@ -110,13 +52,13 @@ def test_optimization_min(self): assert len(topo) == 2 assert isinstance(topo[0].op, Elemwise) assert isinstance(topo[0].op.scalar_op, ps.Neg) - assert isinstance(topo[1].op, TensorMax) # max + assert isinstance(topo[1].op, CAReduce) # max f(data) f = function([n], -pt_min(-n, axis), mode=self.mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 - assert isinstance(topo[0].op, TensorMax) # max + assert isinstance(topo[0].op, CAReduce) # max f(data) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index bb460d1fa6..f83944eb1c 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -11,6 +11,7 @@ from numpy.testing import assert_array_equal from scipy.special import logsumexp as scipy_logsumexp +import pytensor import pytensor.scalar as ps from pytensor.compile.debugmode import DebugMode from pytensor.compile.function import function @@ -39,7 +40,7 @@ from pytensor.tensor.math import ( Argmax, Dot, - MaxAndArgmax, + Max, Mean, Prod, ProdWithoutZeros, @@ -760,11 +761,12 @@ def test_isnan(): class TestMaxAndArgmax: def setup_method(self): - MaxAndArgmax.debug = 0 + Max.debug = 0 + Argmax.debug = 0 def test_basic(self): - n = as_tensor_variable([5.0]) - v, i = eval_outputs(max_and_argmax(n)) + n = as_tensor_variable(5) + v, i = eval_outputs(max_and_argmax(n, axis=())) assert v == 5.0 assert i == 0 assert i.dtype == "int64" @@ -806,11 +808,7 @@ def test_basic_2(self, axis, np_axis): v_shape, i_shape = eval_outputs([vt.shape, it.shape]) assert tuple(v_shape) == vt.type.shape assert tuple(i_shape) == it.type.shape - # Test values - v, i = eval_outputs([vt, it]) - assert i.dtype == "int64" - assert np.all(v == np_max) - assert np.all(i == np_argm) + # Test valuesgi @pytest.mark.parametrize( "axis,np_axis", @@ -1032,29 +1030,46 @@ def test_vectorize(self, core_axis, batch_axis): # Test MaxAndArgmax max_x, argmax_x = max_and_argmax(x, axis=core_axis) - node = max_x.owner - assert isinstance(node.op, MaxAndArgmax) - - new_node = vectorize_node(node, batch_x) - assert isinstance(new_node.op, MaxAndArgmax) - assert new_node.op.axis == batch_axis + max_node = max_x.owner + assert isinstance(max_node.op, Max) - # Test Argmax - # Argmax is not user-facing, so we have to create it manually - node = Argmax(axis=node.op.axis).make_node(x) + arg_max_node = argmax_x.owner + new_node = vectorize_node(arg_max_node, batch_x) - new_node = vectorize_node(node, batch_x) assert isinstance(new_node.op, Argmax) assert new_node.op.axis == batch_axis + def test_max_empty_axis(self): + x = np.random.normal(size=(2, 3, 5, 7)) + axis = () + + non_axis = tuple(i for i in range(x.ndim) if i not in axis) + shape_axis = tuple(x.shape[dim] for dim in axis) + shape_non_axis = tuple(x.shape[dim] for dim in non_axis) + x_transposed = x.transpose(*axis, *non_axis) + + x_axis_raveled = x_transposed.reshape( + np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int) + ) + max_x = max_and_argmax(x, axis=axis)[0].eval() + argmax_x = max_and_argmax(x, axis=axis)[1].eval() + + raveled_max = x_axis_raveled[ + argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int)) + ] + indirect_max = raveled_max.reshape(shape_non_axis) + + np.testing.assert_allclose(max_x, x.max(axis=axis)) + np.testing.assert_allclose(indirect_max, x.max(axis=axis)) + class TestArgminArgmax: def setup_method(self): - MaxAndArgmax.debug = 0 + Argmax.debug = 0 def test_scalar(self): for fct in [argmin, argmax]: - n = as_tensor_variable(5.0) + n = as_tensor_variable([5.0]) i = eval_outputs(fct(n)) assert i == 0 v = eval_outputs(fct(n).shape) @@ -1212,7 +1227,7 @@ def test_bool(self): class TestMinMax: def setup_method(self): - MaxAndArgmax.debug = 0 + Max.debug = 0 def test_scalar(self): for fct in [max, min]: @@ -1404,6 +1419,11 @@ def test_bool(self): assert np.all(i) +def test_MaxandArgmax_deprecated(): + with pytest.warns(FutureWarning, match=".*deprecated.*"): + pytensor.tensor.math.MaxandArgmax + + rng = np.random.default_rng(seed=utt.fetch_seed()) TestClip1 = makeTester( name="ClipTester", @@ -2572,27 +2592,50 @@ def test_Mean(self): [adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean ) - def test_MaxAndArgmax(self): + def test_Max(self): + adtens3 = dtensor3() + adtens3_val = random(4, 5, 3) + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, None), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Max + ) + + self._compile_and_check( + [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Max + ) + + def test_Argmax(self): adtens3 = dtensor3() adtens3_val = random(4, 5, 3) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, None), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, None), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Argmax ) self._compile_and_check( - [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], MaxAndArgmax + [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Argmax ) def test_Dot(self): diff --git a/tests/tensor/test_max_argmax.py b/tests/tensor/test_max_argmax.py deleted file mode 100644 index d9b263e57c..0000000000 --- a/tests/tensor/test_max_argmax.py +++ /dev/null @@ -1,693 +0,0 @@ -import builtins - -import numpy as np -import pytest - -from pytensor.compile.function import function -from pytensor.compile.mode import get_default_mode -from pytensor.compile.sharedvalue import shared -from pytensor.configdefaults import config -from pytensor.gradient import grad, numeric_grad -from pytensor.graph.replace import vectorize_node -from pytensor.tensor.basic import ( - as_tensor_variable, - constant, - get_underlying_scalar_constant_value, -) -from pytensor.tensor.math import ( - Argmax, - TensorMax, - argmax, - argmin, - max, - max_and_argmax, - min, -) -from pytensor.tensor.type import ( - matrix, - tensor, -) -from pytensor.tensor.type_other import NoneConst -from tests import unittest_tools as utt -from tests.tensor.utils import ( - eval_outputs, - random, -) - - -if config.mode == "FAST_COMPILE": - mode_opt = "FAST_RUN" -else: - mode_opt = get_default_mode() - - -class TestMaxAndArgmax: - def setup_method(self): - TensorMax.debug = 0 - - def test_basic(self): - # dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0) - n = as_tensor_variable([5.0]) - v, i = eval_outputs(max_and_argmax(n)) - assert v == 5.0 - assert i == 0 - assert i.dtype == "int64" - v = eval_outputs(max_and_argmax(n)[0].shape) - assert len(v) == 0 - v = eval_outputs(max_and_argmax(n)[1].shape) - assert len(v) == 0 - - def test_basic_1(self): - n = as_tensor_variable([1, 2, 3, 2, -6]) - v, i = eval_outputs(max_and_argmax(n)) - assert v == 3 - assert i == 2 - assert i.dtype == "int64" - v = eval_outputs(max_and_argmax(n)[0].shape) - assert len(v) == 0 - - @pytest.mark.parametrize( - "axis,np_axis", - [ - (-1, -1), - (0, 0), - (1, 1), - (None, None), - ([0, 1], None), - ([1, 0], None), - (NoneConst.clone(), None), - (constant(0), 0), - ], - ) - def test_basic_2(self, axis, np_axis): - data = random(2, 3) - n = as_tensor_variable(data) - # Test shape propagates (static & eval) - vt, it = max_and_argmax(n, axis) - np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) - assert vt.type.shape == np_max.shape - assert it.type.shape == np_argm.shape - v_shape, i_shape = eval_outputs([vt.shape, it.shape]) - assert tuple(v_shape) == vt.type.shape - assert tuple(i_shape) == it.type.shape - # Test values - v, i = eval_outputs([vt, it]) - assert i.dtype == "int64" - assert np.all(v == np_max) - assert np.all(i == np_argm) - - @pytest.mark.parametrize( - "axis,np_axis", - [ - (-1, -1), - (0, 0), - (1, 1), - (None, None), - ([0, 1], None), - ([1, 0], None), - (NoneConst.clone(), None), - (constant(0), 0), - ], - ) - def test_basic_2_float16(self, axis, np_axis): - # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16 - data = (random(20, 30).astype("float16") - 0.5) * 20 - n = as_tensor_variable(data) - # Test shape propagates (static & eval) - vt, it = max_and_argmax(n, axis) - np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) - assert vt.type.shape == np_max.shape - assert it.type.shape == np_argm.shape - v_shape, i_shape = eval_outputs([vt.shape, it.shape]) - assert tuple(v_shape) == vt.type.shape - assert tuple(i_shape) == it.type.shape - # Test values - v, i = eval_outputs([vt, it]) - assert i.dtype == "int64" - assert np.all(v == np_max) - assert np.all(i == np_argm) - - def test_basic_2_invalid(self): - n = as_tensor_variable(random(2, 3)) - with pytest.raises(ValueError): - eval_outputs(max_and_argmax(n, 3)) - - n = as_tensor_variable(random(2, 3)) - with pytest.raises(ValueError): - eval_outputs(max_and_argmax(n, -3)) - - def test_basic_2_valid_neg(self): - n = as_tensor_variable(random(2, 3)) - v, i = eval_outputs(max_and_argmax(n, -1)) - assert i.dtype == "int64" - assert v.shape == (2,) - assert i.shape == (2,) - assert np.all(v == np.max(n.value, -1)) - assert np.all(i == np.argmax(n.value, -1)) - v, i = eval_outputs(max_and_argmax(n, -2)) - assert i.dtype == "int64" - assert v.shape == (3,) - assert i.shape == (3,) - assert np.all(v == np.max(n.value, -2)) - assert np.all(i == np.argmax(n.value, -2)) - v = eval_outputs(max_and_argmax(n, -1)[0].shape) - assert v == (2) - v = eval_outputs(max_and_argmax(n, -2)[0].shape) - assert v == (3) - - @pytest.mark.parametrize( - "axis,np_axis", - [ - (-1, -1), - (0, 0), - (1, 1), - (None, None), - ([0, 1, 2], None), - ([1, 2, 0], None), - ], - ) - def test_basic_3(self, axis, np_axis): - data = random(2, 3, 4) - n = as_tensor_variable(data) - # Test shape propagates (static & eval) - vt, it = max_and_argmax(n, axis) - np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) - assert vt.type.shape == np_max.shape - assert it.type.shape == np_argm.shape - v_shape, i_shape = eval_outputs([vt.shape, it.shape]) - assert tuple(v_shape) == vt.type.shape - assert tuple(i_shape) == it.type.shape - # Test values - v, i = eval_outputs([vt, it]) - assert i.dtype == "int64" - assert np.all(v == np_max) - assert np.all(i == np_argm) - - def test_arg_grad(self): - # The test checks that the gradient of argmax(x).sum() is 0 - - x = matrix() - cost = argmax(x, axis=0).sum() - gx = grad(cost, x) - val = get_underlying_scalar_constant_value(gx) - assert val == 0.0 - - def test_grad(self): - data = random(2, 3) - n = as_tensor_variable(data) - - def safe_verify_grad(func, data): - # Wrapper around 'verify_grad' that picks a proper value for epsilon. - # - # This is needed because 'verify_grad' may fail when its epsilon is - # too large, due to the fact the argmax is not continuous. - # We make sure epsilon is less than the minimum absolute value found - # in the matrix of pairwise differences between all elements in the - # data. This way, the argmax will not change when adding epsilon. - - # 'data' is a one-element list. - (data_tensor,) = data - # Flatten it into a 1D vector. - data_vector = data_tensor.flatten() - # Compute pairwise absolute differences. - diff = np.abs(data_vector.reshape((-1, 1)) - data_vector) - # Alter the diagonal to avoid a zero minimum. - for i in range(len(diff)): - diff[i, i] = 1 - # Find an appropriate epsilon. - eps = builtins.min(numeric_grad.type_eps[config.floatX], diff.min() / 2) - # Run gradient verification. - utt.verify_grad(func, data, eps=eps) - - def check_grad_max(data, max_grad_data, axis=None): - # Why this is needed? verify_grad is not enough? - # This works only for axis in [0, None]. - assert axis in [0, None] - z = np.zeros_like(data) - z = z.flatten() - argmax = np.argmax(data, axis=axis) - if argmax.ndim == 0: - z[argmax] += 1 - else: - for id, v in enumerate(argmax): - z[v * np.prod(data.shape[data.ndim - 1 : axis : -1]) + id] += 1 - - z = z.reshape(data.shape) - assert np.all(max_grad_data == z) - - for axis in (-1, 0, 1, None): - for j in range(2): - safe_verify_grad(lambda v: max_and_argmax(v, axis=axis)[j], [data]) - if axis != 1: - safe_verify_grad( - lambda v: max_and_argmax(v.flatten(), axis=axis)[j], [data] - ) - if axis in (0, None): - check_grad_max( - data, - eval_outputs(grad(max_and_argmax(n, axis=axis)[0].sum(), n)), - axis=axis, - ) - check_grad_max(data, eval_outputs(grad(max_and_argmax(n.flatten())[0], n))) - - # Test 3d inner dimensions - data = random(3, 4, 5) - - for i in [0, 1, 2]: - safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data]) - safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) - - # Test 4d inner dimensions - data = random(2, 3, 4, 5) - - for i in [0, 1, 2, 3]: - safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data]) - safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) - - # Test grad with multiple axes - for i in [[0, 1], [0, 0]]: - safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data]) - safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data]) - - def test_preserve_broadcastable(self): - # Ensure the original broadcastable flags are preserved by Max/Argmax. - x = matrix().dimshuffle("x", 0, "x", 1, "x") - y = x.max(axis=1) - assert y.type.shape == (1, 1, None, 1) - assert y.type.broadcastable == (True, True, False, True) - - def test_multiple_axes(self): - data = np.arange(24).reshape(3, 2, 4) - x = as_tensor_variable(data) - vt, it = max_and_argmax(x, [1, -1]) - assert vt.type.shape == it.type.shape == (3,) - v, i = eval_outputs([vt, it]) - assert np.all(v == np.array([7, 15, 23])) - assert np.all(i == np.array([7, 7, 7])) - v = eval_outputs(vt.shape) - assert tuple(v) == vt.type.shape - - def test_zero_shape(self): - x = matrix() - m, i = max_and_argmax(x, axis=1) - f = function([x], [m, i]) - xv = np.zeros((0, 4), dtype=config.floatX) - mv, iv = f(xv) - assert mv.shape == (0,) - assert iv.shape == (0,) - - def test_numpy_input(self): - ar = np.array([1, 2, 3]) - max_pt, argmax_pt = max_and_argmax(ar, axis=None) - assert max_pt.eval() == 3 - assert argmax_pt.eval() == 2 - - @pytest.mark.parametrize( - "core_axis, batch_axis", - [ - (None, (1, 2, 3, 4)), - (0, (1,)), - ((1, -1), (2, 4)), - ], - ) - def test_vectorize(self, core_axis, batch_axis): - x = tensor(shape=(5, 5, 5, 5)) - batch_x = tensor(shape=(3, 5, 5, 5, 5)) - - # Test MaxAndArgmax - max_x, argmax_x = max_and_argmax(x, axis=core_axis) - node = max_x.owner - assert isinstance(node.op, TensorMax) - - # dbt: how to make Argmax user facing? - # new_node = vectorize_node(node, batch_x) - # pytensor.dprint(new_node) - # assert isinstance(new_node.op, Argmax) - # assert new_node.op.axis == batch_axis - - # Test Argmax - # Argmax is not user-facing, so we have to create it manually - node = Argmax(axis=node.op.axis).make_node(x) - - new_node = vectorize_node(node, batch_x) - # print() - # pytensor.dprint(new_node) - # print() - assert isinstance(new_node.op, Argmax) - assert new_node.op.axis == batch_axis - - -class TestArgminArgmax: - def setup_method(self): - TensorMax.debug = 0 - - def test_scalar(self): - for fct in [argmin, argmax]: - n = as_tensor_variable([5.0]) - i = eval_outputs(fct(n)) - assert i == 0 - v = eval_outputs(fct(n).shape) - assert len(v) == 0 - - def test_list(self): - n = as_tensor_variable([1, 2, 3, 2, -6]) - i = eval_outputs(argmin(n)) - assert i == 4 - v = eval_outputs(argmin(n).shape) - assert len(v) == 0 - - n = as_tensor_variable([1, 2, 3, 2, -6]) - i = eval_outputs(argmax(n)) - assert i == 2 - v = eval_outputs(argmax(n).shape) - assert len(v) == 0 - - def test2(self): - data = random(2, 3) - n = as_tensor_variable(data) - for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: - for axis, np_axis in [ - (-1, -1), - (0, 0), - (1, 1), - (None, None), - ([0, 1], None), - ([1, 0], None), - ]: - v = eval_outputs(fct(n, axis)) - assert np.all(v == nfct(data, np_axis)) - v_shape = eval_outputs(fct(n, axis).shape) - assert tuple(v_shape) == nfct(data, np_axis).shape - - def test2_float16(self): - # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16 - data = (random(20, 30).astype("float16") - 0.5) * 20 - n = shared(data) - mode = get_default_mode().including("local_max_and_argmax", "uncanonicalize") - for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: - for axis, np_axis in [ - (-1, -1), - (0, 0), - (1, 1), - (None, None), - ([0, 1], None), - ([1, 0], None), - ]: - v = eval_outputs(fct(n, axis), (Argmax,), mode=mode) - assert np.all(v == nfct(data, np_axis)) - v_shape = eval_outputs(fct(n, axis).shape, mode=mode) - assert tuple(v_shape) == nfct(data, np_axis).shape - - def test2_invalid(self): - for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: - n = as_tensor_variable(random(2, 3)) - with pytest.raises(ValueError): - eval_outputs(fct(n, 3)) - with pytest.raises(ValueError): - eval_outputs(fct(n, -3)) - - def test2_valid_neg(self): - for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: - n = as_tensor_variable(random(2, 3)) - i = eval_outputs(fct(n, -1)) - assert i.shape == (2,) - assert np.all(i == nfct(n.value, -1)) - i = eval_outputs(fct(n, -2)) - assert i.shape == (3,) - assert np.all(i == nfct(n.value, -2)) - - v = eval_outputs(fct(n, -1).shape) - assert v == (2) - v = eval_outputs(fct(n, -2).shape) - assert v == (3) - - def test3(self): - data = random(2, 3, 4) - n = as_tensor_variable(data) - for fct, nfct in [(argmax, np.argmax), (argmin, np.argmin)]: - for axis, np_axis in [ - (-1, -1), - (0, 0), - (1, 1), - (2, 2), - (None, None), - ([0, 1, 2], None), - ([1, 0, 2], None), - ]: - v = eval_outputs(fct(n, axis)) - assert np.all(v == nfct(data, np_axis)) - v_shape = eval_outputs(fct(n, axis).shape) - assert tuple(v_shape) == nfct(data, np_axis).shape - - def test_grad_argmin(self): - data = random(2, 3) - n = as_tensor_variable(data) - n.name = "n" - - # test grad of argmin - utt.verify_grad(lambda v: argmin(v, axis=-1), [data]) - - utt.verify_grad(lambda v: argmin(v, axis=[0]), [data]) - - utt.verify_grad(lambda v: argmin(v, axis=[1]), [data]) - - utt.verify_grad(lambda v: argmin(v.flatten()), [data]) - - try: - cost = argmin(n, axis=-1) - cost.name = None - grad(cost, n) - raise Exception("Expected an error") - except TypeError: - pass - - def test_grad_argmax(self): - data = random(2, 3) - n = as_tensor_variable(data) - - # test grad of argmax - utt.verify_grad(lambda v: argmax(v, axis=-1), [data]) - - utt.verify_grad(lambda v: argmax(v, axis=[0]), [data]) - - utt.verify_grad(lambda v: argmax(v, axis=[1]), [data]) - - utt.verify_grad(lambda v: argmax(v.flatten()), [data]) - - try: - grad(argmax(n, axis=-1), n) - raise Exception("Expected an error") - except TypeError: - pass - - def test_uint(self): - for dtype in ("uint8", "uint16", "uint32", "uint64"): - itype = np.iinfo(dtype) - data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) - n = as_tensor_variable(data) - i = eval_outputs(argmin(n)) - assert i == 1 - i = eval_outputs(argmax(n)) - assert i == 3 - - def test_bool(self): - data = np.array([True, False], "bool") - n = as_tensor_variable(data) - i = eval_outputs(argmin(n)) - assert i == 1 - i = eval_outputs(argmax(n)) - assert i == 0 - - -class TestMinMax: - def setup_method(self): - TensorMax.debug = 0 - - def test_scalar(self): - for fct in [max, min]: - n = as_tensor_variable(5.0) - v = eval_outputs(fct(n)) - assert v == 5.0 - - v = eval_outputs(fct(n).shape) - assert len(v) == 0 - - def test_list(self): - for fct, nfct in [(max, np.max), (min, np.min)]: - n = as_tensor_variable([1, 2, 3, 2, -6]) - v = eval_outputs([fct(n)]) - assert v == nfct(n.value) - - v = eval_outputs(fct(n).shape) - assert len(v) == 0 - - def test2(self): - data = random(2, 3) - n = as_tensor_variable(data) - for fct, nfct in [(max, np.max), (min, np.min)]: - for axis, np_axis in [ - (-1, -1), - (0, 0), - (1, 1), - (None, None), - ([0, 1], None), - ([1, 0], None), - ]: - v = eval_outputs(fct(n, axis)) - assert np.all(v == nfct(data, np_axis)) - v_shape = eval_outputs(fct(n, axis).shape) - assert tuple(v_shape) == nfct(data, np_axis).shape - - def test2_invalid(self): - for fct in [max, min]: - n = as_tensor_variable(random(2, 3)) - with pytest.raises(ValueError): - eval_outputs(fct(n, 3)) - with pytest.raises(ValueError): - eval_outputs(fct(n, -3)) - - def test2_valid_neg(self): - for fct, nfct in [(max, np.max), (min, np.min)]: - n = as_tensor_variable(random(2, 3)) - v = eval_outputs(fct(n, -1)) - assert v.shape == (2,) - assert np.all(v == nfct(n.value, -1)) - v = eval_outputs(fct(n, -2)) - assert v.shape == (3,) - assert np.all(v == nfct(n.value, -2)) - - v = eval_outputs(fct(n, -1).shape) - assert v == (2) - v = eval_outputs(fct(n, -2).shape) - assert v == (3) - - def test3(self): - # Test with 1 axis or all axis out of 3 dims - data = random(2, 3, 4) - n = as_tensor_variable(data) - for fct, nfct in [(max, np.max), (min, np.min)]: - for axis, np_axis in [ - (-1, -1), - (0, 0), - (1, 1), - (2, 2), - (None, None), - ([0, 1, 2], None), - ([1, 0, 2], None), - ]: - v = eval_outputs(fct(n, axis)) - assert np.all(v == nfct(data, np_axis)) - v_shape = eval_outputs(fct(n, axis).shape) - assert tuple(v_shape) == nfct(data, np_axis).shape - - def test3b(self): - # Test with 2 axis out of 3 dims - data = random(2, 3, 4) - n = as_tensor_variable(data) - for fct, nfct in [(max, np.max), (min, np.min)]: - for axis in [[0, 1], [1, 2], [0, 2]]: - v = eval_outputs(fct(n, axis)) - np_v = nfct(nfct(data, axis[1]), axis[0]) - assert np.all(v == np_v) - v_shape = eval_outputs(fct(n, axis).shape) - assert tuple(v_shape) == np_v.shape - - def test_grad_max(self): - data = random(2, 3) - n = as_tensor_variable(data) - - def check_grad_max(data, max_grad_data, axis=None): - # This work only for axis in [0,None] - assert axis in [0, None] - z = np.zeros_like(data) - z = z.flatten() - argmax = np.argmax(data, axis=axis) - if argmax.ndim == 0: - z[np.argmax(data, axis=axis)] += 1 - else: - for id, v in enumerate(argmax): - z[v * np.prod(data.shape[data.ndim - 1 : axis : -1]) + id] += 1 - - z = z.reshape(data.shape) - assert np.all(max_grad_data == z) - - # test grad of max - # axis is the last one - utt.verify_grad(lambda v: max(v, axis=-1), [data]) - - utt.verify_grad(lambda v: max(v, axis=[0]), [data]) - check_grad_max(data, eval_outputs(grad(max(n, axis=0).sum(), n)), axis=0) - - utt.verify_grad(lambda v: max(v, axis=[1]), [data]) - # check_grad_max(data,eval_outputs(grad(max(n,axis=1),n)),axis=1) - - utt.verify_grad(lambda v: max(v.flatten()), [data]) - check_grad_max(data, eval_outputs(grad(max(n.flatten()), n))) - - def test_grad_min(self): - data = random(2, 3) - n = as_tensor_variable(data) - - def check_grad_min(data, min_grad_data, axis=None): - # This work only for axis in [0, None] - assert axis in [0, None] - z = np.zeros_like(data) - z = z.flatten() - argmin = np.argmin(data, axis=axis) - if argmin.ndim == 0: - z[np.argmin(data, axis=axis)] += 1 - else: - for id, v in enumerate(argmin): - z[v * np.prod(data.shape[data.ndim - 1 : axis : -1]) + id] += 1 - - z = z.reshape(data.shape) - assert np.all(min_grad_data == z) - - # test grad of min - # axis is the last one - utt.verify_grad(lambda v: min(v, axis=-1), [data]) - - utt.verify_grad(lambda v: min(v, axis=[0]), [data]) - check_grad_min(data, eval_outputs(grad(min(n, axis=0).sum(), n)), axis=0) - - utt.verify_grad(lambda v: min(v, axis=[1]), [data]) - # check_grad_min(data,eval_outputs(grad(min(n,axis=1),n)),axis=1) - - utt.verify_grad(lambda v: min(v.flatten()), [data]) - check_grad_min(data, eval_outputs(grad(min(n.flatten()), n))) - - def _grad_list(self): - # Test the gradient when we have multiple axis at the same time. - # - # This not implemented, so we disable the test. See ticket: - # http://www.assembla.com/spaces/pytensor/tickets/511 - data = random(2, 3) - for fct in [max_and_argmax, max, min]: - utt.verify_grad(lambda v: fct(v, axis=[0, 1]), [data]) - # n = as_tensor_variable(data) - # check_grad_max(data, eval_outputs(grad(max_and_argmax(n, - # axis=1)[0], n)),axis=1) - - def test_uint(self): - for dtype in ("uint8", "uint16", "uint32", "uint64"): - itype = np.iinfo(dtype) - data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) - n = as_tensor_variable(data) - assert min(n).dtype == dtype - i = eval_outputs(min(n)) - assert i == itype.min - assert max(n).dtype == dtype - i = eval_outputs(max(n)) - assert i == itype.max - - def test_bool(self): - data = np.array([True, False], "bool") - n = as_tensor_variable(data) - assert min(n).dtype == "bool" - i = eval_outputs(min(n)) - assert i.ndim == 0 - assert not np.any(i) - assert max(n).dtype == "bool" - i = eval_outputs(max(n)) - assert i.ndim == 0 - assert np.all(i) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 51c3a28ce3..40c71dfd61 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -107,6 +107,7 @@ def inplace_func( def eval_outputs(outputs, ops=(), mode=None): f = inplace_func([], outputs, mode=mode) variables = f() + if ops: assert any(isinstance(node.op, ops) for node in f.maker.fgraph.apply_nodes) if isinstance(variables, tuple | list) and len(variables) == 1: From 25af747e3f2e5fdcd3bff26779e24da519813a1c Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Wed, 22 May 2024 21:53:38 +0530 Subject: [PATCH 3/4] XFAIL pytensor tests for uint64 data type --- pytensor/compile/function/__init__.py | 1 - pytensor/compile/function/types.py | 1 - pytensor/graph/op.py | 1 - pytensor/link/jax/dispatch/nlinalg.py | 49 +-------------------- pytensor/tensor/math.py | 32 +++++++------- pytensor/tensor/rewriting/uncanonicalize.py | 1 - tests/link/numba/test_basic.py | 2 - tests/tensor/test_math.py | 17 ++++--- tests/tensor/utils.py | 1 - 9 files changed, 27 insertions(+), 78 deletions(-) diff --git a/pytensor/compile/function/__init__.py b/pytensor/compile/function/__init__.py index f020e84376..b7ebba01e8 100644 --- a/pytensor/compile/function/__init__.py +++ b/pytensor/compile/function/__init__.py @@ -312,7 +312,6 @@ def opt_log1p(node): else: # note: pfunc will also call orig_function -- orig_function is # a choke point that all compilation must pass through - fn = pfunc( params=inputs, outputs=outputs, diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 959a0cb14c..c221d7cf41 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -1758,7 +1758,6 @@ def orig_function( name=name, fgraph=fgraph, ) - print(m) with config.change_flags(compute_test_value="off"): fn = m.create(defaults) finally: diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index a70c2b9730..160a65dd7a 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -291,7 +291,6 @@ def __call__( """ node = self.make_node(*inputs, **kwargs) - if name is not None: if len(node.outputs) == 1: node.outputs[0].name = name diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index e76781afa2..81ff82ada2 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -104,56 +104,11 @@ def batched_dot(a, b): return batched_dot -# @jax_funcify.register(Max) -# @jax_funcify.register(Argmax) -# def jax_funcify_MaxAndArgmax(op, **kwargs): -# axis = op.axis - -# def maxandargmax(x, axis=axis): -# if axis is None: -# axes = tuple(range(x.ndim)) -# else: -# axes = tuple(int(ax) for ax in axis) - -# max_res = jnp.max(x, axis) - -# # NumPy does not support multiple axes for argmax; this is a -# # work-around -# keep_axes = jnp.array( -# [i for i in range(x.ndim) if i not in axes], dtype="int64" -# ) -# # Not-reduced axes in front -# transposed_x = jnp.transpose( -# x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) -# ) -# kept_shape = transposed_x.shape[: len(keep_axes)] -# reduced_shape = transposed_x.shape[len(keep_axes) :] - -# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 -# # Otherwise reshape would complain citing float arg -# new_shape = ( -# *kept_shape, -# jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), -# ) -# reshaped_x = transposed_x.reshape(new_shape) - -# max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") - -# return max_res, max_idx_res - -# return maxandargmax - - @jax_funcify.register(Max) def jax_funcify_Max(op, **kwargs): axis = op.axis - def max(x, axis=axis): - # if axis is None: - # axes = tuple(range(x.ndim)) - # else: - # axes = tuple(int(ax) for ax in axis) - + def max(x): max_res = jnp.max(x, axis) return max_res @@ -165,7 +120,7 @@ def max(x, axis=axis): def jax_funcify_Argmax(op, **kwargs): axis = op.axis - def argmax(x, axis=axis): + def argmax(x): if axis is None: axes = tuple(range(x.ndim)) else: diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index da872bc791..264d643762 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -107,6 +107,14 @@ float64_atol = 1e-8 +def __getattr__(name): + if name == "MaxAndArgmax": + raise AttributeError( + "The class `MaxAndArgmax` has been deprecated. " + "Call `Max` and `Argmax` separately as an alternative." + ) + + def _get_atol_rtol(a, b): tiny = ("float16",) narrow = ("float32", "complex64") @@ -134,15 +142,6 @@ def _allclose(a, b, rtol=None, atol=None): return np.allclose(a, b, atol=atol_, rtol=rtol_) -def __getattr__(name): - if name == "MaxandArgmax": - warnings.warn( - "The class `MaxandArgmax` has been deprecated. " - "Call `Max` and `Argmax` seperately as an alternative.", - FutureWarning, - ) - - class Argmax(COp): """ Calculate the argmax over a given axis or over all axes. @@ -193,10 +192,8 @@ def perform(self, node, inp, outs): (x,) = inp axes = self.axis (max_idx,) = outs - if axes is None: axes = tuple(range(x.ndim)) - # Numpy does not support multiple axes for argmax # Work around keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") @@ -398,21 +395,21 @@ def max_and_argmax(a, axis=None, keepdims=False): """ # Check axis and convert it to a Python list of integers. - # Axis will be used as an op param of MaxAndArgmax. + # Axis will be used as an op param of Max and Argmax. a = as_tensor_variable(a) - flag = False + is_axis_empty = False if axis == (): - flag = True + is_axis_empty = True axis = check_and_normalize_axes(a, axis) - if len(axis) == 0 and not flag: + if len(axis) == 0 and not is_axis_empty: axis = None out = Max(axis)(a) - if not flag: + if not is_axis_empty: argout = Argmax(axis)(a) else: argout = zeros_like(a, dtype="int64") @@ -495,7 +492,8 @@ def grad(self, inp, grads): if g_max_disconnected: return [DisconnectedType()()] - if NoneConst.equals(axis): + # if NoneConst.equals(axis): + if axis is None: axis_ = list(range(x.ndim)) else: axis_ = axis diff --git a/pytensor/tensor/rewriting/uncanonicalize.py b/pytensor/tensor/rewriting/uncanonicalize.py index 194734f567..5f6cdc05aa 100644 --- a/pytensor/tensor/rewriting/uncanonicalize.py +++ b/pytensor/tensor/rewriting/uncanonicalize.py @@ -57,7 +57,6 @@ def local_max_to_min(fgraph, node): """ if node.op == neg and node.inputs[0].owner: max = node.inputs[0] - # print(max.owner.op.scalar_op) if ( max.owner and isinstance(max.owner.op, CAReduce) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 8c68600aed..5830983518 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -256,8 +256,6 @@ def compare_numba_and_py( if assert_fn is None: def assert_fn(x, y): - print(x) - print(y) return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( x, y ) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index f83944eb1c..720aba697c 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -808,7 +808,11 @@ def test_basic_2(self, axis, np_axis): v_shape, i_shape = eval_outputs([vt.shape, it.shape]) assert tuple(v_shape) == vt.type.shape assert tuple(i_shape) == it.type.shape - # Test valuesgi + # Test values + v, i = eval_outputs([vt, it]) + assert i.dtype == "int64" + assert np.all(v == np_max) + assert np.all(i == np_argm) @pytest.mark.parametrize( "axis,np_axis", @@ -1029,9 +1033,7 @@ def test_vectorize(self, core_axis, batch_axis): batch_x = tensor(shape=(3, 5, 5, 5, 5)) # Test MaxAndArgmax - max_x, argmax_x = max_and_argmax(x, axis=core_axis) - max_node = max_x.owner - assert isinstance(max_node.op, Max) + argmax_x = argmax(x, axis=core_axis) arg_max_node = argmax_x.owner new_node = vectorize_node(arg_max_node, batch_x) @@ -1394,6 +1396,7 @@ def _grad_list(self): # check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # axis=1)[0], n)),axis=1) + @pytest.mark.xfail(reason="Fails due to #770") def test_uint(self): for dtype in ("uint8", "uint16", "uint32", "uint64"): itype = np.iinfo(dtype) @@ -1419,9 +1422,9 @@ def test_bool(self): assert np.all(i) -def test_MaxandArgmax_deprecated(): - with pytest.warns(FutureWarning, match=".*deprecated.*"): - pytensor.tensor.math.MaxandArgmax +def test_MaxAndArgmax_deprecated(): + with pytest.raises(AttributeError): + pytensor.tensor.math.MaxAndArgmax rng = np.random.default_rng(seed=utt.fetch_seed()) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 40c71dfd61..51c3a28ce3 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -107,7 +107,6 @@ def inplace_func( def eval_outputs(outputs, ops=(), mode=None): f = inplace_func([], outputs, mode=mode) variables = f() - if ops: assert any(isinstance(node.op, ops) for node in f.maker.fgraph.apply_nodes) if isinstance(variables, tuple | list) and len(variables) == 1: From d953a0dab3d21b5231519868455ef1d7c94e60a8 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 13 Jun 2024 14:52:32 +0530 Subject: [PATCH 4/4] Deprecate and raise AttributeError for MaxAndArgmax --- pytensor/tensor/math.py | 14 ++------------ pytensor/tensor/rewriting/uncanonicalize.py | 2 +- tests/link/jax/test_nlinalg.py | 5 +++-- tests/tensor/test_math.py | 6 ++++-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 264d643762..181e813f50 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -110,9 +110,9 @@ def __getattr__(name): if name == "MaxAndArgmax": raise AttributeError( - "The class `MaxAndArgmax` has been deprecated. " - "Call `Max` and `Argmax` separately as an alternative." + "The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative." ) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def _get_atol_rtol(a, b): @@ -565,16 +565,6 @@ def max(x, axis=None, keepdims=False): We return an error as numpy when we reduce a dim with a shape of 0. """ - - # We have a choice of implementing this call with the - # CAReduce op or the MaxAndArgmax op. - - # MaxAndArgmax supports grad and Rop, so we prefer to use that. - # CAReduce is faster, but optimizations will replace MaxAndArgmax[0] - # with CAReduce at compile time, so at this stage the important - # thing is supporting all user interface features, not speed. - # Some cases can be implemented only with CAReduce. - out = max_and_argmax(x, axis)[0] if keepdims: diff --git a/pytensor/tensor/rewriting/uncanonicalize.py b/pytensor/tensor/rewriting/uncanonicalize.py index 5f6cdc05aa..a44870ded2 100644 --- a/pytensor/tensor/rewriting/uncanonicalize.py +++ b/pytensor/tensor/rewriting/uncanonicalize.py @@ -52,7 +52,7 @@ def local_max_to_min(fgraph, node): Notes ----- We don't need an opt that will do the reverse as by default - the interface put only MaxAndArgmax into the graph. + the interface put only Max into the graph. """ if node.op == neg and node.inputs[0].owner: diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 3a64fda364..2175670ee6 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -11,7 +11,7 @@ from pytensor.link.jax import JAXLinker from pytensor.tensor import blas as pt_blas from pytensor.tensor import nlinalg as pt_nlinalg -from pytensor.tensor.math import MaxAndArgmax, maximum +from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.math import max as pt_max from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to # another `Op` x = dvector() - mx, amx = MaxAndArgmax([0])(x) + mx = Max([0])(x) + amx = Argmax([0])(x) out = mx * amx out_fg = FunctionGraph([x], [out]) compare_jax_and_py(out_fg, [np.r_[1, 2]]) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 720aba697c..6b6f8def13 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1032,7 +1032,6 @@ def test_vectorize(self, core_axis, batch_axis): x = tensor(shape=(5, 5, 5, 5)) batch_x = tensor(shape=(3, 5, 5, 5, 5)) - # Test MaxAndArgmax argmax_x = argmax(x, axis=core_axis) arg_max_node = argmax_x.owner @@ -1423,7 +1422,10 @@ def test_bool(self): def test_MaxAndArgmax_deprecated(): - with pytest.raises(AttributeError): + with pytest.raises( + AttributeError, + match="The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative.", + ): pytensor.tensor.math.MaxAndArgmax