Skip to content

Commit de9e283

Browse files
committed
Implement index operations for XTensorVariables
1 parent b94af96 commit de9e283

File tree

6 files changed

+768
-3
lines changed

6 files changed

+768
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from pytensor.xtensor.shape import concat
99
from pytensor.xtensor.type import (
10-
XTensorType,
1110
as_xtensor,
1211
xtensor,
1312
xtensor_constant,

pytensor/xtensor/indexing.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor.basic import as_tensor
11+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
12+
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
13+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
14+
15+
16+
def as_idx_variable(idx, indexed_dim: str):
17+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
18+
raise TypeError(
19+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
20+
)
21+
if isinstance(idx, slice):
22+
idx = make_slice(idx)
23+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
24+
pass
25+
elif (
26+
isinstance(idx, tuple)
27+
and len(idx) == 2
28+
and (
29+
isinstance(idx[0], str)
30+
or (
31+
isinstance(idx[0], tuple | list)
32+
and all(isinstance(d, str) for d in idx[0])
33+
)
34+
)
35+
):
36+
# Special case for ("x", array) that xarray supports
37+
dim, idx = idx
38+
if isinstance(idx, Variable) and isinstance(idx.type, XTensorType):
39+
raise TypeError(
40+
"Giving a dimension name to an XTensorVariable indexer is not supported"
41+
)
42+
if isinstance(dim, str):
43+
dims = (dim,)
44+
else:
45+
dims = tuple(dim)
46+
idx = as_xtensor(as_tensor(idx), dims=dims)
47+
else:
48+
# Must be integer indices, we already counted for None and slices
49+
try:
50+
idx = as_xtensor(idx)
51+
except TypeError:
52+
idx = as_tensor(idx)
53+
if idx.type.ndim > 1:
54+
# Same error that xarray raises
55+
raise IndexError(
56+
"Unlabeled multi-dimensional array cannot be used for indexing"
57+
)
58+
# This is implicitly an XTensorVariable with dim matching the indexed one
59+
idx = xtensor_from_tensor(idx, dims=(indexed_dim,)[: idx.type.ndim])
60+
61+
if idx.type.dtype == "bool":
62+
if idx.type.ndim != 1:
63+
# xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379
64+
# Otherwise, it is always restricted to 1d boolean indexing arrays
65+
raise NotImplementedError(
66+
"Only 1d boolean indexing arrays are supported"
67+
)
68+
if idx.type.dims != (indexed_dim,):
69+
raise IndexError(
70+
"Boolean indexer should be unlabeled or on the same dimension to the indexed array. "
71+
f"Indexer is on {idx.type.dims} but the target dimension is {indexed_dim}."
72+
)
73+
74+
# Convert to nonzero indices
75+
idx = as_xtensor(idx.values.nonzero()[0], dims=idx.type.dims)
76+
77+
elif idx.type.dtype not in discrete_dtypes:
78+
raise TypeError("Numerical indices must be integers or boolean")
79+
return idx
80+
81+
82+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
83+
if dim_length is None:
84+
return None
85+
if isinstance(slc, Constant):
86+
d = slc.data
87+
start, stop, step = d.start, d.stop, d.step
88+
elif slc.owner is None:
89+
# It's a root variable no way of knowing what we're getting
90+
return None
91+
else:
92+
# It's a MakeSliceOp
93+
start, stop, step = slc.owner.inputs
94+
if isinstance(start, Constant):
95+
start = start.data
96+
else:
97+
return None
98+
if isinstance(stop, Constant):
99+
stop = stop.data
100+
else:
101+
return None
102+
if isinstance(step, Constant):
103+
step = step.data
104+
else:
105+
return None
106+
return len(range(*slice(start, stop, step).indices(dim_length)))
107+
108+
109+
class Index(XOp):
110+
__props__ = ()
111+
112+
def make_node(self, x, *idxs):
113+
x = as_xtensor(x)
114+
115+
if any(idx is Ellipsis for idx in idxs):
116+
if idxs.count(Ellipsis) > 1:
117+
raise IndexError("an index can only have a single ellipsis ('...')")
118+
# Convert intermediate Ellipsis to slice(None)
119+
ellipsis_loc = idxs.index(Ellipsis)
120+
n_implied_none_slices = x.type.ndim - (len(idxs) - 1)
121+
idxs = (
122+
*idxs[:ellipsis_loc],
123+
*((slice(None),) * n_implied_none_slices),
124+
*idxs[ellipsis_loc + 1 :],
125+
)
126+
127+
x_ndim = x.type.ndim
128+
x_dims = x.type.dims
129+
x_shape = x.type.shape
130+
out_dims = []
131+
out_shape = []
132+
133+
def combine_dim_info(idx_dim, idx_dim_shape):
134+
if idx_dim not in out_dims:
135+
# First information about the dimension length
136+
out_dims.append(idx_dim)
137+
out_shape.append(idx_dim_shape)
138+
else:
139+
# Dim already introduced in output by a previous index
140+
# Update static shape or raise if incompatible
141+
out_dim_pos = out_dims.index(idx_dim)
142+
out_dim_shape = out_shape[out_dim_pos]
143+
if out_dim_shape is None:
144+
# We don't know the size of the dimension yet
145+
out_shape[out_dim_pos] = idx_dim_shape
146+
elif idx_dim_shape is not None and idx_dim_shape != out_dim_shape:
147+
raise IndexError(
148+
f"Dimension of indexers mismatch for dim {idx_dim}"
149+
)
150+
151+
if len(idxs) > x_ndim:
152+
raise IndexError("Too many indices")
153+
154+
idxs = [
155+
as_idx_variable(idx, dim) for idx, dim in zip(idxs, x_dims, strict=False)
156+
]
157+
158+
for i, idx in enumerate(idxs):
159+
if isinstance(idx.type, SliceType):
160+
idx_dim = x_dims[i]
161+
idx_dim_shape = get_static_slice_length(idx, x_shape[i])
162+
combine_dim_info(idx_dim, idx_dim_shape)
163+
else:
164+
if idx.type.ndim == 0:
165+
# Scalar index, dimension is dropped
166+
continue
167+
168+
assert isinstance(idx.type, XTensorType)
169+
170+
idx_dims = idx.type.dims
171+
for idx_dim in idx_dims:
172+
idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)]
173+
combine_dim_info(idx_dim, idx_dim_shape)
174+
175+
for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]):
176+
# Add back any unindexed dimensions
177+
if dim_i not in out_dims:
178+
# If the dimension was not indexed, we keep it as is
179+
combine_dim_info(dim_i, shape_i)
180+
181+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
182+
return Apply(self, [x, *idxs], [output])
183+
184+
185+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.reduction
34
import pytensor.xtensor.rewriting.shape
45
import pytensor.xtensor.rewriting.vectorization
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from itertools import zip_longest
2+
3+
from pytensor import as_symbolic
4+
from pytensor.graph import Constant, node_rewriter
5+
from pytensor.tensor import TensorType, arange, specify_shape
6+
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing
7+
from pytensor.tensor.type_other import NoneTypeT, SliceType
8+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
9+
from pytensor.xtensor.indexing import Index
10+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
11+
from pytensor.xtensor.type import XTensorType
12+
13+
14+
def to_basic_idx(idx):
15+
if isinstance(idx.type, SliceType):
16+
if isinstance(idx, Constant):
17+
return idx.data
18+
elif idx.owner:
19+
# MakeSlice Op
20+
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
21+
return slice(
22+
*[
23+
None if isinstance(i.type, NoneTypeT) else i
24+
for i in idx.owner.inputs
25+
]
26+
)
27+
else:
28+
return idx
29+
if (
30+
isinstance(idx.type, XTensorType)
31+
and idx.type.ndim == 0
32+
and idx.type.dtype != bool
33+
):
34+
return idx.values
35+
raise TypeError("Cannot convert idx to basic idx")
36+
37+
38+
@register_xcanonicalize
39+
@node_rewriter(tracks=[Index])
40+
def lower_index(fgraph, node):
41+
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
42+
43+
xarray-like indexing has two modes:
44+
1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices.
45+
2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing.
46+
47+
An Index Op can combine both modes.
48+
To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing.
49+
We expand the dims of each index so they are as large as the number of output dimensions, place the indices that
50+
belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes.
51+
52+
For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]],
53+
This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have
54+
indices that have more than one dimension at the start.
55+
56+
In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension).
57+
However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices
58+
we have to convert them slices to equivalent advanced indices.
59+
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
60+
and then indexing it with the original slice. This index is then handled as a regular advanced index.
61+
62+
Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices
63+
are converted to the appropriate XTensorVariable by `Index.make_node`
64+
"""
65+
66+
x, *idxs = node.inputs
67+
[out] = node.outputs
68+
x_tensor = tensor_from_xtensor(x)
69+
70+
if all(
71+
(
72+
isinstance(idx.type, SliceType)
73+
or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0)
74+
)
75+
for idx in idxs
76+
):
77+
# Special case having just basic indexing
78+
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
79+
80+
else:
81+
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing
82+
# May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index
83+
x_dims = x.type.dims
84+
x_shape = tuple(x.shape)
85+
out_ndim = out.type.ndim
86+
out_dims = out.type.dims
87+
aligned_idxs = []
88+
basic_idx_axis = []
89+
# zip_longest adds the implicit slice(None)
90+
for i, (idx, x_dim) in enumerate(
91+
zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None)))
92+
):
93+
if isinstance(idx.type, SliceType):
94+
if not any(
95+
(
96+
isinstance(other_idx.type, XTensorType)
97+
and x_dim in other_idx.dims
98+
)
99+
for j, other_idx in enumerate(idxs)
100+
if j != i
101+
):
102+
# We can use basic indexing directly if no other index acts on this dimension
103+
# This is an optimization that avoids creating an unnecessary arange tensor
104+
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
105+
aligned_idxs.append(idx)
106+
basic_idx_axis.append(out_dims.index(x_dim))
107+
else:
108+
# Otherwise we need to convert the basic index into an equivalent advanced indexing
109+
# And align it so it interacts correctly with the other advanced indices
110+
adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)]
111+
ds_order = ["x"] * out_ndim
112+
ds_order[out_dims.index(x_dim)] = 0
113+
aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order))
114+
else:
115+
assert isinstance(idx.type, XTensorType)
116+
if idx.type.ndim == 0:
117+
# Scalar index, we can use it directly
118+
aligned_idxs.append(idx.values)
119+
else:
120+
# Vector index, we need to align the indexing dimensions with the base_dims
121+
ds_order = ["x"] * out_ndim
122+
for j, idx_dim in enumerate(idx.dims):
123+
ds_order[out_dims.index(idx_dim)] = j
124+
aligned_idxs.append(idx.values.dimshuffle(ds_order))
125+
126+
# Squeeze indexing dimensions that were not used because we kept basic indexing slices
127+
if basic_idx_axis:
128+
aligned_idxs = [
129+
idx.squeeze(axis=basic_idx_axis)
130+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
131+
else idx
132+
for idx in aligned_idxs
133+
]
134+
135+
x_tensor_indexed = x_tensor[tuple(aligned_idxs)]
136+
137+
if basic_idx_axis and _non_consecutive_adv_indexing(aligned_idxs):
138+
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
139+
# We need to transpose them back to the expected output order
140+
x_tensor_indexed_basic_dims = [out_dims[axis] for axis in basic_idx_axis]
141+
x_tensor_indexed_dims = [
142+
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims
143+
] + x_tensor_indexed_basic_dims
144+
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
145+
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)
146+
147+
# Add lost shape information
148+
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
149+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
150+
return [new_out]

0 commit comments

Comments
 (0)