From d042566d799a48acace07b33cee5cc37cf5bdbf9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 3 Jun 2025 18:17:10 +0200 Subject: [PATCH 1/3] .tweak test name --- tests/xtensor/test_indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py index e00adb4d86..721fd1e695 100644 --- a/tests/xtensor/test_indexing.py +++ b/tests/xtensor/test_indexing.py @@ -45,7 +45,7 @@ def test_basic_indexing(labeled, indices): xr_assert_allclose(res, expected_res) -def test_single_adv_indexing_on_existing_dim(): +def test_single_vector_indexing_on_existing_dim(): x = xtensor(dims=("a", "b"), shape=(3, 5)) idx = tensor("idx", dtype=int, shape=(4,)) xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) From fcbe46c5625f4a0db91cf8a31be3f260b7ee4c10 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 2 Jun 2025 11:22:48 +0200 Subject: [PATCH 2/3] Don't lose static shape in AdvancedIncSubtensor --- pytensor/tensor/subtensor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 278d1e8da6..99ae67af9b 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs): return Apply( self, (x, y, *new_inputs), - [ - tensor( - dtype=x.type.dtype, - shape=tuple(1 if s == 1 else None for s in x.type.shape), - ) - ], + [x.type()], ) def perform(self, node, inputs, out_): From 2ad906a2696d3480d51ea116f8bc4fc82f3cb156 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 2 Jun 2025 11:27:11 +0200 Subject: [PATCH 3/3] Implement index update for XTensorVariables --- pytensor/xtensor/indexing.py | 33 ++++++ pytensor/xtensor/rewriting/indexing.py | 82 ++++++++++++-- pytensor/xtensor/type.py | 28 ++++- tests/xtensor/test_indexing.py | 142 ++++++++++++++++++++++++- 4 files changed, 272 insertions(+), 13 deletions(-) diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py index 91e74017c9..01517db55d 100644 --- a/pytensor/xtensor/indexing.py +++ b/pytensor/xtensor/indexing.py @@ -4,6 +4,7 @@ # https://numpy.org/neps/nep-0021-advanced-indexing.html # https://docs.xarray.dev/en/latest/user-guide/indexing.html # https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html +from typing import Literal from pytensor.graph.basic import Apply, Constant, Variable from pytensor.scalar.basic import discrete_dtypes @@ -184,3 +185,35 @@ def combine_dim_info(idx_dim, idx_dim_shape): index = Index() + + +class IndexUpdate(XOp): + __props__ = ("mode",) + + def __init__(self, mode: Literal["set", "inc"]): + if mode not in ("set", "inc"): + raise ValueError("mode must be 'set' or 'inc'") + self.mode = mode + + def make_node(self, x, y, *idxs): + # Call Index on (x, *idxs) to process inputs and infer output type + x_view_node = index.make_node(x, *idxs) + x, *idxs = x_view_node.inputs + [x_view] = x_view_node.outputs + + try: + y = as_xtensor(y) + except TypeError: + y = as_xtensor(as_tensor(y), dims=x_view.type.dims) + + if not set(y.type.dims).issubset(x_view.type.dims): + raise ValueError( + f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}" + ) + + out = x.type() + return Apply(self, [x, y, *idxs], [out]) + + +index_assignment = IndexUpdate("set") +index_increment = IndexUpdate("inc") diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py index 70f232ffb1..6b0b650848 100644 --- a/pytensor/xtensor/rewriting/indexing.py +++ b/pytensor/xtensor/rewriting/indexing.py @@ -3,10 +3,10 @@ from pytensor import as_symbolic from pytensor.graph import Constant, node_rewriter from pytensor.tensor import TensorType, arange, specify_shape -from pytensor.tensor.subtensor import _non_consecutive_adv_indexing +from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor -from pytensor.xtensor.indexing import Index +from pytensor.xtensor.indexing import Index, IndexUpdate, index from pytensor.xtensor.rewriting.utils import register_xcanonicalize from pytensor.xtensor.type import XTensorType @@ -35,9 +35,7 @@ def to_basic_idx(idx): raise TypeError("Cannot convert idx to basic idx") -@register_xcanonicalize -@node_rewriter(tracks=[Index]) -def lower_index(fgraph, node): +def _lower_index(node): """Lower XTensorVariable indexing to regular TensorVariable indexing. xarray-like indexing has two modes: @@ -59,12 +57,18 @@ def lower_index(fgraph, node): We do this by creating an `arange` tensor that matches the shape of the dimension being indexed, and then indexing it with the original slice. This index is then handled as a regular advanced index. - Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices - are converted to the appropriate XTensorVariable by `Index.make_node` + Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy. + When all advanced indices are consecutive, the respective view is located in the "original" location. + However, if advanced indices are separated by basic indices (slices in our case), the output views + always show up at the front of the array. This information is returned as the second output of this function, + which labels the final position of the indexed dimensions under this rule. """ + assert isinstance(node.op, Index) + x, *idxs = node.inputs [out] = node.outputs + x_tensor_indexed_dims = out.type.dims x_tensor = tensor_from_xtensor(x) if all( @@ -141,10 +145,68 @@ def lower_index(fgraph, node): x_tensor_indexed_dims = [ dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims ] + x_tensor_indexed_basic_dims - transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims] - x_tensor_indexed = x_tensor_indexed.transpose(transpose_order) + + return x_tensor_indexed, x_tensor_indexed_dims + + +@register_xcanonicalize +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + """Lower XTensorVariable indexing to regular TensorVariable indexing. + + The bulk of the work is done by `_lower_index`, except for special logic to control the + location of non-consecutive advanced indices, and to preserve static shape information. + """ + + [out] = node.outputs + out_dims = out.type.dims + + x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node) + if x_tensor_indexed_dims != out_dims: + # Numpy moves advanced indexing dimensions to the front when they are not consecutive + # We need to transpose them back to the expected output order + transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims] + x_tensor_indexed = x_tensor_indexed.transpose(transpose_order) # Add lost shape information x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) - new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims) + + new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims) + return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[IndexUpdate]) +def lower_index_update(fgraph, node): + """Lower XTensorVariable index update to regular TensorVariable indexing update. + + This rewrite requires converting the index view to a tensor-based equivalent expression, + just like `lower_index`. It then requires aligning the dimensions of y with the + dimensions of the index view, with special care for non-consecutive dimensions being + pulled to the front axis according to numpy rules. + """ + x, y, *idxs = node.inputs + + # Lower the indexing part first + indexed_node = index.make_node(x, *idxs) + x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node) + y_tensor = tensor_from_xtensor(y) + + # Align dimensions of y with those of the indexed tensor x + y_dims = y.type.dims + y_dims_set = set(y_dims) + y_order = tuple( + y_dims.index(x_dim) if x_dim in y_dims_set else "x" + for x_dim in x_tensor_indexed_dims + ) + # Remove useless left expand_dims + while len(y_order) > 0 and y_order[0] == "x": + y_order = y_order[1:] + if y_order != tuple(range(y_tensor.type.ndim)): + y_tensor = y_tensor.dimshuffle(y_order) + + x_tensor_updated = inc_subtensor( + x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set" + ) + new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims) return [new_out] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index adb7218147..deb8fe7291 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -342,8 +342,10 @@ def item(self): # Indexing # https://docs.xarray.dev/en/latest/api.html#id2 - def __setitem__(self, key, value): - raise TypeError("XTensorVariable does not support item assignment.") + def __setitem__(self, idx, value): + raise TypeError( + "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." + ) @property def loc(self): @@ -403,6 +405,28 @@ def isel( return px.indexing.index(self, *indices) + def set(self, value): + if not ( + self.owner is not None and isinstance(self.owner.op, px.indexing.Index) + ): + raise ValueError( + f"set can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}" + ) + + x, *idxs = self.owner.inputs + return px.indexing.index_assignment(x, value, *idxs) + + def inc(self, value): + if not ( + self.owner is not None and isinstance(self.owner.op, px.indexing.Index) + ): + raise ValueError( + f"inc can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}" + ) + + x, *idxs = self.owner.inputs + return px.indexing.index_increment(x, value, *idxs) + def _head_tail_or_thin( self, indexers: dict[str, Any] | int | None, diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py index 721fd1e695..c7d8572bdc 100644 --- a/tests/xtensor/test_indexing.py +++ b/tests/xtensor/test_indexing.py @@ -6,7 +6,12 @@ from pytensor.tensor import tensor from pytensor.xtensor import xtensor -from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, + xr_random_like, +) @pytest.mark.parametrize( @@ -346,3 +351,138 @@ def test_boolean_indexing(): expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")] xr_assert_allclose(res1, expected_res1) xr_assert_allclose(res2, expected_res2) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +def test_basic_index_update(mode): + x = xtensor("x", shape=(11, 7), dims=("a", "b")) + y = xtensor("y", shape=(7, 5), dims=("a", "b")) + x_indexed = x[2:-2, 2:] + update_method = getattr(x_indexed, mode) + + x_updated = [ + update_method(y), + update_method(y.T), + update_method(y.isel(a=-1)), + update_method(y.isel(b=-1)), + update_method(y.isel(a=-2, b=-2)), + ] + + fn = xr_function([x, y], x_updated) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + results = fn(x_test, y_test) + + def update_fn(y): + x = x_test.copy() + if mode == "set": + x[2:-2, 2:] = y + elif mode == "inc": + x[2:-2, 2:] += y + return x + + expected_results = [ + update_fn(y_test), + update_fn(y_test.T), + update_fn(y_test.isel(a=-1)), + update_fn(y_test.isel(b=-1)), + update_fn(y_test.isel(a=-2, b=-2)), + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +@pytest.mark.parametrize("idx_dtype", (int, bool)) +def test_adv_index_update(mode, idx_dtype): + x = xtensor("x", shape=(5, 5), dims=("a", "b")) + y = xtensor("y", shape=(3,), dims=("b",)) + idx = xtensor("idx", dtype=idx_dtype, shape=(None,), dims=("a",)) + + orthogonal_update1 = getattr(x[idx, -3:], mode)(y) + orthogonal_update2 = getattr(x[idx, -3:], mode)(y.rename(b="a")) + if idx_dtype is not bool: + # Vectorized booling indexing/update is not allowed + vectorized_update = getattr(x[idx.rename(a="b"), :3], mode)(y) + else: + with pytest.raises( + IndexError, + match="Boolean indexer should be unlabeled or on the same dimension to the indexed array.", + ): + getattr(x[idx.rename(a="b"), :3], mode)(y) + vectorized_update = x + + outs = [orthogonal_update1, orthogonal_update2, vectorized_update] + + fn = xr_function([x, idx, y], outs) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + if idx_dtype is int: + idx_test = DataArray([0, 1, 2], dims=("a",)) + else: + idx_test = DataArray([True, False, True, True, False], dims=("a",)) + results = fn(x_test, idx_test, y_test) + + def update_fn(x, idx, y): + x = x.copy() + if mode == "set": + x[idx] = y + else: + x[idx] += y + return x + + expected_results = [ + update_fn(x_test, (idx_test, slice(-3, None)), y_test), + update_fn( + x_test, + (idx_test, slice(-3, None)), + y_test.rename(b="a"), + ), + update_fn(x_test, (idx_test.rename(a="b"), slice(None, 3)), y_test) + if idx_dtype is not bool + else x_test, + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +def test_non_consecutive_idx_update(mode): + x = xtensor("x", shape=(2, 3, 5, 7), dims=("a", "b", "c", "d")) + y = xtensor("y", shape=(5, 4), dims=("c", "b")) + x_indexed = x[:, [0, 1, 2, 2], :, ("b", [0, 1, 1, 2])] + out = getattr(x_indexed, mode)(y) + + fn = xr_function([x, y], out) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + + result = fn(x_test, y_test) + expected_result = x_test.copy() + # xarray fails inplace operation with the "tuple trick" + # https://github.com/pydata/xarray/issues/10387 + d_indexer = DataArray([0, 1, 1, 2], dims=("b",)) + if mode == "set": + expected_result[:, [0, 1, 2, 2], :, d_indexer] = y_test + else: + expected_result[:, [0, 1, 2, 2], :, d_indexer] += y_test + xr_assert_allclose(result, expected_result) + + +def test_indexing_renames_into_update_variable(): + x = xtensor("x", shape=(5, 5), dims=("a", "b")) + y = xtensor("y", shape=(3,), dims=("d",)) + idx = xtensor("idx", dtype=int, shape=(None,), dims=("d",)) + + # define "d" dimension by slicing the "a" dimension so we can set y into x + orthogonal_update1 = x[idx].set(y) + fn = xr_function([x, idx, y], orthogonal_update1) + + x_test = np.abs(xr_random_like(x)) + y_test = -np.abs(xr_random_like(y)) + idx_test = DataArray([0, 2, 3], dims=("d",)) + + result = fn(x_test, idx_test, y_test) + expected_result = x_test.copy() + expected_result[idx_test] = y_test + xr_assert_allclose(result, expected_result)