Skip to content

Commit 959b7f6

Browse files
committed
Do not use Numba objmode for supported AdvancedSubtensor operations
1 parent 36c55f5 commit 959b7f6

File tree

2 files changed

+53
-56
lines changed

2 files changed

+53
-56
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ def numba_funcify_FunctionGraph(
479479
)
480480

481481

482+
SET_OR_INC_OPS = IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
483+
484+
482485
def create_index_func(node, objmode=False):
483486
"""Create a Python function that assembles and uses an index on an array."""
484487

@@ -501,9 +504,7 @@ def convert_indices(indices, entry):
501504
else:
502505
raise ValueError()
503506

504-
set_or_inc = isinstance(
505-
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
506-
)
507+
set_or_inc = isinstance(node.op, SET_OR_INC_OPS)
507508
index_start_idx = 1 + int(set_or_inc)
508509

509510
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
@@ -545,21 +546,6 @@ def convert_indices(indices, entry):
545546
index_prologue = ""
546547
index_body = f"z = {input_names[0]}[indices]"
547548

548-
if objmode:
549-
output_var = node.outputs[0]
550-
551-
if not set_or_inc:
552-
# Since `z` is being "created" while in object mode, it's
553-
# considered an "outgoing" variable and needs to be manually typed
554-
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
555-
else:
556-
output_sig = ""
557-
558-
index_body = f"""
559-
with objmode({output_sig}):
560-
{index_body}
561-
"""
562-
563549
subtensor_def_src = f"""
564550
def {fn_name}({", ".join(input_names)}):
565551
{index_prologue}
@@ -572,48 +558,34 @@ def {fn_name}({", ".join(input_names)}):
572558

573559

574560
@numba_funcify.register(Subtensor)
561+
@numba_funcify.register(IncSubtensor)
575562
@numba_funcify.register(AdvancedSubtensor1)
576-
def numba_funcify_Subtensor(op, node, **kwargs):
577-
objmode = isinstance(op, AdvancedSubtensor)
578-
if objmode:
579-
warnings.warn(
580-
("Numba will use object mode to allow run " "AdvancedSubtensor."),
581-
UserWarning,
582-
)
583-
584-
subtensor_def_src = create_index_func(node, objmode=objmode)
585-
586-
global_env = {"np": np}
587-
if objmode:
588-
global_env["objmode"] = numba.objmode
589-
563+
def numba_funcify_default_subtensor(op, node, **kwargs):
564+
function_name = "subtensor"
565+
if isinstance(op, SET_OR_INC_OPS):
566+
function_name = "setsubtensor" if op.set_instead_of_inc else "incsubtensor"
567+
subtensor_def_src = create_index_func(node)
590568
subtensor_fn = compile_function_src(
591-
subtensor_def_src, "subtensor", {**globals(), **global_env}
569+
subtensor_def_src, function_name, {**globals(), "np": np}
592570
)
593-
594571
return numba_njit(subtensor_fn, boundscheck=True)
595572

596573

597-
@numba_funcify.register(IncSubtensor)
598-
def numba_funcify_IncSubtensor(op, node, **kwargs):
599-
objmode = isinstance(op, AdvancedIncSubtensor)
600-
if objmode:
601-
warnings.warn(
602-
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
603-
UserWarning,
604-
)
605-
606-
incsubtensor_def_src = create_index_func(node, objmode=objmode)
607-
608-
global_env = {"np": np}
609-
if objmode:
610-
global_env["objmode"] = numba.objmode
574+
@numba_funcify.register(AdvancedSubtensor)
575+
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
576+
idxs = node.inputs[1:]
577+
adv_idxs_dims = [
578+
idx.type.ndim
579+
for idx in idxs
580+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
581+
]
611582

612-
incsubtensor_fn = compile_function_src(
613-
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
614-
)
583+
# Numba does not support indexes with more than one dimension
584+
# Nor multiple vector indexes
585+
if len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1:
586+
return generate_fallback_impl(op, node, **kwargs)
615587

616-
return numba_njit(incsubtensor_fn, boundscheck=True)
588+
return numba_funcify_default_subtensor(op, node, **kwargs)
617589

618590

619591
@numba_funcify.register(AdvancedIncSubtensor1)

tests/link/numba/test_basic.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,21 +409,46 @@ def test_AdvancedSubtensor1_out_of_bounds():
409409

410410

411411
@pytest.mark.parametrize(
412-
"x, indices",
412+
"x, indices, objmode_needed",
413413
[
414-
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
414+
(
415+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
416+
(0, [1, 2, 2, 3]),
417+
False,
418+
),
419+
(
420+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
421+
(np.array([True, False, False])),
422+
False,
423+
),
424+
(
425+
pt.as_tensor(np.arange(3 * 3).reshape((3, 3))),
426+
(np.eye(3).astype(bool)),
427+
True,
428+
),
429+
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
415430
(
416431
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
417432
([1, 2], slice(None), [3, 4]),
433+
True,
418434
),
419435
],
420436
)
421-
def test_AdvancedSubtensor(x, indices):
437+
@pytest.mark.filterwarnings("error")
438+
def test_AdvancedSubtensor(x, indices, objmode_needed):
422439
"""Test NumPy's advanced indexing in more than one dimension."""
423440
out_pt = x[indices]
424441
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
425442
out_fg = FunctionGraph([], [out_pt])
426-
compare_numba_and_py(out_fg, [])
443+
with (
444+
pytest.warns(
445+
UserWarning,
446+
match="Numba will use object mode to run AdvancedSubtensor's perform method",
447+
)
448+
if objmode_needed
449+
else contextlib.nullcontext()
450+
):
451+
compare_numba_and_py(out_fg, [])
427452

428453

429454
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)