Skip to content

Commit 65f1385

Browse files
committed
Support consecutive integer vector indexing in Numba backend
1 parent ae66e82 commit 65f1385

File tree

4 files changed

+300
-67
lines changed

4 files changed

+300
-67
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 145 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
66
from pytensor.link.utils import compile_function_src, unique_name_generator
77
from pytensor.tensor import TensorType
8+
from pytensor.tensor.rewriting.subtensor import is_full_slice
89
from pytensor.tensor.subtensor import (
910
AdvancedIncSubtensor,
1011
AdvancedIncSubtensor1,
@@ -13,6 +14,7 @@
1314
IncSubtensor,
1415
Subtensor,
1516
)
17+
from pytensor.tensor.type_other import NoneTypeT, SliceType
1618

1719

1820
@numba_funcify.register(Subtensor)
@@ -104,18 +106,73 @@ def {function_name}({", ".join(input_names)}):
104106
@numba_funcify.register(AdvancedSubtensor)
105107
@numba_funcify.register(AdvancedIncSubtensor)
106108
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
107-
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
108-
adv_idxs_dims = [
109-
idx.type.ndim
109+
if isinstance(op, AdvancedSubtensor):
110+
x, y, idxs = node.inputs[0], None, node.inputs[1:]
111+
else:
112+
x, y, *idxs = node.inputs
113+
114+
basic_idxs = [
115+
idx
110116
for idx in idxs
111-
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
117+
if (
118+
isinstance(idx.type, NoneTypeT)
119+
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
120+
)
121+
]
122+
adv_idxs = [
123+
{
124+
"axis": i,
125+
"dtype": idx.type.dtype,
126+
"bcast": idx.type.broadcastable,
127+
"ndim": idx.type.ndim,
128+
}
129+
for i, idx in enumerate(idxs)
130+
if isinstance(idx.type, TensorType)
112131
]
113132

133+
# Special case for consecutive consecutive vector indices
134+
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
135+
# Check that x is not broadcasted to y based on broadcastable info
136+
if len(x_bcast) < len(to_bcast):
137+
return True
138+
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
139+
if x_bcast_dim and not to_bcast_dim:
140+
return True
141+
return False
142+
143+
# Special implementation for consecutive vector indices
114144
if (
145+
not basic_idxs
146+
and len(adv_idxs) >= 2
147+
# Must be integer vectors
148+
# Todo: we could allow shape=(1,) if this is the shape of x
149+
and all(
150+
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
151+
for adv_idx in adv_idxs
152+
)
153+
# Must be consecutive
154+
and not op.non_contiguous_adv_indexing(node)
155+
# y in set/inc_subtensor cannot be broadcasted
156+
and (
157+
y is None
158+
or not broadcasted_to(
159+
y.type.broadcastable,
160+
(
161+
x.type.broadcastable[: adv_idxs[0]["axis"]]
162+
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
163+
),
164+
)
165+
)
166+
):
167+
return numba_funcify_multiple_vector_integer_indexing(op, node, **kwargs)
168+
169+
# Other cases not natively supported by Numba (fallback to obj-mode)
170+
if not (
115171
# Numba does not support indexes with more than one dimension
172+
any(idx["ndim"] > 1 for idx in adv_idxs)
116173
# Nor multiple vector indexes
117-
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
118-
# The default index implementation does not handle duplicate indices correctly
174+
or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
175+
# The default PyTensor implementation does not handle duplicate indices correctly
119176
or (
120177
isinstance(op, AdvancedIncSubtensor)
121178
and not op.set_instead_of_inc
@@ -124,9 +181,91 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
124181
):
125182
return generate_fallback_impl(op, node, **kwargs)
126183

184+
# What's left should all be supported natively by numba
127185
return numba_funcify_default_subtensor(op, node, **kwargs)
128186

129187

188+
def numba_funcify_multiple_vector_integer_indexing(
189+
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
190+
):
191+
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
192+
if isinstance(op, AdvancedSubtensor):
193+
y, idxs = None, node.inputs[1:]
194+
else:
195+
y, *idxs = node.inputs[1:]
196+
197+
first_axis = next(
198+
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
199+
)
200+
try:
201+
after_last_axis = next(
202+
i
203+
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
204+
if not isinstance(idx.type, TensorType)
205+
)
206+
except StopIteration:
207+
after_last_axis = len(idxs)
208+
209+
if isinstance(op, AdvancedSubtensor):
210+
211+
@numba_njit
212+
def advanced_subtensor_multiple_vector(x, *idxs):
213+
none_slices = idxs[:first_axis]
214+
vec_idxs = idxs[first_axis:after_last_axis]
215+
216+
x_shape = x.shape
217+
idx_shape = vec_idxs[0].shape
218+
shape_bef = x_shape[:first_axis]
219+
shape_aft = x_shape[after_last_axis:]
220+
out_shape = (*shape_bef, *idx_shape, *shape_aft)
221+
out_buffer = np.empty(out_shape, dtype=x.dtype)
222+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
223+
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
224+
return out_buffer
225+
226+
return advanced_subtensor_multiple_vector
227+
228+
elif op.set_instead_of_inc:
229+
inplace = op.inplace
230+
231+
@numba_njit
232+
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
233+
vec_idxs = idxs[first_axis:after_last_axis]
234+
x_shape = x.shape
235+
236+
if inplace:
237+
out = x
238+
else:
239+
out = x.copy()
240+
241+
for outer in np.ndindex(x_shape[:first_axis]):
242+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
243+
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
244+
return out
245+
246+
return advanced_set_subtensor_multiple_vector
247+
248+
else:
249+
inplace = op.inplace
250+
251+
@numba_njit
252+
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
253+
vec_idxs = idxs[first_axis:after_last_axis]
254+
x_shape = x.shape
255+
256+
if inplace:
257+
out = x
258+
else:
259+
out = x.copy()
260+
261+
for outer in np.ndindex(x_shape[:first_axis]):
262+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
263+
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
264+
return out
265+
266+
return advanced_inc_subtensor_multiple_vector
267+
268+
130269
@numba_funcify.register(AdvancedIncSubtensor1)
131270
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
132271
inplace = op.inplace

pytensor/tensor/subtensor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2937,6 +2937,31 @@ def grad(self, inpt, output_gradients):
29372937
gy = _sum_grad_over_bcasted_dims(y, gy)
29382938
return [gx, gy] + [DisconnectedType()() for _ in idxs]
29392939

2940+
@staticmethod
2941+
def non_contiguous_adv_indexing(node: Apply) -> bool:
2942+
"""
2943+
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
2944+
2945+
This function checks if the advanced indexing is non-contiguous,
2946+
in which case the advanced index dimensions are placed on the left of the
2947+
output array, regardless of their opriginal position.
2948+
2949+
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
2950+
2951+
2952+
Parameters
2953+
----------
2954+
node : Apply
2955+
The node of the AdvancedSubtensor operation.
2956+
2957+
Returns
2958+
-------
2959+
bool
2960+
True if the advanced indexing is non-contiguous, False otherwise.
2961+
"""
2962+
_, _, *idxs = node.inputs
2963+
return _non_contiguous_adv_indexing(idxs)
2964+
29402965

29412966
advanced_inc_subtensor = AdvancedIncSubtensor()
29422967
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)

tests/link/numba/test_basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,11 @@ def compare_numba_and_py(
228228
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]],
229229
inputs: Sequence["TensorLike"],
230230
assert_fn: Callable | None = None,
231+
*,
231232
numba_mode=numba_mode,
232233
py_mode=py_mode,
233234
updates=None,
235+
inplace: bool = False,
234236
eval_obj_mode: bool = True,
235237
) -> tuple[Callable, Any]:
236238
"""Function to compare python graph output and Numba compiled output for testing equality
@@ -276,7 +278,14 @@ def assert_fn(x, y):
276278
pytensor_py_fn = function(
277279
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
278280
)
279-
py_res = pytensor_py_fn(*inputs)
281+
282+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
283+
py_res = pytensor_py_fn(*test_inputs)
284+
285+
# Get some coverage (and catch errors in python mode before unreadable numba ones)
286+
if eval_obj_mode:
287+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
288+
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode)
280289

281290
pytensor_numba_fn = function(
282291
fn_inputs,
@@ -285,11 +294,9 @@ def assert_fn(x, y):
285294
accept_inplace=True,
286295
updates=updates,
287296
)
288-
numba_res = pytensor_numba_fn(*inputs)
289297

290-
# Get some coverage
291-
if eval_obj_mode:
292-
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
298+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
299+
numba_res = pytensor_numba_fn(*test_inputs)
293300

294301
if len(fn_outputs) > 1:
295302
for j, p in zip(numba_res, py_res, strict=True):

0 commit comments

Comments
 (0)