Skip to content

Commit 37ef829

Browse files
committed
Support consecutive vector indices in Numba backend
1 parent ae66e82 commit 37ef829

File tree

4 files changed

+298
-67
lines changed

4 files changed

+298
-67
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 143 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
66
from pytensor.link.utils import compile_function_src, unique_name_generator
77
from pytensor.tensor import TensorType
8+
from pytensor.tensor.rewriting.subtensor import is_full_slice
89
from pytensor.tensor.subtensor import (
910
AdvancedIncSubtensor,
1011
AdvancedIncSubtensor1,
@@ -13,6 +14,7 @@
1314
IncSubtensor,
1415
Subtensor,
1516
)
17+
from pytensor.tensor.type_other import NoneTypeT, SliceType
1618

1719

1820
@numba_funcify.register(Subtensor)
@@ -104,18 +106,72 @@ def {function_name}({", ".join(input_names)}):
104106
@numba_funcify.register(AdvancedSubtensor)
105107
@numba_funcify.register(AdvancedIncSubtensor)
106108
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
107-
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
108-
adv_idxs_dims = [
109-
idx.type.ndim
109+
if isinstance(op, AdvancedSubtensor):
110+
x, y, idxs = node.inputs[0], None, node.inputs[1:]
111+
else:
112+
x, y, *idxs = node.inputs
113+
114+
basic_idxs = [
115+
idx
110116
for idx in idxs
111-
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
117+
if (
118+
isinstance(idx.type, NoneTypeT)
119+
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
120+
)
112121
]
122+
adv_idxs = [
123+
{
124+
"axis": i,
125+
"dtype": idx.type.dtype,
126+
"bcast": idx.type.broadcastable,
127+
"ndim": idx.type.ndim,
128+
}
129+
for i, idx in enumerate(idxs)
130+
if isinstance(idx.type, TensorType)
131+
]
132+
133+
# Special case for consecutive consecutive vector indices
134+
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
135+
# Check that x is not broadcasted to y based on broadcastable info
136+
if len(x_bcast) < len(to_bcast):
137+
return True
138+
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
139+
if x_bcast_dim and not to_bcast_dim:
140+
return True
141+
return False
142+
143+
if (
144+
not basic_idxs
145+
and len(adv_idxs) >= 2
146+
# Must be integer vectors
147+
# Todo: we could allow shape=(1,) if this is the shape of x
148+
and all(
149+
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
150+
for adv_idx in adv_idxs
151+
)
152+
# Must be consecutive
153+
and not op.non_contiguous_adv_indexing(node)
154+
# y in set/inc_subtensor cannot be broadcasted
155+
and (
156+
y is None
157+
or not broadcasted_to(
158+
y.type.broadcastable,
159+
(
160+
x.type.broadcastable[: adv_idxs[0]["axis"]]
161+
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
162+
),
163+
)
164+
)
165+
):
166+
return numba_funcify_multiple_vector_indexing(op, node, **kwargs)
113167

168+
# Cases natively supported by Numba
114169
if (
115170
# Numba does not support indexes with more than one dimension
171+
any(idx["ndim"] > 1 for idx in adv_idxs)
116172
# Nor multiple vector indexes
117-
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
118-
# The default index implementation does not handle duplicate indices correctly
173+
or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
174+
# The default PyTensor implementation does not handle duplicate indices correctly
119175
or (
120176
isinstance(op, AdvancedIncSubtensor)
121177
and not op.set_instead_of_inc
@@ -127,6 +183,87 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
127183
return numba_funcify_default_subtensor(op, node, **kwargs)
128184

129185

186+
def numba_funcify_multiple_vector_indexing(
187+
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
188+
):
189+
# Special-case implementation for multiple consecutive vector indices (and set/incsubtensor)
190+
if isinstance(op, AdvancedSubtensor):
191+
y, idxs = None, node.inputs[1:]
192+
else:
193+
y, *idxs = node.inputs[1:]
194+
195+
first_axis = next(
196+
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
197+
)
198+
try:
199+
after_last_axis = next(
200+
i
201+
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
202+
if not isinstance(idx.type, TensorType)
203+
)
204+
except StopIteration:
205+
after_last_axis = len(idxs)
206+
207+
if isinstance(op, AdvancedSubtensor):
208+
209+
@numba_njit
210+
def advanced_subtensor_multiple_vector(x, *idxs):
211+
none_slices = idxs[:first_axis]
212+
vec_idxs = idxs[first_axis:after_last_axis]
213+
214+
x_shape = x.shape
215+
idx_shape = vec_idxs[0].shape
216+
shape_bef = x_shape[:first_axis]
217+
shape_aft = x_shape[after_last_axis:]
218+
out_shape = (*shape_bef, *idx_shape, *shape_aft)
219+
out_buffer = np.empty(out_shape, dtype=x.dtype)
220+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
221+
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
222+
return out_buffer
223+
224+
return advanced_subtensor_multiple_vector
225+
226+
elif op.set_instead_of_inc:
227+
inplace = op.inplace
228+
229+
@numba_njit
230+
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
231+
vec_idxs = idxs[first_axis:after_last_axis]
232+
x_shape = x.shape
233+
234+
if inplace:
235+
out = x
236+
else:
237+
out = x.copy()
238+
239+
for outer in np.ndindex(x_shape[:first_axis]):
240+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
241+
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
242+
return out
243+
244+
return advanced_set_subtensor_multiple_vector
245+
246+
else:
247+
inplace = op.inplace
248+
249+
@numba_njit
250+
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
251+
vec_idxs = idxs[first_axis:after_last_axis]
252+
x_shape = x.shape
253+
254+
if inplace:
255+
out = x
256+
else:
257+
out = x.copy()
258+
259+
for outer in np.ndindex(x_shape[:first_axis]):
260+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
261+
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
262+
return out
263+
264+
return advanced_inc_subtensor_multiple_vector
265+
266+
130267
@numba_funcify.register(AdvancedIncSubtensor1)
131268
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
132269
inplace = op.inplace

pytensor/tensor/subtensor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2937,6 +2937,31 @@ def grad(self, inpt, output_gradients):
29372937
gy = _sum_grad_over_bcasted_dims(y, gy)
29382938
return [gx, gy] + [DisconnectedType()() for _ in idxs]
29392939

2940+
@staticmethod
2941+
def non_contiguous_adv_indexing(node: Apply) -> bool:
2942+
"""
2943+
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
2944+
2945+
This function checks if the advanced indexing is non-contiguous,
2946+
in which case the advanced index dimensions are placed on the left of the
2947+
output array, regardless of their opriginal position.
2948+
2949+
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
2950+
2951+
2952+
Parameters
2953+
----------
2954+
node : Apply
2955+
The node of the AdvancedSubtensor operation.
2956+
2957+
Returns
2958+
-------
2959+
bool
2960+
True if the advanced indexing is non-contiguous, False otherwise.
2961+
"""
2962+
_, _, *idxs = node.inputs
2963+
return _non_contiguous_adv_indexing(idxs)
2964+
29402965

29412966
advanced_inc_subtensor = AdvancedIncSubtensor()
29422967
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)

tests/link/numba/test_basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,11 @@ def compare_numba_and_py(
228228
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]],
229229
inputs: Sequence["TensorLike"],
230230
assert_fn: Callable | None = None,
231+
*,
231232
numba_mode=numba_mode,
232233
py_mode=py_mode,
233234
updates=None,
235+
inplace: bool = False,
234236
eval_obj_mode: bool = True,
235237
) -> tuple[Callable, Any]:
236238
"""Function to compare python graph output and Numba compiled output for testing equality
@@ -276,7 +278,14 @@ def assert_fn(x, y):
276278
pytensor_py_fn = function(
277279
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
278280
)
279-
py_res = pytensor_py_fn(*inputs)
281+
282+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
283+
py_res = pytensor_py_fn(*test_inputs)
284+
285+
# Get some coverage (and catch errors in python mode before unreadable numba ones)
286+
if eval_obj_mode:
287+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
288+
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode)
280289

281290
pytensor_numba_fn = function(
282291
fn_inputs,
@@ -285,11 +294,9 @@ def assert_fn(x, y):
285294
accept_inplace=True,
286295
updates=updates,
287296
)
288-
numba_res = pytensor_numba_fn(*inputs)
289297

290-
# Get some coverage
291-
if eval_obj_mode:
292-
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
298+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
299+
numba_res = pytensor_numba_fn(*test_inputs)
293300

294301
if len(fn_outputs) > 1:
295302
for j, p in zip(numba_res, py_res, strict=True):

0 commit comments

Comments
 (0)