Skip to content

Commit 414702d

Browse files
committed
Implement index operations for XTensorVariables
1 parent 1450796 commit 414702d

File tree

6 files changed

+784
-3
lines changed

6 files changed

+784
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from pytensor.xtensor.shape import concat
99
from pytensor.xtensor.type import (
10-
XTensorType,
1110
as_xtensor,
1211
xtensor,
1312
xtensor_constant,

pytensor/xtensor/indexing.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor.basic import as_tensor
11+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
12+
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
13+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
14+
15+
16+
def as_idx_variable(idx, indexed_dim: str):
17+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
18+
raise TypeError(
19+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
20+
)
21+
if isinstance(idx, slice):
22+
idx = make_slice(idx)
23+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
24+
pass
25+
elif (
26+
isinstance(idx, tuple)
27+
and len(idx) == 2
28+
and (
29+
isinstance(idx[0], str)
30+
or (
31+
isinstance(idx[0], tuple | list)
32+
and all(isinstance(d, str) for d in idx[0])
33+
)
34+
)
35+
):
36+
# Special case for ("x", array) that xarray supports
37+
dim, idx = idx
38+
if isinstance(idx, Variable) and isinstance(idx.type, XTensorType):
39+
raise IndexError(
40+
f"Giving a dimension name to an XTensorVariable indexer is not supported: {(dim, idx)}. "
41+
"Use .rename() instead."
42+
)
43+
if isinstance(dim, str):
44+
dims = (dim,)
45+
else:
46+
dims = tuple(dim)
47+
idx = as_xtensor(as_tensor(idx), dims=dims)
48+
else:
49+
# Must be integer / boolean indices, we already counted for None and slices
50+
try:
51+
idx = as_xtensor(idx)
52+
except TypeError:
53+
idx = as_tensor(idx)
54+
if idx.type.ndim > 1:
55+
# Same error that xarray raises
56+
raise IndexError(
57+
"Unlabeled multi-dimensional array cannot be used for indexing"
58+
)
59+
# This is implicitly an XTensorVariable with dim matching the indexed one
60+
idx = xtensor_from_tensor(idx, dims=(indexed_dim,)[: idx.type.ndim])
61+
62+
if idx.type.dtype == "bool":
63+
if idx.type.ndim != 1:
64+
# xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379
65+
# Otherwise, it is always restricted to 1d boolean indexing arrays
66+
raise NotImplementedError(
67+
"Only 1d boolean indexing arrays are supported"
68+
)
69+
if idx.type.dims != (indexed_dim,):
70+
raise IndexError(
71+
"Boolean indexer should be unlabeled or on the same dimension to the indexed array. "
72+
f"Indexer is on {idx.type.dims} but the target dimension is {indexed_dim}."
73+
)
74+
75+
# Convert to nonzero indices
76+
idx = as_xtensor(idx.values.nonzero()[0], dims=idx.type.dims)
77+
78+
elif idx.type.dtype not in discrete_dtypes:
79+
raise TypeError("Numerical indices must be integers or boolean")
80+
return idx
81+
82+
83+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
84+
if dim_length is None:
85+
return None
86+
if isinstance(slc, Constant):
87+
d = slc.data
88+
start, stop, step = d.start, d.stop, d.step
89+
elif slc.owner is None:
90+
# It's a root variable no way of knowing what we're getting
91+
return None
92+
else:
93+
# It's a MakeSliceOp
94+
start, stop, step = slc.owner.inputs
95+
if isinstance(start, Constant):
96+
start = start.data
97+
else:
98+
return None
99+
if isinstance(stop, Constant):
100+
stop = stop.data
101+
else:
102+
return None
103+
if isinstance(step, Constant):
104+
step = step.data
105+
else:
106+
return None
107+
return len(range(*slice(start, stop, step).indices(dim_length)))
108+
109+
110+
class Index(XOp):
111+
__props__ = ()
112+
113+
def make_node(self, x, *idxs):
114+
x = as_xtensor(x)
115+
116+
if any(idx is Ellipsis for idx in idxs):
117+
if idxs.count(Ellipsis) > 1:
118+
raise IndexError("an index can only have a single ellipsis ('...')")
119+
# Convert intermediate Ellipsis to slice(None)
120+
ellipsis_loc = idxs.index(Ellipsis)
121+
n_implied_none_slices = x.type.ndim - (len(idxs) - 1)
122+
idxs = (
123+
*idxs[:ellipsis_loc],
124+
*((slice(None),) * n_implied_none_slices),
125+
*idxs[ellipsis_loc + 1 :],
126+
)
127+
128+
x_ndim = x.type.ndim
129+
x_dims = x.type.dims
130+
x_shape = x.type.shape
131+
out_dims = []
132+
out_shape = []
133+
134+
def combine_dim_info(idx_dim, idx_dim_shape):
135+
if idx_dim not in out_dims:
136+
# First information about the dimension length
137+
out_dims.append(idx_dim)
138+
out_shape.append(idx_dim_shape)
139+
else:
140+
# Dim already introduced in output by a previous index
141+
# Update static shape or raise if incompatible
142+
out_dim_pos = out_dims.index(idx_dim)
143+
out_dim_shape = out_shape[out_dim_pos]
144+
if out_dim_shape is None:
145+
# We don't know the size of the dimension yet
146+
out_shape[out_dim_pos] = idx_dim_shape
147+
elif idx_dim_shape is not None and idx_dim_shape != out_dim_shape:
148+
raise IndexError(
149+
f"Dimension of indexers mismatch for dim {idx_dim}"
150+
)
151+
152+
if len(idxs) > x_ndim:
153+
raise IndexError("Too many indices")
154+
155+
idxs = [
156+
as_idx_variable(idx, dim) for idx, dim in zip(idxs, x_dims, strict=False)
157+
]
158+
159+
for i, idx in enumerate(idxs):
160+
if isinstance(idx.type, SliceType):
161+
idx_dim = x_dims[i]
162+
idx_dim_shape = get_static_slice_length(idx, x_shape[i])
163+
combine_dim_info(idx_dim, idx_dim_shape)
164+
else:
165+
if idx.type.ndim == 0:
166+
# Scalar index, dimension is dropped
167+
continue
168+
169+
assert isinstance(idx.type, XTensorType)
170+
171+
idx_dims = idx.type.dims
172+
for idx_dim in idx_dims:
173+
idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)]
174+
combine_dim_info(idx_dim, idx_dim_shape)
175+
176+
for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]):
177+
# Add back any unindexed dimensions
178+
if dim_i not in out_dims:
179+
# If the dimension was not indexed, we keep it as is
180+
combine_dim_info(dim_i, shape_i)
181+
182+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
183+
return Apply(self, [x, *idxs], [output])
184+
185+
186+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.reduction
34
import pytensor.xtensor.rewriting.shape
45
import pytensor.xtensor.rewriting.vectorization
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from itertools import zip_longest
2+
3+
from pytensor import as_symbolic
4+
from pytensor.graph import Constant, node_rewriter
5+
from pytensor.tensor import TensorType, arange, specify_shape
6+
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing
7+
from pytensor.tensor.type_other import NoneTypeT, SliceType
8+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
9+
from pytensor.xtensor.indexing import Index
10+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
11+
from pytensor.xtensor.type import XTensorType
12+
13+
14+
def to_basic_idx(idx):
15+
if isinstance(idx.type, SliceType):
16+
if isinstance(idx, Constant):
17+
return idx.data
18+
elif idx.owner:
19+
# MakeSlice Op
20+
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
21+
return slice(
22+
*[
23+
None if isinstance(i.type, NoneTypeT) else i
24+
for i in idx.owner.inputs
25+
]
26+
)
27+
else:
28+
return idx
29+
if (
30+
isinstance(idx.type, XTensorType)
31+
and idx.type.ndim == 0
32+
and idx.type.dtype != bool
33+
):
34+
return idx.values
35+
raise TypeError("Cannot convert idx to basic idx")
36+
37+
38+
@register_xcanonicalize
39+
@node_rewriter(tracks=[Index])
40+
def lower_index(fgraph, node):
41+
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
42+
43+
xarray-like indexing has two modes:
44+
1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices.
45+
2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing.
46+
47+
An Index Op can combine both modes.
48+
To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing.
49+
We expand the dims of each index so they are as large as the number of output dimensions, place the indices that
50+
belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes.
51+
52+
For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]],
53+
This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have
54+
indices that have more than one dimension at the start.
55+
56+
In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension).
57+
However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices
58+
we have to convert the slices to equivalent advanced indices.
59+
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
60+
and then indexing it with the original slice. This index is then handled as a regular advanced index.
61+
62+
Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices
63+
are converted to the appropriate XTensorVariable by `Index.make_node`
64+
"""
65+
66+
x, *idxs = node.inputs
67+
[out] = node.outputs
68+
x_tensor = tensor_from_xtensor(x)
69+
70+
if all(
71+
(
72+
isinstance(idx.type, SliceType)
73+
or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0)
74+
)
75+
for idx in idxs
76+
):
77+
# Special case having just basic indexing
78+
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
79+
80+
else:
81+
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing
82+
# May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index
83+
x_dims = x.type.dims
84+
x_shape = tuple(x.shape)
85+
out_ndim = out.type.ndim
86+
out_dims = out.type.dims
87+
aligned_idxs = []
88+
basic_idx_axis = []
89+
# zip_longest adds the implicit slice(None)
90+
for i, (idx, x_dim) in enumerate(
91+
zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None)))
92+
):
93+
if isinstance(idx.type, SliceType):
94+
if not any(
95+
(
96+
isinstance(other_idx.type, XTensorType)
97+
and x_dim in other_idx.dims
98+
)
99+
for j, other_idx in enumerate(idxs)
100+
if j != i
101+
):
102+
# We can use basic indexing directly if no other index acts on this dimension
103+
# This is an optimization that avoids creating an unnecessary arange tensor
104+
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
105+
aligned_idxs.append(idx)
106+
basic_idx_axis.append(out_dims.index(x_dim))
107+
else:
108+
# Otherwise we need to convert the basic index into an equivalent advanced indexing
109+
# And align it so it interacts correctly with the other advanced indices
110+
adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)]
111+
ds_order = ["x"] * out_ndim
112+
ds_order[out_dims.index(x_dim)] = 0
113+
aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order))
114+
else:
115+
assert isinstance(idx.type, XTensorType)
116+
if idx.type.ndim == 0:
117+
# Scalar index, we can use it directly
118+
aligned_idxs.append(idx.values)
119+
else:
120+
# Vector index, we need to align the indexing dimensions with the base_dims
121+
ds_order = ["x"] * out_ndim
122+
for j, idx_dim in enumerate(idx.dims):
123+
ds_order[out_dims.index(idx_dim)] = j
124+
aligned_idxs.append(idx.values.dimshuffle(ds_order))
125+
126+
# Squeeze indexing dimensions that were not used because we kept basic indexing slices
127+
if basic_idx_axis:
128+
aligned_idxs = [
129+
idx.squeeze(axis=basic_idx_axis)
130+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
131+
else idx
132+
for idx in aligned_idxs
133+
]
134+
135+
x_tensor_indexed = x_tensor[tuple(aligned_idxs)]
136+
137+
if basic_idx_axis and _non_consecutive_adv_indexing(aligned_idxs):
138+
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
139+
# We need to transpose them back to the expected output order
140+
x_tensor_indexed_basic_dims = [out_dims[axis] for axis in basic_idx_axis]
141+
x_tensor_indexed_dims = [
142+
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims
143+
] + x_tensor_indexed_basic_dims
144+
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
145+
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)
146+
147+
# Add lost shape information
148+
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
149+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
150+
return [new_out]

0 commit comments

Comments
 (0)