From 6bde4dfebfb244338da531b8f6ea4934d9cf47ed Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Feb 2025 23:49:31 +0100 Subject: [PATCH 1/2] Remove unnecessary type ignore in new version of mypy --- pytensor/link/vm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index af44af3254..c6e1283806 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -118,7 +118,7 @@ def calculate_reallocate_info( # where gc for i in range(idx + 1, len(order)): if reuse_out is not None: - break # type: ignore + break for out in order[i].outputs: if ( getattr(out.type, "ndim", None) == 0 From 2543387c0817519c5e6292af3456668e6d941949 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 5 Feb 2025 10:24:47 +0100 Subject: [PATCH 2/2] Implement gradient for vector repetitions Also cleans up implementation and documentation --- pytensor/tensor/extra_ops.py | 176 ++++++++++++++++++++------------- tests/tensor/test_extra_ops.py | 28 ++++-- 2 files changed, 131 insertions(+), 73 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index fedcd32ab9..27eabc5ba4 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -646,12 +646,17 @@ class Repeat(Op): __props__ = ("axis",) - def __init__(self, axis=None): + def __init__(self, axis: int | None = None): + if axis is not None: + if not isinstance(axis, int) or axis < 0: + raise ValueError( + f"Repeat only accepts positive integer axis or None, got {axis}" + ) self.axis = axis def make_node(self, x, repeats): x = ptb.as_tensor_variable(x) - repeats = ptb.as_tensor_variable(repeats) + repeats = ptb.as_tensor_variable(repeats, dtype="int64") if repeats.dtype not in integer_dtypes: raise TypeError("repeats.dtype must be an integer.") @@ -687,17 +692,12 @@ def make_node(self, x, repeats): out_shape = list(x.type.shape) out_shape[self.axis] = None - out_type = TensorType( - x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape) - ) - + out_type = TensorType(x.dtype, shape=out_shape) return Apply(self, [x, repeats], [out_type()]) def perform(self, node, inputs, output_storage): - x = inputs[0] - repeats = inputs[1] - z = output_storage[0] - z[0] = np.repeat(x, repeats=repeats, axis=self.axis) + [x, repeats] = inputs + output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis) def connection_pattern(self, node): return [[True], [False]] @@ -705,40 +705,51 @@ def connection_pattern(self, node): def grad(self, inputs, gout): (x, repeats) = inputs (gz,) = gout + axis = self.axis if repeats.ndim == 0: - if self.axis is None: - axis = x.ndim - else: - if self.axis >= 0: - axis = self.axis + 1 - else: - axis = self.axis + x.ndim + 1 - - shape = [x.shape[k] for k in range(x.ndim)] - shape.insert(axis, repeats) + # When axis is a scalar (same number of reps for all elements), + # We can split the repetitions into their own axis with reshape and sum them back + # to the original element location + sum_axis = x.ndim if axis is None else axis + 1 + shape = list(x.shape) + shape.insert(sum_axis, repeats) + gx = gz.reshape(shape).sum(axis=sum_axis) - return [ - gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis), - DisconnectedType()(), - ] elif repeats.ndim == 1: - # For this implementation, we would need to specify the length - # of repeats in order to split gz in the right way to sum - # the good part. - raise NotImplementedError() + # To sum the gradients that belong to the same repeated x, + # We create a repeated eye and dot product it with the gradient. + axis_size = x.size if axis is None else x.shape[axis] + repeated_eye = repeat( + ptb.eye(axis_size), repeats, axis=0 + ) # A sparse repeat would be neat + + if axis is None: + gx = gz @ repeated_eye + # Undo the ravelling when axis=None + gx = gx.reshape(x.shape) + else: + # Place gradient axis at end for dot product + gx = ptb.moveaxis(gz, axis, -1) + gx = gx @ repeated_eye + # Place gradient back into the correct axis + gx = ptb.moveaxis(gx, -1, axis) + else: raise ValueError() + return [gx, DisconnectedType()()] + def infer_shape(self, fgraph, node, ins_shapes): i0_shapes = ins_shapes[0] repeats = node.inputs[1] out_shape = list(i0_shapes) + axis = self.axis # uint64 shape are not supported. dtype = None if repeats.dtype in ("uint8", "uint16", "uint32"): dtype = "int64" - if self.axis is None: + if axis is None: if repeats.ndim == 0: if len(i0_shapes) == 0: out_shape = [repeats] @@ -751,82 +762,115 @@ def infer_shape(self, fgraph, node, ins_shapes): out_shape = [pt_sum(repeats, dtype=dtype)] else: if repeats.ndim == 0: - out_shape[self.axis] = out_shape[self.axis] * repeats + out_shape[axis] = out_shape[axis] * repeats else: - out_shape[self.axis] = pt_sum(repeats, dtype=dtype) + out_shape[axis] = pt_sum(repeats, dtype=dtype) return [out_shape] -def repeat(x, repeats, axis=None): - """Repeat elements of an array. +def repeat( + a: TensorLike, repeats: TensorLike, axis: int or None = None +) -> TensorVariable: + """Repeat elements of a tensor. - It returns an array which has the same shape as `x`, except along the given - `axis`. The `axis` parameter is used to specify the axis along which values - are repeated. By default, a flattened version of `x` is used. + See :func:`numpy.repeat` for more information. - The number of repetitions for each element is `repeats`. `repeats` is - broadcasted to fit the length of the given `axis`. Parameters ---------- - x - Input data, tensor variable. - repeats - int, scalar or tensor variable + a: tensor_like + Input tensor + repeats: tensor_like + The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. axis : int, optional + The axis along which to repeat values. By default, use the flattened input array, and return a flat output array. - See Also + Returns + ------- + repeated_tensor: TensorVariable + Output tensor which as the same shape as a, except along the given axis + + Examples -------- - tensor.tile + + .. testcode:: + + import pytensor.tensor as pt + + a = pt.arange(4).reshape((2, 2)) + out = pt.repeat(a, repeats=[2, 3], axis=0) + print(out.eval()) + + .. testoutput:: + + [[0 1] + [0 1] + [2 3] + [2 3] + [2 3]] + + When axis is None, the array is first flattened and then repeated + + .. testcode:: + + import pytensor.tensor as pt + + a = pt.arange(4).reshape((2, 2)) + out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None) + print(out.eval()) + + .. testoutput:: + + [0 0 1 1 1 3] + .. versionadded:: 0.6 """ + a = ptb.as_tensor_variable(a) + + if axis is not None: + axis = normalize_axis_index(axis, a.ndim) + repeats = ptb.as_tensor_variable(repeats, dtype=np.int64) if repeats.ndim > 1: raise ValueError("The dimension of repeats should not exceed 1.") if repeats.ndim == 1 and not repeats.broadcastable[0]: - return Repeat(axis=axis)(x, repeats) + # We only use the Repeat Op for vector repeats + return Repeat(axis=axis)(a, repeats) else: if repeats.ndim == 1: repeats = repeats[0] - if x.dtype == "uint64": + if a.dtype == "uint64": + # Multiplying int64 (shape) by uint64 (repeats) yields a float64 + # Which is not valid for the `reshape` operation at the end raise TypeError("repeat doesn't support dtype uint64") if axis is None: axis = 0 - x = x.flatten() - else: - if axis >= x.ndim: - raise ValueError("Axis should not exceed x.ndim-1.") - if axis < 0: - axis = x.ndim + axis + a = a.flatten() - shape = [x.shape[i] for i in range(x.ndim)] + repeat_shape = list(a.shape) - # shape_ is the shape of the intermediate tensor which has + # alloc_shape is the shape of the intermediate tensor which has # an additional dimension comparing to x. We use alloc to # allocate space for this intermediate tensor to replicate x # along that additional dimension. - shape_ = shape[:] - shape_.insert(axis + 1, repeats) + alloc_shape = repeat_shape[:] + alloc_shape.insert(axis + 1, repeats) - # shape is now the shape of output, where shape[axis] becomes + # repeat_shape is now the shape of output, where shape[axis] becomes # shape[axis]*repeats. - shape[axis] = shape[axis] * repeats - - # dims_ is the dimension of that intermediate tensor. - dims_ = list(np.arange(x.ndim)) - dims_.insert(axis + 1, "x") + repeat_shape[axis] = repeat_shape[axis] * repeats # After the original tensor is duplicated along the additional - # dimension, we reshape it to the expected output shape, and - # return the output z. - z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape) - return z + # dimension, we reshape it to the expected output shape + return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape( + repeat_shape + ) class Bartlett(Op): diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index c45e6b1e48..e4f4945393 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -595,7 +595,6 @@ def test_basic(self, ndim, dtype): isinstance(n.op, Repeat) for n in f.maker.fgraph.toposort() ) - @pytest.mark.slow @pytest.mark.parametrize("ndim", [1, 3]) @pytest.mark.parametrize("dtype", ["int8", "uint8", "uint64"]) def test_infer_shape(self, ndim, dtype): @@ -606,6 +605,10 @@ def test_infer_shape(self, ndim, dtype): a = rng.random(shp).astype(config.floatX) for axis in self._possible_axis(ndim): + if axis is not None and axis < 0: + # Operator does not support negative axis + continue + r_var = scalar(dtype=dtype) r = np.asarray(3, dtype=dtype) if dtype in self.numpy_unsupported_dtypes: @@ -635,12 +638,23 @@ def test_infer_shape(self, ndim, dtype): self.op_class, ) - @pytest.mark.parametrize("ndim", range(3)) - def test_grad(self, ndim): - a = np.random.random((10,) * ndim).astype(config.floatX) - - for axis in self._possible_axis(ndim): - utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a]) + @pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}") + @pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}") + @pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}") + def test_grad(self, x_ndim, repeats_ndim, axis): + rng = np.random.default_rng( + [653, x_ndim, 2 if axis is None else axis, repeats_ndim] + ) + x_test = rng.normal(size=np.arange(3, 3 + x_ndim)) + if repeats_ndim == 0: + repeats_size = () + else: + repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,) + repeats = rng.integers(1, 6, size=repeats_size) + utt.verify_grad( + lambda x: Repeat(axis=axis)(x, repeats), + [x_test], + ) def test_broadcastable(self): x = TensorType(config.floatX, shape=(None, 1, None))()