Skip to content

Commit 74f66a9

Browse files
committed
WIP Implement index operations for XTensorVariables
1 parent 4cb2392 commit 74f66a9

File tree

6 files changed

+361
-3
lines changed

6 files changed

+361
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from pytensor.xtensor.shape import concat
99
from pytensor.xtensor.type import (
10-
XTensorType,
1110
as_xtensor,
1211
xtensor,
1312
xtensor_constant,

pytensor/xtensor/indexing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor.basic import as_tensor
11+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
12+
from pytensor.xtensor.basic import XOp
13+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
14+
15+
16+
def as_idx_variable(idx):
17+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
18+
raise TypeError(
19+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
20+
)
21+
if isinstance(idx, slice):
22+
idx = make_slice(idx)
23+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
24+
pass
25+
else:
26+
# Must be integer indices, we already counted for None and slices
27+
try:
28+
idx = as_tensor(idx)
29+
except TypeError:
30+
idx = as_xtensor(idx)
31+
if idx.type.dtype == "bool":
32+
raise NotImplementedError("Boolean indexing not yet supported")
33+
if idx.type.dtype not in discrete_dtypes:
34+
raise TypeError("Numerical indices must be integers or boolean")
35+
if idx.type.dtype == "bool" and idx.type.ndim == 0:
36+
# This can't be triggered right now, but will once we lift the boolean restriction
37+
raise NotImplementedError("Scalar boolean indices not supported")
38+
return idx
39+
40+
41+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
42+
if dim_length is None:
43+
return None
44+
if isinstance(slc, Constant):
45+
d = slc.data
46+
start, stop, step = d.start, d.stop, d.step
47+
elif slc.owner is None:
48+
# It's a root variable no way of knowing what we're getting
49+
return None
50+
else:
51+
# It's a MakeSliceOp
52+
start, stop, step = slc.owner.inputs
53+
if isinstance(start, Constant):
54+
start = start.data
55+
else:
56+
return None
57+
if isinstance(stop, Constant):
58+
stop = stop.data
59+
else:
60+
return None
61+
if isinstance(step, Constant):
62+
step = step.data
63+
else:
64+
return None
65+
return len(range(*slice(start, stop, step).indices(dim_length)))
66+
67+
68+
class Index(XOp):
69+
__props__ = ()
70+
71+
def make_node(self, x, *idxs):
72+
x = as_xtensor(x)
73+
idxs = [as_idx_variable(idx) for idx in idxs]
74+
75+
x_ndim = x.type.ndim
76+
x_dims = x.type.dims
77+
x_shape = x.type.shape
78+
out_dims = []
79+
out_shape = []
80+
has_unlabeled_vector_idx = False
81+
has_labeled_vector_idx = False
82+
for i, idx in enumerate(idxs):
83+
if i == x_ndim:
84+
raise IndexError("Too many indices")
85+
if isinstance(idx.type, SliceType):
86+
out_dims.append(x_dims[i])
87+
out_shape.append(get_static_slice_length(idx, x_shape[i]))
88+
elif isinstance(idx.type, XTensorType):
89+
if has_unlabeled_vector_idx:
90+
raise NotImplementedError(
91+
"Mixing of labeled and unlabeled vector indexing not implemented"
92+
)
93+
has_labeled_vector_idx = True
94+
idx_dims = idx.type.dims
95+
for dim in idx_dims:
96+
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
97+
if dim in out_dims:
98+
# Dim already introduced in output by a previous index
99+
# Update static shape or raise if incompatible
100+
out_dim_pos = out_dims.index(dim)
101+
out_dim_shape = out_shape[out_dim_pos]
102+
if out_dim_shape is None:
103+
# We don't know the size of the dimension yet
104+
out_shape[out_dim_pos] = idx_dim_shape
105+
elif (
106+
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
107+
):
108+
raise IndexError(
109+
f"Dimension of indexers mismatch for dim {dim}"
110+
)
111+
else:
112+
# New dimension
113+
out_dims.append(dim)
114+
out_shape.append(idx_dim_shape)
115+
116+
else: # TensorType
117+
if idx.type.ndim == 0:
118+
# Scalar, dimension is dropped
119+
pass
120+
elif idx.type.ndim == 1:
121+
if has_labeled_vector_idx:
122+
raise NotImplementedError(
123+
"Mixing of labeled and unlabeled vector indexing not implemented"
124+
)
125+
has_unlabeled_vector_idx = True
126+
out_dims.append(x_dims[i])
127+
out_shape.append(idx.type.shape[0])
128+
else:
129+
# Same error that xarray raises
130+
raise IndexError(
131+
"Unlabeled multi-dimensional array cannot be used for indexing"
132+
)
133+
for j in range(i + 1, x_ndim):
134+
# Add any unindexed dimensions
135+
out_dims.append(x_dims[j])
136+
out_shape.append(x_shape[j])
137+
138+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
139+
return Apply(self, [x, *idxs], [output])
140+
141+
142+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.reduction
34
import pytensor.xtensor.rewriting.shape
45
import pytensor.xtensor.rewriting.vectorization
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from pytensor.graph import Constant, node_rewriter
2+
from pytensor.tensor import TensorType, specify_shape
3+
from pytensor.tensor.type_other import NoneTypeT, SliceType
4+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.indexing import Index
6+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
7+
from pytensor.xtensor.type import XTensorType
8+
9+
10+
def to_basic_idx(idx):
11+
if isinstance(idx.type, SliceType):
12+
if isinstance(idx, Constant):
13+
return idx.data
14+
elif idx.owner:
15+
# MakeSlice Op
16+
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
17+
return slice(
18+
*[
19+
None if isinstance(i.type, NoneTypeT) else i
20+
for i in idx.owner.inputs
21+
]
22+
)
23+
else:
24+
return idx
25+
if (
26+
isinstance(idx.type, XTensorType | TensorType)
27+
and idx.type.ndim == 0
28+
and idx.type.dtype != bool
29+
):
30+
return idx
31+
raise TypeError("Cannot convert idx to basic idx")
32+
33+
34+
def _count_idx_types(idxs):
35+
basic, vector, xvector = 0, 0, 0
36+
for idx in idxs:
37+
if isinstance(idx.type, SliceType):
38+
basic += 1
39+
elif idx.type.ndim == 0:
40+
basic += 1
41+
elif isinstance(idx.type, TensorType):
42+
vector += 1
43+
else:
44+
xvector += 1
45+
return basic, vector, xvector
46+
47+
48+
@register_xcanonicalize
49+
@node_rewriter(tracks=[Index])
50+
def lower_index(fgraph, node):
51+
x, *idxs = node.inputs
52+
[out] = node.outputs
53+
x_tensor = tensor_from_xtensor(x)
54+
n_basic, n_vector, n_xvector = _count_idx_types(idxs)
55+
if n_xvector == 0 and n_vector == 0:
56+
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
57+
elif n_vector == 1 and n_xvector == 0:
58+
# Special case for single vector index, no orthogonal indexing
59+
x_tensor_indexed = x_tensor[tuple(idxs)]
60+
else:
61+
# Not yet implemented
62+
return None
63+
64+
# Add lost shape if any
65+
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
66+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
67+
return [new_out]

pytensor/xtensor/type.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from pytensor.tensor import TensorType
24
from pytensor.tensor.math import variadic_mul
35

@@ -10,7 +12,7 @@
1012
XARRAY_AVAILABLE = False
1113

1214
from collections.abc import Sequence
13-
from typing import TypeVar
15+
from typing import Any, Literal, TypeVar
1416

1517
import numpy as np
1618

@@ -339,7 +341,112 @@ def sel(self, *args, **kwargs):
339341
raise NotImplementedError("sel not implemented for XTensorVariable")
340342

341343
def __getitem__(self, idx):
342-
raise NotImplementedError("Indexing not yet implemnented")
344+
if isinstance(idx, dict):
345+
return self.isel(idx)
346+
347+
# Check for ellipsis not in the last position (last one is useless anyway)
348+
if any(idx_item is Ellipsis for idx_item in idx):
349+
if idx.count(Ellipsis) > 1:
350+
raise IndexError("an index can only have a single ellipsis ('...')")
351+
# Convert intermediate Ellipsis to slice(None)
352+
ellipsis_loc = idx.index(Ellipsis)
353+
n_implied_none_slices = self.type.ndim - (len(idx) - 1)
354+
idx = (
355+
*idx[:ellipsis_loc],
356+
*((slice(None),) * n_implied_none_slices),
357+
*idx[ellipsis_loc + 1 :],
358+
)
359+
360+
return px.indexing.index(self, *idx)
361+
362+
def isel(
363+
self,
364+
indexers: dict[str, Any] | None = None,
365+
drop: bool = False, # Unused by PyTensor
366+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
367+
**indexers_kwargs,
368+
):
369+
if indexers_kwargs:
370+
if indexers is not None:
371+
raise ValueError(
372+
"Cannot pass both indexers and indexers_kwargs to isel"
373+
)
374+
indexers = indexers_kwargs
375+
376+
if missing_dims not in {"raise", "warn", "ignore"}:
377+
raise ValueError(
378+
f"Unrecognized options {missing_dims} for missing_dims argument"
379+
)
380+
381+
# Sort indices and pass them to index
382+
dims = self.type.dims
383+
indices = [slice(None)] * self.type.ndim
384+
for key, idx in indexers.items():
385+
if idx is Ellipsis:
386+
# Xarray raises a less informative error, suggesting indices must be integer
387+
# But slices are also fine
388+
raise TypeError("Ellipsis (...) is an invalid labeled index")
389+
try:
390+
indices[dims.index(key)] = idx
391+
except IndexError:
392+
if missing_dims == "raise":
393+
raise ValueError(
394+
f"Dimension {key} does not exist. Expected one of {dims}"
395+
)
396+
elif missing_dims == "warn":
397+
warnings.warn(
398+
UserWarning,
399+
f"Dimension {key} does not exist. Expected one of {dims}",
400+
)
401+
402+
return px.indexing.index(self, *indices)
403+
404+
def _head_tail_or_thin(
405+
self,
406+
indexers: dict[str, Any] | int | None,
407+
indexers_kwargs: dict[str, Any],
408+
*,
409+
kind: Literal["head", "tail", "thin"],
410+
):
411+
if indexers_kwargs:
412+
if indexers is not None:
413+
raise ValueError(
414+
"Cannot pass both indexers and indexers_kwargs to head"
415+
)
416+
indexers = indexers_kwargs
417+
418+
if indexers is None:
419+
if kind == "thin":
420+
raise TypeError(
421+
"thin() indexers must be either dict-like or a single integer"
422+
)
423+
else:
424+
# Default to 5 for head and tail
425+
indexers = {dim: 5 for dim in self.type.dims}
426+
427+
elif not isinstance(indexers, dict):
428+
indexers = {dim: indexers for dim in self.type.dims}
429+
430+
if kind == "head":
431+
indices = {dim: slice(None, value) for dim, value in indexers.items()}
432+
elif kind == "tail":
433+
sizes = self.sizes
434+
# Can't use slice(-value, None), in case value is zero
435+
indices = {
436+
dim: slice(sizes[dim] - value, None) for dim, value in indexers.items()
437+
}
438+
elif kind == "thin":
439+
indices = {dim: slice(None, None, value) for dim, value in indexers.items()}
440+
return self.isel(indices)
441+
442+
def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
443+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head")
444+
445+
def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
446+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail")
447+
448+
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
449+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
343450

344451
# ndarray methods
345452
# https://docs.xarray.dev/en/latest/api.html#id7

tests/xtensor/test_indexing.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import pytest
3+
from xarray import DataArray
4+
5+
from pytensor.xtensor import xtensor
6+
from tests.xtensor.util import xr_assert_allclose, xr_function
7+
8+
9+
@pytest.mark.parametrize(
10+
"indices",
11+
[
12+
(0,),
13+
(slice(1, None),),
14+
(slice(None, -1),),
15+
(slice(None, None, -1),),
16+
(0, slice(None), -1, slice(1, None)),
17+
(..., 0, -1),
18+
(0, ..., -1),
19+
(0, -1, ...),
20+
],
21+
)
22+
@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"])
23+
def test_basic_indexing(labeled, indices):
24+
if ... in indices and labeled:
25+
pytest.skip("Ellipsis not supported with labeled indexing")
26+
27+
dims = ("a", "b", "c", "d")
28+
x = xtensor(dims=dims, shape=(2, 3, 5, 7))
29+
30+
if labeled:
31+
shufled_dims = tuple(np.random.permutation(dims))
32+
indices = dict(zip(shufled_dims, indices, strict=False))
33+
out = x[indices]
34+
35+
fn = xr_function([x], out)
36+
x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(
37+
x.type.shape
38+
)
39+
x_test = DataArray(x_test_values, dims=x.type.dims)
40+
res = fn(x_test)
41+
expected_res = x_test[indices]
42+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)