Skip to content

XTensorVariable indexing update #1438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs):
return Apply(
self,
(x, y, *new_inputs),
[
tensor(
dtype=x.type.dtype,
shape=tuple(1 if s == 1 else None for s in x.type.shape),
)
],
[x.type()],
)

def perform(self, node, inputs, out_):
Expand Down
33 changes: 33 additions & 0 deletions pytensor/xtensor/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from typing import Literal

from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.scalar.basic import discrete_dtypes
Expand Down Expand Up @@ -184,3 +185,35 @@ def combine_dim_info(idx_dim, idx_dim_shape):


index = Index()


class IndexUpdate(XOp):
__props__ = ("mode",)

def __init__(self, mode: Literal["set", "inc"]):
if mode not in ("set", "inc"):
raise ValueError("mode must be 'set' or 'inc'")
self.mode = mode

def make_node(self, x, y, *idxs):
# Call Index on (x, *idxs) to process inputs and infer output type
x_view_node = index.make_node(x, *idxs)
x, *idxs = x_view_node.inputs
[x_view] = x_view_node.outputs

try:
y = as_xtensor(y)
except TypeError:
y = as_xtensor(as_tensor(y), dims=x_view.type.dims)

if not set(y.type.dims).issubset(x_view.type.dims):
raise ValueError(
f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}"
)

out = x.type()
return Apply(self, [x, y, *idxs], [out])


index_assignment = IndexUpdate("set")
index_increment = IndexUpdate("inc")
82 changes: 72 additions & 10 deletions pytensor/xtensor/rewriting/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import TensorType, arange, specify_shape
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index
from pytensor.xtensor.indexing import Index, IndexUpdate, index
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
from pytensor.xtensor.type import XTensorType

Expand Down Expand Up @@ -35,9 +35,7 @@ def to_basic_idx(idx):
raise TypeError("Cannot convert idx to basic idx")


@register_xcanonicalize
@node_rewriter(tracks=[Index])
def lower_index(fgraph, node):
def _lower_index(node):
"""Lower XTensorVariable indexing to regular TensorVariable indexing.

xarray-like indexing has two modes:
Expand All @@ -59,12 +57,18 @@ def lower_index(fgraph, node):
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
and then indexing it with the original slice. This index is then handled as a regular advanced index.

Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices
are converted to the appropriate XTensorVariable by `Index.make_node`
Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy.
When all advanced indices are consecutive, the respective view is located in the "original" location.
However, if advanced indices are separated by basic indices (slices in our case), the output views
always show up at the front of the array. This information is returned as the second output of this function,
which labels the final position of the indexed dimensions under this rule.
"""

assert isinstance(node.op, Index)

x, *idxs = node.inputs
[out] = node.outputs
x_tensor_indexed_dims = out.type.dims
x_tensor = tensor_from_xtensor(x)

if all(
Expand Down Expand Up @@ -141,10 +145,68 @@ def lower_index(fgraph, node):
x_tensor_indexed_dims = [
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims
] + x_tensor_indexed_basic_dims
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)

return x_tensor_indexed, x_tensor_indexed_dims


@register_xcanonicalize
@node_rewriter(tracks=[Index])
def lower_index(fgraph, node):
"""Lower XTensorVariable indexing to regular TensorVariable indexing.

The bulk of the work is done by `_lower_index`, except for special logic to control the
location of non-consecutive advanced indices, and to preserve static shape information.
"""

[out] = node.outputs
out_dims = out.type.dims

x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node)
if x_tensor_indexed_dims != out_dims:
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
# We need to transpose them back to the expected output order
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)

# Add lost shape information
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)

new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[IndexUpdate])
def lower_index_update(fgraph, node):
"""Lower XTensorVariable index update to regular TensorVariable indexing update.

This rewrite requires converting the index view to a tensor-based equivalent expression,
just like `lower_index`. It then requires aligning the dimensions of y with the
dimensions of the index view, with special care for non-consecutive dimensions being
pulled to the front axis according to numpy rules.
"""
x, y, *idxs = node.inputs

# Lower the indexing part first
indexed_node = index.make_node(x, *idxs)
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node)
y_tensor = tensor_from_xtensor(y)

# Align dimensions of y with those of the indexed tensor x
y_dims = y.type.dims
y_dims_set = set(y_dims)
y_order = tuple(
y_dims.index(x_dim) if x_dim in y_dims_set else "x"
for x_dim in x_tensor_indexed_dims
)
# Remove useless left expand_dims
while len(y_order) > 0 and y_order[0] == "x":
y_order = y_order[1:]
if y_order != tuple(range(y_tensor.type.ndim)):
y_tensor = y_tensor.dimshuffle(y_order)

x_tensor_updated = inc_subtensor(
x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set"
)
new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims)
return [new_out]
28 changes: 26 additions & 2 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,10 @@ def item(self):

# Indexing
# https://docs.xarray.dev/en/latest/api.html#id2
def __setitem__(self, key, value):
raise TypeError("XTensorVariable does not support item assignment.")
def __setitem__(self, idx, value):
raise TypeError(
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
)

@property
def loc(self):
Expand Down Expand Up @@ -403,6 +405,28 @@ def isel(

return px.indexing.index(self, *indices)

def set(self, value):
if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
):
raise ValueError(
f"set can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}"
)

x, *idxs = self.owner.inputs
return px.indexing.index_assignment(x, value, *idxs)

def inc(self, value):
if not (
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
):
raise ValueError(
f"inc can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}"
)

x, *idxs = self.owner.inputs
return px.indexing.index_increment(x, value, *idxs)

def _head_tail_or_thin(
self,
indexers: dict[str, Any] | int | None,
Expand Down
144 changes: 142 additions & 2 deletions tests/xtensor/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from pytensor.tensor import tensor
from pytensor.xtensor import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
from tests.xtensor.util import (
xr_arange_like,
xr_assert_allclose,
xr_function,
xr_random_like,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -45,7 +50,7 @@ def test_basic_indexing(labeled, indices):
xr_assert_allclose(res, expected_res)


def test_single_adv_indexing_on_existing_dim():
def test_single_vector_indexing_on_existing_dim():
x = xtensor(dims=("a", "b"), shape=(3, 5))
idx = tensor("idx", dtype=int, shape=(4,))
xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",))
Expand Down Expand Up @@ -346,3 +351,138 @@ def test_boolean_indexing():
expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")]
xr_assert_allclose(res1, expected_res1)
xr_assert_allclose(res2, expected_res2)


@pytest.mark.parametrize("mode", ("set", "inc"))
def test_basic_index_update(mode):
x = xtensor("x", shape=(11, 7), dims=("a", "b"))
y = xtensor("y", shape=(7, 5), dims=("a", "b"))
x_indexed = x[2:-2, 2:]
update_method = getattr(x_indexed, mode)

x_updated = [
update_method(y),
update_method(y.T),
update_method(y.isel(a=-1)),
update_method(y.isel(b=-1)),
update_method(y.isel(a=-2, b=-2)),
]

fn = xr_function([x, y], x_updated)
x_test = xr_random_like(x)
y_test = xr_random_like(y)
results = fn(x_test, y_test)

def update_fn(y):
x = x_test.copy()
if mode == "set":
x[2:-2, 2:] = y
elif mode == "inc":
x[2:-2, 2:] += y
return x

expected_results = [
update_fn(y_test),
update_fn(y_test.T),
update_fn(y_test.isel(a=-1)),
update_fn(y_test.isel(b=-1)),
update_fn(y_test.isel(a=-2, b=-2)),
]
for result, expected_result in zip(results, expected_results):
xr_assert_allclose(result, expected_result)


@pytest.mark.parametrize("mode", ("set", "inc"))
@pytest.mark.parametrize("idx_dtype", (int, bool))
def test_adv_index_update(mode, idx_dtype):
x = xtensor("x", shape=(5, 5), dims=("a", "b"))
y = xtensor("y", shape=(3,), dims=("b",))
idx = xtensor("idx", dtype=idx_dtype, shape=(None,), dims=("a",))

orthogonal_update1 = getattr(x[idx, -3:], mode)(y)
orthogonal_update2 = getattr(x[idx, -3:], mode)(y.rename(b="a"))
if idx_dtype is not bool:
# Vectorized booling indexing/update is not allowed
vectorized_update = getattr(x[idx.rename(a="b"), :3], mode)(y)
else:
with pytest.raises(
IndexError,
match="Boolean indexer should be unlabeled or on the same dimension to the indexed array.",
):
getattr(x[idx.rename(a="b"), :3], mode)(y)
vectorized_update = x

outs = [orthogonal_update1, orthogonal_update2, vectorized_update]

fn = xr_function([x, idx, y], outs)
x_test = xr_random_like(x)
y_test = xr_random_like(y)
if idx_dtype is int:
idx_test = DataArray([0, 1, 2], dims=("a",))
else:
idx_test = DataArray([True, False, True, True, False], dims=("a",))
results = fn(x_test, idx_test, y_test)

def update_fn(x, idx, y):
x = x.copy()
if mode == "set":
x[idx] = y
else:
x[idx] += y
return x

expected_results = [
update_fn(x_test, (idx_test, slice(-3, None)), y_test),
update_fn(
x_test,
(idx_test, slice(-3, None)),
y_test.rename(b="a"),
),
update_fn(x_test, (idx_test.rename(a="b"), slice(None, 3)), y_test)
if idx_dtype is not bool
else x_test,
]
for result, expected_result in zip(results, expected_results):
xr_assert_allclose(result, expected_result)


@pytest.mark.parametrize("mode", ("set", "inc"))
def test_non_consecutive_idx_update(mode):
x = xtensor("x", shape=(2, 3, 5, 7), dims=("a", "b", "c", "d"))
y = xtensor("y", shape=(5, 4), dims=("c", "b"))
x_indexed = x[:, [0, 1, 2, 2], :, ("b", [0, 1, 1, 2])]
out = getattr(x_indexed, mode)(y)

fn = xr_function([x, y], out)
x_test = xr_random_like(x)
y_test = xr_random_like(y)

result = fn(x_test, y_test)
expected_result = x_test.copy()
# xarray fails inplace operation with the "tuple trick"
# https://github.com/pydata/xarray/issues/10387
d_indexer = DataArray([0, 1, 1, 2], dims=("b",))
if mode == "set":
expected_result[:, [0, 1, 2, 2], :, d_indexer] = y_test
else:
expected_result[:, [0, 1, 2, 2], :, d_indexer] += y_test
xr_assert_allclose(result, expected_result)


def test_indexing_renames_into_update_variable():
x = xtensor("x", shape=(5, 5), dims=("a", "b"))
y = xtensor("y", shape=(3,), dims=("d",))
idx = xtensor("idx", dtype=int, shape=(None,), dims=("d",))

# define "d" dimension by slicing the "a" dimension so we can set y into x
orthogonal_update1 = x[idx].set(y)
fn = xr_function([x, idx, y], orthogonal_update1)

x_test = np.abs(xr_random_like(x))
y_test = -np.abs(xr_random_like(y))
idx_test = DataArray([0, 2, 3], dims=("d",))

result = fn(x_test, idx_test, y_test)
expected_result = x_test.copy()
expected_result[idx_test] = y_test
xr_assert_allclose(result, expected_result)
Loading