Skip to content

Commit 2742594

Browse files
committed
WIP Implement index operations on XTensorTypes
1 parent 8c2d953 commit 2742594

File tree

6 files changed

+263
-9
lines changed

6 files changed

+263
-9
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor.type import (
5-
XTensorType,
65
as_xtensor,
76
as_xtensor_variable,
87
xtensor,

pytensor/xtensor/indexing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor.basic import as_tensor
11+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
12+
from pytensor.xtensor.basic import XOp
13+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
14+
15+
16+
def as_idx_variable(idx):
17+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
18+
raise TypeError(
19+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
20+
)
21+
if isinstance(idx, slice):
22+
idx = make_slice(idx)
23+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
24+
pass
25+
else:
26+
# Must be integer indices, we already counted for None and slices
27+
try:
28+
idx = as_xtensor(idx)
29+
except TypeError:
30+
idx = as_tensor(idx)
31+
if idx.type.dtype == "bool":
32+
raise NotImplementedError("Boolean indexing not yet supported")
33+
if idx.type.dtype not in discrete_dtypes:
34+
raise TypeError("Numerical indices must be integers or boolean")
35+
if idx.type.dtype == "bool" and idx.type.ndim == 0:
36+
# This can't be triggered right now, but will once we lift the boolean restriction
37+
raise NotImplementedError("Scalar boolean indices not supported")
38+
return idx
39+
40+
41+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
42+
if dim_length is None:
43+
return None
44+
if isinstance(slc, Constant):
45+
d = slc.data
46+
start, stop, step = d.start, d.stop, d.step
47+
elif slc.owner is None:
48+
# It's a root variable no way of knowing what we're getting
49+
return None
50+
else:
51+
# It's a MakeSliceOp
52+
start, stop, step = slc.owner.inputs
53+
if isinstance(start, Constant):
54+
start = start.data
55+
else:
56+
return None
57+
if isinstance(stop, Constant):
58+
stop = stop.data
59+
else:
60+
return None
61+
if isinstance(step, Constant):
62+
step = step.data
63+
else:
64+
return None
65+
return len(range(*slice(start, stop, step).indices(dim_length)))
66+
67+
68+
class Index(XOp):
69+
__props__ = ()
70+
71+
def make_node(self, x, *idxs):
72+
x = as_xtensor(x)
73+
idxs = [as_idx_variable(idx) for idx in idxs]
74+
75+
x_ndim = x.type.ndim
76+
x_dims = x.type.dims
77+
x_shape = x.type.shape
78+
out_dims = []
79+
out_shape = []
80+
has_unlabeled_vector_idx = False
81+
has_labeled_vector_idx = False
82+
for i, idx in enumerate(idxs):
83+
if i == x_ndim:
84+
raise IndexError("Too many indices")
85+
if isinstance(idx.type, SliceType):
86+
out_dims.append(x_dims[i])
87+
out_shape.append(get_static_slice_length(idx, x_shape[i]))
88+
elif isinstance(idx.type, XTensorType):
89+
if has_unlabeled_vector_idx:
90+
raise NotImplementedError(
91+
"Mixing of labeled and unlabeled vector indexing not implemented"
92+
)
93+
has_labeled_vector_idx = True
94+
idx_dims = idx.type.dims
95+
for dim in idx_dims:
96+
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
97+
if dim in out_dims:
98+
# Dim already introduced in output by a previous index
99+
# Update static shape or raise if incompatible
100+
out_dim_pos = out_dims.index(dim)
101+
out_dim_shape = out_shape[out_dim_pos]
102+
if out_dim_shape is None:
103+
# We don't know the size of the dimension yet
104+
out_shape[out_dim_pos] = idx_dim_shape
105+
elif (
106+
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
107+
):
108+
raise IndexError(
109+
f"Dimension of indexers mismatch for dim {dim}"
110+
)
111+
else:
112+
# New dimension
113+
out_dims.append(dim)
114+
out_shape.append(idx_dim_shape)
115+
116+
else: # TensorType
117+
if idx.type.ndim == 0:
118+
# Scalar, dimension is dropped
119+
pass
120+
elif idx.type.ndim == 1:
121+
if has_labeled_vector_idx:
122+
raise NotImplementedError(
123+
"Mixing of labeled and unlabeled vector indexing not implemented"
124+
)
125+
has_unlabeled_vector_idx = True
126+
out_dims.append(x_dims[i])
127+
out_shape.append(idx.type.shape[0])
128+
else:
129+
# Same error that xarray raises
130+
raise IndexError(
131+
"Unlabeled multi-dimensional array cannot be used for indexing"
132+
)
133+
for j in range(i + 1, x_ndim):
134+
# Add any unindexed dimensions
135+
out_dims.append(x_dims[j])
136+
out_shape.append(x_shape[j])
137+
138+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
139+
return Apply(self, [x, *idxs], [output])
140+
141+
142+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.shape
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import TensorType
3+
from pytensor.tensor.type_other import SliceType
4+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.indexing import Index
6+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
7+
8+
9+
def is_basic_idx(idx):
10+
return (
11+
isinstance(idx.type, SliceType)
12+
or isinstance(idx.type, TensorType)
13+
and idx.type.ndim == 0
14+
and idx.type.dtype != bool
15+
)
16+
17+
18+
@register_xcanonicalize
19+
@node_rewriter(tracks=[Index])
20+
def lower_index(fgraph, node):
21+
x, *idxs = node.inputs
22+
x_tensor = tensor_from_xtensor(x)
23+
if all(is_basic_idx(idx) for idx in idxs):
24+
# Simple case
25+
x_tensor_indexed = x_tensor[tuple(idxs)]
26+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=node.outputs[0].type.dims)
27+
return [new_out]

pytensor/xtensor/type.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
3+
14
try:
25
import xarray as xr
36

@@ -6,13 +9,13 @@
69
XARRAY_AVAILABLE = False
710

811
from collections.abc import Sequence
9-
from typing import TypeVar
12+
from typing import Any, Literal, TypeVar
1013

1114
import numpy as np
1215

1316
from pytensor import _as_symbolic, config
1417
from pytensor.graph import Apply, Constant
15-
from pytensor.graph.basic import Variable, OptionalApplyType
18+
from pytensor.graph.basic import OptionalApplyType, Variable
1619
from pytensor.graph.type import HasDataType, HasShape, Type
1720
from pytensor.tensor.utils import hash_from_ndarray
1821
from pytensor.utils import hash_from_code
@@ -141,17 +144,62 @@ def __getitem__(self, idx):
141144
if isinstance(idx, dict):
142145
return self.isel(idx)
143146

147+
if any(isinstance(idx_item, Ellipsis) for idx_item in idx):
148+
# TODO: Convert Ellipsis to slice(None)
149+
raise NotImplementedError(
150+
"Ellipsis (...) is not yet supported for indexing"
151+
)
152+
144153
return index(self, *idx)
145154

155+
def sel(self, *args, **kwargs):
156+
raise NotImplementedError(
157+
"sel not implemented for XTensorVariable, use isel instead"
158+
)
146159

147-
class XTensorVariable(Variable):
148-
pass
160+
def isel(
161+
self,
162+
indexers: dict[str, Any] | None = None,
163+
drop: bool = False, # Unused by PyTensor
164+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
165+
**indexers_kwargs,
166+
):
167+
from pytensor.xtensor.indexing import index
168+
169+
if indexers_kwargs:
170+
if indexers is not None:
171+
raise ValueError(
172+
"Cannot pass both indexers and indexers_kwargs to isel"
173+
)
174+
indexers = indexers_kwargs
149175

150-
# def __str__(self):
151-
# return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}"
176+
if missing_dims not in {"raise", "warn", "ignore"}:
177+
raise ValueError(
178+
f"Unrecognized options {missing_dims} for missing_dims argument"
179+
)
152180

153-
# def __repr__(self):
154-
# return str(self)
181+
# Sort indices and pass them to index
182+
dims = self.type.dims
183+
indices = [slice(None) for _ in range(self.type.ndim)]
184+
for key, idx in indexers.items():
185+
if isinstance(idx, Ellipsis):
186+
# Xarray raises a less informative error, suggesting indices must be integer
187+
# But slices are also fine
188+
raise TypeError("Ellipsis (...) is an invalid labeled index")
189+
try:
190+
indices[dims.index(key)] = idx
191+
except IndexError:
192+
if missing_dims == "raise":
193+
raise ValueError(
194+
f"Dimension {key} does not exist. Expected one of {dims}"
195+
)
196+
elif missing_dims == "warn":
197+
warnings.warn(
198+
UserWarning,
199+
f"Dimension {key} does not exist. Expected one of {dims}",
200+
)
201+
202+
return index(self, *indices)
155203

156204

157205
class XTensorConstantSignature(tuple):

tests/xtensor/test_indexing.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
import pytest
3+
from xarray import DataArray
4+
from xtensor.util import xr_assert_allclose, xr_function
5+
6+
from pytensor.xtensor import xtensor
7+
8+
9+
@pytest.mark.parametrize(
10+
"indices",
11+
[
12+
(0,),
13+
(slice(1, None),),
14+
(slice(None, -1),),
15+
(slice(None, None, -1),),
16+
(0, slice(None), -1, slice(1, None)),
17+
(..., 0, -1),
18+
],
19+
)
20+
@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"])
21+
def test_basic_indexing(labeled, indices):
22+
dims = ("a", "b", "c", "d")
23+
x = xtensor(dims=dims, shape=(2, 3, 5, 7))
24+
25+
if labeled:
26+
shufled_dims = tuple(np.random.permutation(dims))
27+
indices = dict(zip(shufled_dims, indices, strict=False))
28+
out = x[indices]
29+
30+
fn = xr_function([x], out)
31+
x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(
32+
x.type.shape
33+
)
34+
x_test = DataArray(x_test_values, dims=x.type.dims)
35+
res = fn(x_test)
36+
expected_res = x_test[indices]
37+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)