Skip to content

Commit de6e6d4

Browse files
committed
WIP Implement index operations on XTensorTypes
1 parent a4d5727 commit de6e6d4

File tree

7 files changed

+276
-10
lines changed

7 files changed

+276
-10
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor.type import (
5-
XTensorType,
65
as_xtensor,
76
as_xtensor_variable,
87
xtensor,

pytensor/xtensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class XOp(Op):
1212

1313
def perform(self, node, inputs, outputs):
1414
raise NotImplementedError(
15-
"xtensor operations must be rewritten as tensor operations"
15+
f"xtensor operation {self} must be lowered to equivalent tensor operations"
1616
)
1717

1818

pytensor/xtensor/indexing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor.basic import as_tensor
11+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
12+
from pytensor.xtensor.basic import XOp
13+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
14+
15+
16+
def as_idx_variable(idx):
17+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
18+
raise TypeError(
19+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
20+
)
21+
if isinstance(idx, slice):
22+
idx = make_slice(idx)
23+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
24+
pass
25+
else:
26+
# Must be integer indices, we already counted for None and slices
27+
try:
28+
idx = as_xtensor(idx)
29+
except TypeError:
30+
idx = as_tensor(idx)
31+
if idx.type.dtype == "bool":
32+
raise NotImplementedError("Boolean indexing not yet supported")
33+
if idx.type.dtype not in discrete_dtypes:
34+
raise TypeError("Numerical indices must be integers or boolean")
35+
if idx.type.dtype == "bool" and idx.type.ndim == 0:
36+
# This can't be triggered right now, but will once we lift the boolean restriction
37+
raise NotImplementedError("Scalar boolean indices not supported")
38+
return idx
39+
40+
41+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
42+
if dim_length is None:
43+
return None
44+
if isinstance(slc, Constant):
45+
d = slc.data
46+
start, stop, step = d.start, d.stop, d.step
47+
elif slc.owner is None:
48+
# It's a root variable no way of knowing what we're getting
49+
return None
50+
else:
51+
# It's a MakeSliceOp
52+
start, stop, step = slc.owner.inputs
53+
if isinstance(start, Constant):
54+
start = start.data
55+
else:
56+
return None
57+
if isinstance(stop, Constant):
58+
stop = stop.data
59+
else:
60+
return None
61+
if isinstance(step, Constant):
62+
step = step.data
63+
else:
64+
return None
65+
return len(range(*slice(start, stop, step).indices(dim_length)))
66+
67+
68+
class Index(XOp):
69+
__props__ = ()
70+
71+
def make_node(self, x, *idxs):
72+
x = as_xtensor(x)
73+
idxs = [as_idx_variable(idx) for idx in idxs]
74+
75+
x_ndim = x.type.ndim
76+
x_dims = x.type.dims
77+
x_shape = x.type.shape
78+
out_dims = []
79+
out_shape = []
80+
has_unlabeled_vector_idx = False
81+
has_labeled_vector_idx = False
82+
for i, idx in enumerate(idxs):
83+
if i == x_ndim:
84+
raise IndexError("Too many indices")
85+
if isinstance(idx.type, SliceType):
86+
out_dims.append(x_dims[i])
87+
out_shape.append(get_static_slice_length(idx, x_shape[i]))
88+
elif isinstance(idx.type, XTensorType):
89+
if has_unlabeled_vector_idx:
90+
raise NotImplementedError(
91+
"Mixing of labeled and unlabeled vector indexing not implemented"
92+
)
93+
has_labeled_vector_idx = True
94+
idx_dims = idx.type.dims
95+
for dim in idx_dims:
96+
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
97+
if dim in out_dims:
98+
# Dim already introduced in output by a previous index
99+
# Update static shape or raise if incompatible
100+
out_dim_pos = out_dims.index(dim)
101+
out_dim_shape = out_shape[out_dim_pos]
102+
if out_dim_shape is None:
103+
# We don't know the size of the dimension yet
104+
out_shape[out_dim_pos] = idx_dim_shape
105+
elif (
106+
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
107+
):
108+
raise IndexError(
109+
f"Dimension of indexers mismatch for dim {dim}"
110+
)
111+
else:
112+
# New dimension
113+
out_dims.append(dim)
114+
out_shape.append(idx_dim_shape)
115+
116+
else: # TensorType
117+
if idx.type.ndim == 0:
118+
# Scalar, dimension is dropped
119+
pass
120+
elif idx.type.ndim == 1:
121+
if has_labeled_vector_idx:
122+
raise NotImplementedError(
123+
"Mixing of labeled and unlabeled vector indexing not implemented"
124+
)
125+
has_unlabeled_vector_idx = True
126+
out_dims.append(x_dims[i])
127+
out_shape.append(idx.type.shape[0])
128+
else:
129+
# Same error that xarray raises
130+
raise IndexError(
131+
"Unlabeled multi-dimensional array cannot be used for indexing"
132+
)
133+
for j in range(i + 1, x_ndim):
134+
# Add any unindexed dimensions
135+
out_dims.append(x_dims[j])
136+
out_shape.append(x_shape[j])
137+
138+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
139+
return Apply(self, [x, *idxs], [output])
140+
141+
142+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.shape
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import TensorType
3+
from pytensor.tensor.type_other import SliceType
4+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.indexing import Index
6+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
7+
8+
9+
def is_basic_idx(idx):
10+
return (
11+
isinstance(idx.type, SliceType)
12+
or isinstance(idx.type, TensorType)
13+
and idx.type.ndim == 0
14+
and idx.type.dtype != bool
15+
)
16+
17+
18+
@register_xcanonicalize
19+
@node_rewriter(tracks=[Index])
20+
def lower_index(fgraph, node):
21+
x, *idxs = node.inputs
22+
x_tensor = tensor_from_xtensor(x)
23+
if all(is_basic_idx(idx) for idx in idxs):
24+
# Simple case
25+
x_tensor_indexed = x_tensor[tuple(idxs)]
26+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=node.outputs[0].type.dims)
27+
return [new_out]

pytensor/xtensor/type.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
3+
14
try:
25
import xarray as xr
36

@@ -6,13 +9,13 @@
69
XARRAY_AVAILABLE = False
710

811
from collections.abc import Sequence
9-
from typing import TypeVar
12+
from typing import Any, Literal, TypeVar
1013

1114
import numpy as np
1215

1316
from pytensor import _as_symbolic, config
1417
from pytensor.graph import Apply, Constant
15-
from pytensor.graph.basic import Variable, OptionalApplyType
18+
from pytensor.graph.basic import OptionalApplyType, Variable
1619
from pytensor.graph.type import HasDataType, HasShape, Type
1720
from pytensor.tensor.utils import hash_from_ndarray
1821
from pytensor.utils import hash_from_code
@@ -141,17 +144,69 @@ def __getitem__(self, idx):
141144
if isinstance(idx, dict):
142145
return self.isel(idx)
143146

147+
# Check for ellipsis not in the last position (last one is useless anyway)
148+
if any(idx_item is Ellipsis for idx_item in idx):
149+
if idx.count(Ellipsis) > 1:
150+
raise IndexError("an index can only have a single ellipsis ('...')")
151+
# Convert intermediate Ellipsis to slice(None)
152+
ellipsis_loc = idx.index(Ellipsis)
153+
n_implied_none_slices = self.type.ndim - (len(idx) - 1)
154+
idx = (
155+
*idx[:ellipsis_loc],
156+
*((slice(None),) * n_implied_none_slices),
157+
*idx[ellipsis_loc + 1 :],
158+
)
159+
144160
return index(self, *idx)
145161

162+
def sel(self, *args, **kwargs):
163+
raise NotImplementedError(
164+
"sel not implemented for XTensorVariable, use isel instead"
165+
)
146166

147-
class XTensorVariable(Variable):
148-
pass
167+
def isel(
168+
self,
169+
indexers: dict[str, Any] | None = None,
170+
drop: bool = False, # Unused by PyTensor
171+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
172+
**indexers_kwargs,
173+
):
174+
from pytensor.xtensor.indexing import index
175+
176+
if indexers_kwargs:
177+
if indexers is not None:
178+
raise ValueError(
179+
"Cannot pass both indexers and indexers_kwargs to isel"
180+
)
181+
indexers = indexers_kwargs
149182

150-
# def __str__(self):
151-
# return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}"
183+
if missing_dims not in {"raise", "warn", "ignore"}:
184+
raise ValueError(
185+
f"Unrecognized options {missing_dims} for missing_dims argument"
186+
)
152187

153-
# def __repr__(self):
154-
# return str(self)
188+
# Sort indices and pass them to index
189+
dims = self.type.dims
190+
indices = [slice(None)] * self.type.ndim
191+
for key, idx in indexers.items():
192+
if idx is Ellipsis:
193+
# Xarray raises a less informative error, suggesting indices must be integer
194+
# But slices are also fine
195+
raise TypeError("Ellipsis (...) is an invalid labeled index")
196+
try:
197+
indices[dims.index(key)] = idx
198+
except IndexError:
199+
if missing_dims == "raise":
200+
raise ValueError(
201+
f"Dimension {key} does not exist. Expected one of {dims}"
202+
)
203+
elif missing_dims == "warn":
204+
warnings.warn(
205+
UserWarning,
206+
f"Dimension {key} does not exist. Expected one of {dims}",
207+
)
208+
209+
return index(self, *indices)
155210

156211

157212
class XTensorConstantSignature(tuple):

tests/xtensor/test_indexing.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import pytest
3+
from xarray import DataArray
4+
from xtensor.util import xr_assert_allclose, xr_function
5+
6+
from pytensor.xtensor import xtensor
7+
8+
9+
@pytest.mark.parametrize(
10+
"indices",
11+
[
12+
(0,),
13+
(slice(1, None),),
14+
(slice(None, -1),),
15+
(slice(None, None, -1),),
16+
(0, slice(None), -1, slice(1, None)),
17+
(..., 0, -1),
18+
(0, ..., -1),
19+
(0, -1, ...),
20+
],
21+
)
22+
@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"])
23+
def test_basic_indexing(labeled, indices):
24+
if ... in indices and labeled:
25+
pytest.skip("Ellipsis not supported with labeled indexing")
26+
27+
dims = ("a", "b", "c", "d")
28+
x = xtensor(dims=dims, shape=(2, 3, 5, 7))
29+
30+
if labeled:
31+
shufled_dims = tuple(np.random.permutation(dims))
32+
indices = dict(zip(shufled_dims, indices, strict=False))
33+
out = x[indices]
34+
35+
fn = xr_function([x], out)
36+
x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(
37+
x.type.shape
38+
)
39+
x_test = DataArray(x_test_values, dims=x.type.dims)
40+
res = fn(x_test)
41+
expected_res = x_test[indices]
42+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)