Skip to content

Commit 4e105b4

Browse files
committed
Implement index update for XTensorVariables
1 parent 2dd1c8f commit 4e105b4

File tree

4 files changed

+253
-13
lines changed

4 files changed

+253
-13
lines changed

pytensor/xtensor/indexing.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# https://numpy.org/neps/nep-0021-advanced-indexing.html
55
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
66
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
from typing import Literal
78

89
from pytensor.graph.basic import Apply, Constant, Variable
910
from pytensor.scalar.basic import discrete_dtypes
@@ -183,3 +184,35 @@ def combine_dim_info(idx_dim, idx_dim_shape):
183184

184185

185186
index = Index()
187+
188+
189+
class IndexUpdate(XOp):
190+
__props__ = ("mode",)
191+
192+
def __init__(self, mode: Literal["set", "inc"]):
193+
if mode not in ("set", "inc"):
194+
raise ValueError("mode must be 'set' or 'inc'")
195+
self.mode = mode
196+
197+
def make_node(self, x, y, *idxs):
198+
# Call Index on (x, *idxs) to process inputs and infer output type
199+
x_view_node = index.make_node(x, *idxs)
200+
x, *idxs = x_view_node.inputs
201+
[x_view] = x_view_node.outputs
202+
203+
try:
204+
y = as_xtensor(y)
205+
except TypeError:
206+
y = as_xtensor(as_tensor(y), dims=x_view.type.dims)
207+
208+
if not set(y.type.dims).issubset(x_view.type.dims):
209+
raise ValueError(
210+
f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}"
211+
)
212+
213+
out = x.type()
214+
return Apply(self, [x, y, *idxs], [out])
215+
216+
217+
index_assignment = IndexUpdate("set")
218+
index_increment = IndexUpdate("inc")

pytensor/xtensor/rewriting/indexing.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from pytensor import as_symbolic
44
from pytensor.graph import Constant, node_rewriter
55
from pytensor.tensor import TensorType, arange, specify_shape
6-
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing
6+
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
77
from pytensor.tensor.type_other import NoneTypeT, SliceType
88
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
9-
from pytensor.xtensor.indexing import Index
9+
from pytensor.xtensor.indexing import Index, IndexUpdate, index
1010
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
1111
from pytensor.xtensor.type import XTensorType
1212

@@ -35,9 +35,7 @@ def to_basic_idx(idx):
3535
raise TypeError("Cannot convert idx to basic idx")
3636

3737

38-
@register_xcanonicalize
39-
@node_rewriter(tracks=[Index])
40-
def lower_index(fgraph, node):
38+
def _lower_index(node):
4139
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
4240
4341
xarray-like indexing has two modes:
@@ -59,12 +57,18 @@ def lower_index(fgraph, node):
5957
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
6058
and then indexing it with the original slice. This index is then handled as a regular advanced index.
6159
62-
Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices
63-
are converted to the appropriate XTensorVariable by `Index.make_node`
60+
Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy.
61+
When all advanced indices are consecutive, the respective view is located in the "original" location.
62+
However, if advanced indices are separated by basic indices (slices in our case), the output views
63+
always show up at the front of the array. This information is returned as the second output of this function,
64+
which labels the final position of the indexed dimensions under this rule.
6465
"""
6566

67+
assert isinstance(node.op, Index)
68+
6669
x, *idxs = node.inputs
6770
[out] = node.outputs
71+
x_tensor_indexed_dims = out.type.dims
6872
x_tensor = tensor_from_xtensor(x)
6973

7074
if all(
@@ -141,10 +145,68 @@ def lower_index(fgraph, node):
141145
x_tensor_indexed_dims = [
142146
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims
143147
] + x_tensor_indexed_basic_dims
144-
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
145-
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)
148+
149+
return x_tensor_indexed, x_tensor_indexed_dims
150+
151+
152+
@register_xcanonicalize
153+
@node_rewriter(tracks=[Index])
154+
def lower_index(fgraph, node):
155+
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
156+
157+
The bulk of the work is done by `_lower_index`, except for special logic to control the
158+
location of non-consecutive advanced indices, and to preserve static shape information.
159+
"""
160+
161+
[out] = node.outputs
162+
out_dims = out.type.dims
163+
164+
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node)
165+
if x_tensor_indexed_dims != out_dims:
166+
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
167+
# We need to transpose them back to the expected output order
168+
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
169+
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)
146170

147171
# Add lost shape information
148172
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
149-
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
173+
174+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims)
175+
return [new_out]
176+
177+
178+
@register_xcanonicalize
179+
@node_rewriter(tracks=[IndexUpdate])
180+
def lower_index_update(fgraph, node):
181+
"""Lower XTensorVariable index update to regular TensorVariable indexing update.
182+
183+
This rewrite requires converting the index view to a tensor-based equivalent expression,
184+
just like `lower_index`. It then requires aligning the dimensions of y with the
185+
dimensions of the index view, with special care for non-consecutive dimensions being
186+
pulled to the front axis according to numpy rules.
187+
"""
188+
x, y, *idxs = node.inputs
189+
190+
# Lower the indexing part first
191+
indexed_node = index.make_node(x, *idxs)
192+
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node)
193+
y_tensor = tensor_from_xtensor(y)
194+
195+
# Align dimensions of y with those of the indexed tensor x
196+
y_dims = y.type.dims
197+
y_dims_set = set(y_dims)
198+
y_order = tuple(
199+
y_dims.index(x_dim) if x_dim in y_dims_set else "x"
200+
for x_dim in x_tensor_indexed_dims
201+
)
202+
# Remove useless left expand_dims
203+
while len(y_order) > 0 and y_order[0] == "x":
204+
y_order = y_order[1:]
205+
if y_order != tuple(range(y_tensor.type.ndim)):
206+
y_tensor = y_tensor.dimshuffle(y_order)
207+
208+
x_tensor_updated = inc_subtensor(
209+
x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set"
210+
)
211+
new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims)
150212
return [new_out]

pytensor/xtensor/type.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,10 @@ def item(self):
340340

341341
# Indexing
342342
# https://docs.xarray.dev/en/latest/api.html#id2
343-
def __setitem__(self, key, value):
344-
raise TypeError("XTensorVariable does not support item assignment.")
343+
def __setitem__(self, idx, value):
344+
raise TypeError(
345+
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
346+
)
345347

346348
@property
347349
def loc(self):
@@ -401,6 +403,28 @@ def isel(
401403

402404
return px.indexing.index(self, *indices)
403405

406+
def set(self, value):
407+
if not (
408+
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
409+
):
410+
raise ValueError(
411+
f"set can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}"
412+
)
413+
414+
x, *idxs = self.owner.inputs
415+
return px.indexing.index_assignment(x, value, *idxs)
416+
417+
def inc(self, value):
418+
if not (
419+
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
420+
):
421+
raise ValueError(
422+
f"inc can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}"
423+
)
424+
425+
x, *idxs = self.owner.inputs
426+
return px.indexing.index_increment(x, value, *idxs)
427+
404428
def _head_tail_or_thin(
405429
self,
406430
indexers: dict[str, Any] | int | None,

tests/xtensor/test_indexing.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
from pytensor.tensor import tensor
66
from pytensor.xtensor import xtensor
7-
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
7+
from tests.xtensor.util import (
8+
xr_arange_like,
9+
xr_assert_allclose,
10+
xr_function,
11+
xr_random_like,
12+
)
813

914

1015
@pytest.mark.parametrize(
@@ -331,3 +336,119 @@ def test_boolean_indexing():
331336
expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")]
332337
xr_assert_allclose(res1, expected_res1)
333338
xr_assert_allclose(res2, expected_res2)
339+
340+
341+
@pytest.mark.parametrize("mode", ("set", "inc"))
342+
def test_basic_index_update(mode):
343+
x = xtensor("x", shape=(11, 7), dims=("a", "b"))
344+
y = xtensor("y", shape=(7, 5), dims=("a", "b"))
345+
x_indexed = x[2:-2, 2:]
346+
update_method = getattr(x_indexed, mode)
347+
348+
x_updated = [
349+
update_method(y),
350+
update_method(y.T),
351+
update_method(y.isel(a=-1)),
352+
update_method(y.isel(b=-1)),
353+
update_method(y.isel(a=-2, b=-2)),
354+
]
355+
356+
fn = xr_function([x, y], x_updated)
357+
x_test = xr_random_like(x)
358+
y_test = xr_random_like(y)
359+
results = fn(x_test, y_test)
360+
361+
def update_fn(y):
362+
x = x_test.copy()
363+
if mode == "set":
364+
x[2:-2, 2:] = y
365+
elif mode == "inc":
366+
x[2:-2, 2:] += y
367+
return x
368+
369+
expected_results = [
370+
update_fn(y_test),
371+
update_fn(y_test.T),
372+
update_fn(y_test.isel(a=-1)),
373+
update_fn(y_test.isel(b=-1)),
374+
update_fn(y_test.isel(a=-2, b=-2)),
375+
]
376+
for result, expected_result in zip(results, expected_results):
377+
xr_assert_allclose(result, expected_result)
378+
379+
380+
@pytest.mark.parametrize("mode", ("set", "inc"))
381+
@pytest.mark.parametrize("idx_dtype", (int, bool))
382+
def test_adv_index_update(mode, idx_dtype):
383+
x = xtensor("x", shape=(5, 5), dims=("a", "b"))
384+
y = xtensor("y", shape=(3,), dims=("b",))
385+
idx = xtensor("idx", dtype=idx_dtype, shape=(None,), dims=("a",))
386+
387+
orthogonal_update1 = getattr(x[idx, -3:], mode)(y)
388+
orthogonal_update2 = getattr(x[idx, -3:], mode)(y.rename(b="a"))
389+
if idx_dtype is not bool:
390+
# Vectorized booling indexing/update is not allowed
391+
vectorized_update = getattr(x[idx.rename(a="b"), :3], mode)(y)
392+
else:
393+
with pytest.raises(
394+
IndexError,
395+
match="Boolean indexer should be unlabeled or on the same dimension to the indexed array.",
396+
):
397+
getattr(x[idx.rename(a="b"), :3], mode)(y)
398+
vectorized_update = x
399+
400+
outs = [orthogonal_update1, orthogonal_update2, vectorized_update]
401+
402+
fn = xr_function([x, idx, y], outs)
403+
x_test = xr_random_like(x)
404+
y_test = xr_random_like(y)
405+
if idx_dtype is int:
406+
idx_test = DataArray([0, 1, 2], dims=("a",))
407+
else:
408+
idx_test = DataArray([True, False, True, True, False], dims=("a",))
409+
results = fn(x_test, idx_test, y_test)
410+
411+
def update_fn(x, idx, y):
412+
x = x.copy()
413+
if mode == "set":
414+
x[idx] = y
415+
else:
416+
x[idx] += y
417+
return x
418+
419+
expected_results = [
420+
update_fn(x_test, (idx_test, slice(-3, None)), y_test),
421+
update_fn(
422+
x_test,
423+
(idx_test, slice(-3, None)),
424+
y_test.rename(b="a"),
425+
),
426+
update_fn(x_test, (idx_test.rename(a="b"), slice(None, 3)), y_test)
427+
if idx_dtype is not bool
428+
else x_test,
429+
]
430+
for result, expected_result in zip(results, expected_results):
431+
xr_assert_allclose(result, expected_result)
432+
433+
434+
@pytest.mark.parametrize("mode", ("set", "inc"))
435+
def test_non_consecutive_idx_update(mode):
436+
x = xtensor("x", shape=(2, 3, 5, 7), dims=("a", "b", "c", "d"))
437+
y = xtensor("y", shape=(5, 4), dims=("c", "b"))
438+
x_indexed = x[:, [0, 1, 2, 2], :, ("b", [0, 1, 1, 2])]
439+
out = getattr(x_indexed, mode)(y)
440+
441+
fn = xr_function([x, y], out)
442+
x_test = xr_random_like(x)
443+
y_test = xr_random_like(y)
444+
445+
result = fn(x_test, y_test)
446+
expected_result = x_test.copy()
447+
# xarray fails inplace operation with the "tuple trick"
448+
# https://github.com/pydata/xarray/issues/10387
449+
d_indexer = DataArray([0, 1, 1, 2], dims=("b",))
450+
if mode == "set":
451+
expected_result[:, [0, 1, 2, 2], :, d_indexer] = y_test
452+
else:
453+
expected_result[:, [0, 1, 2, 2], :, d_indexer] += y_test
454+
xr_assert_allclose(result, expected_result)

0 commit comments

Comments
 (0)