Skip to content

Commit ea690e6

Browse files
committed
Implement index update for XTensorVariables
1 parent 9971ca3 commit ea690e6

File tree

4 files changed

+272
-13
lines changed

4 files changed

+272
-13
lines changed

pytensor/xtensor/indexing.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# https://numpy.org/neps/nep-0021-advanced-indexing.html
55
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
66
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
from typing import Literal
78

89
from pytensor.graph.basic import Apply, Constant, Variable
910
from pytensor.scalar.basic import discrete_dtypes
@@ -184,3 +185,35 @@ def combine_dim_info(idx_dim, idx_dim_shape):
184185

185186

186187
index = Index()
188+
189+
190+
class IndexUpdate(XOp):
191+
__props__ = ("mode",)
192+
193+
def __init__(self, mode: Literal["set", "inc"]):
194+
if mode not in ("set", "inc"):
195+
raise ValueError("mode must be 'set' or 'inc'")
196+
self.mode = mode
197+
198+
def make_node(self, x, y, *idxs):
199+
# Call Index on (x, *idxs) to process inputs and infer output type
200+
x_view_node = index.make_node(x, *idxs)
201+
x, *idxs = x_view_node.inputs
202+
[x_view] = x_view_node.outputs
203+
204+
try:
205+
y = as_xtensor(y)
206+
except TypeError:
207+
y = as_xtensor(as_tensor(y), dims=x_view.type.dims)
208+
209+
if not set(y.type.dims).issubset(x_view.type.dims):
210+
raise ValueError(
211+
f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}"
212+
)
213+
214+
out = x.type()
215+
return Apply(self, [x, y, *idxs], [out])
216+
217+
218+
index_assignment = IndexUpdate("set")
219+
index_increment = IndexUpdate("inc")

pytensor/xtensor/rewriting/indexing.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from pytensor import as_symbolic
44
from pytensor.graph import Constant, node_rewriter
55
from pytensor.tensor import TensorType, arange, specify_shape
6-
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing
6+
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
77
from pytensor.tensor.type_other import NoneTypeT, SliceType
88
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
9-
from pytensor.xtensor.indexing import Index
9+
from pytensor.xtensor.indexing import Index, IndexUpdate, index
1010
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
1111
from pytensor.xtensor.type import XTensorType
1212

@@ -35,9 +35,7 @@ def to_basic_idx(idx):
3535
raise TypeError("Cannot convert idx to basic idx")
3636

3737

38-
@register_xcanonicalize
39-
@node_rewriter(tracks=[Index])
40-
def lower_index(fgraph, node):
38+
def _lower_index(node):
4139
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
4240
4341
xarray-like indexing has two modes:
@@ -59,12 +57,18 @@ def lower_index(fgraph, node):
5957
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
6058
and then indexing it with the original slice. This index is then handled as a regular advanced index.
6159
62-
Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices
63-
are converted to the appropriate XTensorVariable by `Index.make_node`
60+
Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy.
61+
When all advanced indices are consecutive, the respective view is located in the "original" location.
62+
However, if advanced indices are separated by basic indices (slices in our case), the output views
63+
always show up at the front of the array. This information is returned as the second output of this function,
64+
which labels the final position of the indexed dimensions under this rule.
6465
"""
6566

67+
assert isinstance(node.op, Index)
68+
6669
x, *idxs = node.inputs
6770
[out] = node.outputs
71+
x_tensor_indexed_dims = out.type.dims
6872
x_tensor = tensor_from_xtensor(x)
6973

7074
if all(
@@ -141,10 +145,68 @@ def lower_index(fgraph, node):
141145
x_tensor_indexed_dims = [
142146
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims
143147
] + x_tensor_indexed_basic_dims
144-
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
145-
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)
148+
149+
return x_tensor_indexed, x_tensor_indexed_dims
150+
151+
152+
@register_xcanonicalize
153+
@node_rewriter(tracks=[Index])
154+
def lower_index(fgraph, node):
155+
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
156+
157+
The bulk of the work is done by `_lower_index`, except for special logic to control the
158+
location of non-consecutive advanced indices, and to preserve static shape information.
159+
"""
160+
161+
[out] = node.outputs
162+
out_dims = out.type.dims
163+
164+
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node)
165+
if x_tensor_indexed_dims != out_dims:
166+
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
167+
# We need to transpose them back to the expected output order
168+
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
169+
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)
146170

147171
# Add lost shape information
148172
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
149-
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
173+
174+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims)
175+
return [new_out]
176+
177+
178+
@register_xcanonicalize
179+
@node_rewriter(tracks=[IndexUpdate])
180+
def lower_index_update(fgraph, node):
181+
"""Lower XTensorVariable index update to regular TensorVariable indexing update.
182+
183+
This rewrite requires converting the index view to a tensor-based equivalent expression,
184+
just like `lower_index`. It then requires aligning the dimensions of y with the
185+
dimensions of the index view, with special care for non-consecutive dimensions being
186+
pulled to the front axis according to numpy rules.
187+
"""
188+
x, y, *idxs = node.inputs
189+
190+
# Lower the indexing part first
191+
indexed_node = index.make_node(x, *idxs)
192+
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node)
193+
y_tensor = tensor_from_xtensor(y)
194+
195+
# Align dimensions of y with those of the indexed tensor x
196+
y_dims = y.type.dims
197+
y_dims_set = set(y_dims)
198+
y_order = tuple(
199+
y_dims.index(x_dim) if x_dim in y_dims_set else "x"
200+
for x_dim in x_tensor_indexed_dims
201+
)
202+
# Remove useless left expand_dims
203+
while len(y_order) > 0 and y_order[0] == "x":
204+
y_order = y_order[1:]
205+
if y_order != tuple(range(y_tensor.type.ndim)):
206+
y_tensor = y_tensor.dimshuffle(y_order)
207+
208+
x_tensor_updated = inc_subtensor(
209+
x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set"
210+
)
211+
new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims)
150212
return [new_out]

pytensor/xtensor/type.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,10 @@ def item(self):
342342

343343
# Indexing
344344
# https://docs.xarray.dev/en/latest/api.html#id2
345-
def __setitem__(self, key, value):
346-
raise TypeError("XTensorVariable does not support item assignment.")
345+
def __setitem__(self, idx, value):
346+
raise TypeError(
347+
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
348+
)
347349

348350
@property
349351
def loc(self):
@@ -403,6 +405,28 @@ def isel(
403405

404406
return px.indexing.index(self, *indices)
405407

408+
def set(self, value):
409+
if not (
410+
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
411+
):
412+
raise ValueError(
413+
f"set can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}"
414+
)
415+
416+
x, *idxs = self.owner.inputs
417+
return px.indexing.index_assignment(x, value, *idxs)
418+
419+
def inc(self, value):
420+
if not (
421+
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
422+
):
423+
raise ValueError(
424+
f"inc can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}"
425+
)
426+
427+
x, *idxs = self.owner.inputs
428+
return px.indexing.index_increment(x, value, *idxs)
429+
406430
def _head_tail_or_thin(
407431
self,
408432
indexers: dict[str, Any] | int | None,

tests/xtensor/test_indexing.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66

77
from pytensor.tensor import tensor
88
from pytensor.xtensor import xtensor
9-
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
9+
from tests.xtensor.util import (
10+
xr_arange_like,
11+
xr_assert_allclose,
12+
xr_function,
13+
xr_random_like,
14+
)
1015

1116

1217
@pytest.mark.parametrize(
@@ -346,3 +351,138 @@ def test_boolean_indexing():
346351
expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")]
347352
xr_assert_allclose(res1, expected_res1)
348353
xr_assert_allclose(res2, expected_res2)
354+
355+
356+
@pytest.mark.parametrize("mode", ("set", "inc"))
357+
def test_basic_index_update(mode):
358+
x = xtensor("x", shape=(11, 7), dims=("a", "b"))
359+
y = xtensor("y", shape=(7, 5), dims=("a", "b"))
360+
x_indexed = x[2:-2, 2:]
361+
update_method = getattr(x_indexed, mode)
362+
363+
x_updated = [
364+
update_method(y),
365+
update_method(y.T),
366+
update_method(y.isel(a=-1)),
367+
update_method(y.isel(b=-1)),
368+
update_method(y.isel(a=-2, b=-2)),
369+
]
370+
371+
fn = xr_function([x, y], x_updated)
372+
x_test = xr_random_like(x)
373+
y_test = xr_random_like(y)
374+
results = fn(x_test, y_test)
375+
376+
def update_fn(y):
377+
x = x_test.copy()
378+
if mode == "set":
379+
x[2:-2, 2:] = y
380+
elif mode == "inc":
381+
x[2:-2, 2:] += y
382+
return x
383+
384+
expected_results = [
385+
update_fn(y_test),
386+
update_fn(y_test.T),
387+
update_fn(y_test.isel(a=-1)),
388+
update_fn(y_test.isel(b=-1)),
389+
update_fn(y_test.isel(a=-2, b=-2)),
390+
]
391+
for result, expected_result in zip(results, expected_results):
392+
xr_assert_allclose(result, expected_result)
393+
394+
395+
@pytest.mark.parametrize("mode", ("set", "inc"))
396+
@pytest.mark.parametrize("idx_dtype", (int, bool))
397+
def test_adv_index_update(mode, idx_dtype):
398+
x = xtensor("x", shape=(5, 5), dims=("a", "b"))
399+
y = xtensor("y", shape=(3,), dims=("b",))
400+
idx = xtensor("idx", dtype=idx_dtype, shape=(None,), dims=("a",))
401+
402+
orthogonal_update1 = getattr(x[idx, -3:], mode)(y)
403+
orthogonal_update2 = getattr(x[idx, -3:], mode)(y.rename(b="a"))
404+
if idx_dtype is not bool:
405+
# Vectorized booling indexing/update is not allowed
406+
vectorized_update = getattr(x[idx.rename(a="b"), :3], mode)(y)
407+
else:
408+
with pytest.raises(
409+
IndexError,
410+
match="Boolean indexer should be unlabeled or on the same dimension to the indexed array.",
411+
):
412+
getattr(x[idx.rename(a="b"), :3], mode)(y)
413+
vectorized_update = x
414+
415+
outs = [orthogonal_update1, orthogonal_update2, vectorized_update]
416+
417+
fn = xr_function([x, idx, y], outs)
418+
x_test = xr_random_like(x)
419+
y_test = xr_random_like(y)
420+
if idx_dtype is int:
421+
idx_test = DataArray([0, 1, 2], dims=("a",))
422+
else:
423+
idx_test = DataArray([True, False, True, True, False], dims=("a",))
424+
results = fn(x_test, idx_test, y_test)
425+
426+
def update_fn(x, idx, y):
427+
x = x.copy()
428+
if mode == "set":
429+
x[idx] = y
430+
else:
431+
x[idx] += y
432+
return x
433+
434+
expected_results = [
435+
update_fn(x_test, (idx_test, slice(-3, None)), y_test),
436+
update_fn(
437+
x_test,
438+
(idx_test, slice(-3, None)),
439+
y_test.rename(b="a"),
440+
),
441+
update_fn(x_test, (idx_test.rename(a="b"), slice(None, 3)), y_test)
442+
if idx_dtype is not bool
443+
else x_test,
444+
]
445+
for result, expected_result in zip(results, expected_results):
446+
xr_assert_allclose(result, expected_result)
447+
448+
449+
@pytest.mark.parametrize("mode", ("set", "inc"))
450+
def test_non_consecutive_idx_update(mode):
451+
x = xtensor("x", shape=(2, 3, 5, 7), dims=("a", "b", "c", "d"))
452+
y = xtensor("y", shape=(5, 4), dims=("c", "b"))
453+
x_indexed = x[:, [0, 1, 2, 2], :, ("b", [0, 1, 1, 2])]
454+
out = getattr(x_indexed, mode)(y)
455+
456+
fn = xr_function([x, y], out)
457+
x_test = xr_random_like(x)
458+
y_test = xr_random_like(y)
459+
460+
result = fn(x_test, y_test)
461+
expected_result = x_test.copy()
462+
# xarray fails inplace operation with the "tuple trick"
463+
# https://github.com/pydata/xarray/issues/10387
464+
d_indexer = DataArray([0, 1, 1, 2], dims=("b",))
465+
if mode == "set":
466+
expected_result[:, [0, 1, 2, 2], :, d_indexer] = y_test
467+
else:
468+
expected_result[:, [0, 1, 2, 2], :, d_indexer] += y_test
469+
xr_assert_allclose(result, expected_result)
470+
471+
472+
def test_indexing_renames_into_update_variable():
473+
x = xtensor("x", shape=(5, 5), dims=("a", "b"))
474+
y = xtensor("y", shape=(3,), dims=("d",))
475+
idx = xtensor("idx", dtype=int, shape=(None,), dims=("d",))
476+
477+
# define "d" dimension by slicing the "a" dimension so we can set y into x
478+
orthogonal_update1 = x[idx].set(y)
479+
fn = xr_function([x, idx, y], orthogonal_update1)
480+
481+
x_test = np.abs(xr_random_like(x))
482+
y_test = -np.abs(xr_random_like(y))
483+
idx_test = DataArray([0, 2, 3], dims=("d",))
484+
485+
result = fn(x_test, idx_test, y_test)
486+
expected_result = x_test.copy()
487+
expected_result[idx_test] = y_test
488+
xr_assert_allclose(result, expected_result)

0 commit comments

Comments
 (0)