Skip to content

Commit 30e1a42

Browse files
committed
WIP Implement index operations for XTensorVariables
1 parent 5a7b23c commit 30e1a42

File tree

6 files changed

+476
-3
lines changed

6 files changed

+476
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from pytensor.xtensor.shape import concat
99
from pytensor.xtensor.type import (
10-
XTensorType,
1110
as_xtensor,
1211
xtensor,
1312
xtensor_constant,

pytensor/xtensor/indexing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor import TensorType
11+
from pytensor.tensor.basic import as_tensor
12+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
13+
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
14+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
15+
16+
17+
def as_idx_variable(idx):
18+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
19+
raise TypeError(
20+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
21+
)
22+
if isinstance(idx, slice):
23+
idx = make_slice(idx)
24+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
25+
pass
26+
elif isinstance(idx, tuple) and len(idx) == 2 and isinstance(idx[0], str):
27+
# Special case for ("x", array) that xarray supports
28+
# TODO: Check if this can be used to rename existing xarray dimensions or only for numpy
29+
dim, idx = idx
30+
idx = xtensor_from_tensor(as_tensor(idx), dims=(dim,))
31+
else:
32+
# Must be integer indices, we already counted for None and slices
33+
try:
34+
idx = as_xtensor(idx)
35+
except TypeError:
36+
idx = as_tensor(idx)
37+
if idx.type.dtype == "bool":
38+
raise NotImplementedError("Boolean indexing not yet supported")
39+
if idx.type.dtype not in discrete_dtypes:
40+
raise TypeError("Numerical indices must be integers or boolean")
41+
if idx.type.dtype == "bool" and idx.type.ndim == 0:
42+
# This can't be triggered right now, but will once we lift the boolean restriction
43+
raise NotImplementedError("Scalar boolean indices not supported")
44+
return idx
45+
46+
47+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
48+
if dim_length is None:
49+
return None
50+
if isinstance(slc, Constant):
51+
d = slc.data
52+
start, stop, step = d.start, d.stop, d.step
53+
elif slc.owner is None:
54+
# It's a root variable no way of knowing what we're getting
55+
return None
56+
else:
57+
# It's a MakeSliceOp
58+
start, stop, step = slc.owner.inputs
59+
if isinstance(start, Constant):
60+
start = start.data
61+
else:
62+
return None
63+
if isinstance(stop, Constant):
64+
stop = stop.data
65+
else:
66+
return None
67+
if isinstance(step, Constant):
68+
step = step.data
69+
else:
70+
return None
71+
return len(range(*slice(start, stop, step).indices(dim_length)))
72+
73+
74+
class Index(XOp):
75+
__props__ = ()
76+
77+
def make_node(self, x, *idxs):
78+
x = as_xtensor(x)
79+
idxs = [as_idx_variable(idx) for idx in idxs]
80+
81+
x_ndim = x.type.ndim
82+
x_dims = x.type.dims
83+
x_shape = x.type.shape
84+
out_dims = []
85+
out_shape = []
86+
for i, idx in enumerate(idxs):
87+
if i == x_ndim:
88+
raise IndexError("Too many indices")
89+
if isinstance(idx.type, SliceType):
90+
out_dims.append(x_dims[i])
91+
out_shape.append(get_static_slice_length(idx, x_shape[i]))
92+
else:
93+
if idx.type.ndim == 0:
94+
# Scalar index, dimension is dropped
95+
continue
96+
97+
if isinstance(idx.type, TensorType):
98+
if idx.type.ndim > 1:
99+
# Same error that xarray raises
100+
raise IndexError(
101+
"Unlabeled multi-dimensional array cannot be used for indexing"
102+
)
103+
104+
# This is implicitly an XTensorVariable with dim matching the indexed one
105+
idx = idxs[i] = xtensor_from_tensor(idx, dims=(x_dims[i],))
106+
107+
assert isinstance(idx.type, XTensorType)
108+
109+
idx_dims = idx.type.dims
110+
for dim in idx_dims:
111+
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
112+
if dim in out_dims:
113+
# Dim already introduced in output by a previous index
114+
# Update static shape or raise if incompatible
115+
out_dim_pos = out_dims.index(dim)
116+
out_dim_shape = out_shape[out_dim_pos]
117+
if out_dim_shape is None:
118+
# We don't know the size of the dimension yet
119+
out_shape[out_dim_pos] = idx_dim_shape
120+
elif (
121+
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
122+
):
123+
raise IndexError(
124+
f"Dimension of indexers mismatch for dim {dim}"
125+
)
126+
else:
127+
# New dimension
128+
out_dims.append(dim)
129+
out_shape.append(idx_dim_shape)
130+
131+
for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]):
132+
# Add back any unindexed dimensions
133+
if dim_i not in out_dims:
134+
# If the dimension was not indexed, we keep it as is
135+
out_dims.append(dim_i)
136+
out_shape.append(shape_i)
137+
138+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
139+
return Apply(self, [x, *idxs], [output])
140+
141+
142+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.reduction
34
import pytensor.xtensor.rewriting.shape
45
import pytensor.xtensor.rewriting.vectorization
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from itertools import zip_longest
2+
3+
from pytensor import as_symbolic
4+
from pytensor.graph import Constant, node_rewriter
5+
from pytensor.tensor import arange, specify_shape
6+
from pytensor.tensor.type_other import NoneTypeT, SliceType
7+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
8+
from pytensor.xtensor.indexing import Index
9+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
10+
from pytensor.xtensor.type import XTensorType
11+
12+
13+
def to_basic_idx(idx):
14+
if isinstance(idx.type, SliceType):
15+
if isinstance(idx, Constant):
16+
return idx.data
17+
elif idx.owner:
18+
# MakeSlice Op
19+
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
20+
return slice(
21+
*[
22+
None if isinstance(i.type, NoneTypeT) else i
23+
for i in idx.owner.inputs
24+
]
25+
)
26+
else:
27+
return idx
28+
if (
29+
isinstance(idx.type, XTensorType)
30+
and idx.type.ndim == 0
31+
and idx.type.dtype != bool
32+
):
33+
return idx.values
34+
raise TypeError("Cannot convert idx to basic idx")
35+
36+
37+
@register_xcanonicalize
38+
@node_rewriter(tracks=[Index])
39+
def lower_index(fgraph, node):
40+
x, *idxs = node.inputs
41+
[out] = node.outputs
42+
x_tensor = tensor_from_xtensor(x)
43+
44+
if all(
45+
(
46+
isinstance(idx.type, SliceType)
47+
or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0)
48+
)
49+
for idx in idxs
50+
):
51+
# Special case just basic indexing
52+
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
53+
54+
else:
55+
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing
56+
# May need to convert basic indexing to advanced indexing if it acts on a dimension
57+
# that is also indexed by an advanced index
58+
x_dims = x.type.dims
59+
x_shape = tuple(x.shape)
60+
out_ndim = out.type.ndim
61+
out_xdims = out.type.dims
62+
aligned_idxs = []
63+
# zip_longest adds the implicit slice(None)
64+
for i, (idx, x_dim) in enumerate(
65+
zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None)))
66+
):
67+
if isinstance(idx.type, SliceType):
68+
if not any(
69+
(
70+
isinstance(other_idx.type, XTensorType)
71+
and x_dim in other_idx.dims
72+
)
73+
for j, other_idx in enumerate(idxs)
74+
if j != i
75+
):
76+
# We can use basic indexing directly if no other index acts on this dimension
77+
aligned_idxs.append(idx)
78+
else:
79+
# Otherwise we need to convert the basic index into an equivalent advanced indexing
80+
# And align it so it interacts correctly with the other advanced indices
81+
adv_idx_equivalent = arange(x_shape[i])[idx]
82+
ds_order = ["x"] * out_ndim
83+
ds_order[out_xdims.index(x_dim)] = 0
84+
aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order))
85+
else:
86+
assert isinstance(idx.type, XTensorType)
87+
if idx.type.ndim == 0:
88+
# Scalar index, we can use it directly
89+
aligned_idxs.append(idx.values)
90+
else:
91+
# Vector index, we need to align the indexing dimensions with the base_dims
92+
ds_order = ["x"] * out_ndim
93+
for j, idx_dim in enumerate(idx.dims):
94+
ds_order[out_xdims.index(idx_dim)] = j
95+
aligned_idxs.append(idx.values.dimshuffle(ds_order))
96+
x_tensor_indexed = x_tensor[tuple(aligned_idxs)]
97+
# TODO: Align output dimensions if necessary
98+
99+
# Add lost shape if any
100+
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
101+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
102+
return [new_out]

pytensor/xtensor/type.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from pytensor.tensor import TensorType
24
from pytensor.tensor.math import variadic_mul
35

@@ -10,7 +12,7 @@
1012
XARRAY_AVAILABLE = False
1113

1214
from collections.abc import Sequence
13-
from typing import TypeVar
15+
from typing import Any, Literal, TypeVar
1416

1517
import numpy as np
1618

@@ -339,7 +341,115 @@ def sel(self, *args, **kwargs):
339341
raise NotImplementedError("sel not implemented for XTensorVariable")
340342

341343
def __getitem__(self, idx):
342-
raise NotImplementedError("Indexing not yet implemnented")
344+
if isinstance(idx, dict):
345+
return self.isel(idx)
346+
347+
if not isinstance(idx, tuple):
348+
idx = (idx,)
349+
350+
# Check for ellipsis not in the last position (last one is useless anyway)
351+
if any(idx_item is Ellipsis for idx_item in idx):
352+
if idx.count(Ellipsis) > 1:
353+
raise IndexError("an index can only have a single ellipsis ('...')")
354+
# Convert intermediate Ellipsis to slice(None)
355+
ellipsis_loc = idx.index(Ellipsis)
356+
n_implied_none_slices = self.type.ndim - (len(idx) - 1)
357+
idx = (
358+
*idx[:ellipsis_loc],
359+
*((slice(None),) * n_implied_none_slices),
360+
*idx[ellipsis_loc + 1 :],
361+
)
362+
363+
return px.indexing.index(self, *idx)
364+
365+
def isel(
366+
self,
367+
indexers: dict[str, Any] | None = None,
368+
drop: bool = False, # Unused by PyTensor
369+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
370+
**indexers_kwargs,
371+
):
372+
if indexers_kwargs:
373+
if indexers is not None:
374+
raise ValueError(
375+
"Cannot pass both indexers and indexers_kwargs to isel"
376+
)
377+
indexers = indexers_kwargs
378+
379+
if missing_dims not in {"raise", "warn", "ignore"}:
380+
raise ValueError(
381+
f"Unrecognized options {missing_dims} for missing_dims argument"
382+
)
383+
384+
# Sort indices and pass them to index
385+
dims = self.type.dims
386+
indices = [slice(None)] * self.type.ndim
387+
for key, idx in indexers.items():
388+
if idx is Ellipsis:
389+
# Xarray raises a less informative error, suggesting indices must be integer
390+
# But slices are also fine
391+
raise TypeError("Ellipsis (...) is an invalid labeled index")
392+
try:
393+
indices[dims.index(key)] = idx
394+
except IndexError:
395+
if missing_dims == "raise":
396+
raise ValueError(
397+
f"Dimension {key} does not exist. Expected one of {dims}"
398+
)
399+
elif missing_dims == "warn":
400+
warnings.warn(
401+
UserWarning,
402+
f"Dimension {key} does not exist. Expected one of {dims}",
403+
)
404+
405+
return px.indexing.index(self, *indices)
406+
407+
def _head_tail_or_thin(
408+
self,
409+
indexers: dict[str, Any] | int | None,
410+
indexers_kwargs: dict[str, Any],
411+
*,
412+
kind: Literal["head", "tail", "thin"],
413+
):
414+
if indexers_kwargs:
415+
if indexers is not None:
416+
raise ValueError(
417+
"Cannot pass both indexers and indexers_kwargs to head"
418+
)
419+
indexers = indexers_kwargs
420+
421+
if indexers is None:
422+
if kind == "thin":
423+
raise TypeError(
424+
"thin() indexers must be either dict-like or a single integer"
425+
)
426+
else:
427+
# Default to 5 for head and tail
428+
indexers = {dim: 5 for dim in self.type.dims}
429+
430+
elif not isinstance(indexers, dict):
431+
indexers = {dim: indexers for dim in self.type.dims}
432+
433+
if kind == "head":
434+
indices = {dim: slice(None, value) for dim, value in indexers.items()}
435+
elif kind == "tail":
436+
sizes = self.sizes
437+
# Can't use slice(-value, None), in case value is zero
438+
indices = {
439+
dim: slice(sizes[dim] - value, None) for dim, value in indexers.items()
440+
}
441+
elif kind == "thin":
442+
indices = {dim: slice(None, None, value) for dim, value in indexers.items()}
443+
return self.isel(indices)
444+
445+
def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
446+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head")
447+
448+
def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
449+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail")
450+
451+
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
452+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
343453

344454
# ndarray methods
345455
# https://docs.xarray.dev/en/latest/api.html#id7

0 commit comments

Comments
 (0)