Skip to content

Commit 8fd718c

Browse files
committed
Do not use Numba objmode for supported AdvancedSubtensor operations
1 parent 36c55f5 commit 8fd718c

File tree

2 files changed

+59
-60
lines changed

2 files changed

+59
-60
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ def numba_funcify_FunctionGraph(
479479
)
480480

481481

482+
SET_OR_INC_OPS = IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
483+
484+
482485
def create_index_func(node, objmode=False):
483486
"""Create a Python function that assembles and uses an index on an array."""
484487

@@ -501,14 +504,13 @@ def convert_indices(indices, entry):
501504
else:
502505
raise ValueError()
503506

504-
set_or_inc = isinstance(
505-
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
506-
)
507+
op = node.op
508+
set_or_inc = isinstance(op, SET_OR_INC_OPS) # type: ignore
507509
index_start_idx = 1 + int(set_or_inc)
508510

509511
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
510512
op_indices = list(node.inputs[index_start_idx:])
511-
idx_list = getattr(node.op, "idx_list", None)
513+
idx_list = getattr(op, "idx_list", None)
512514

513515
indices_creation_src = (
514516
tuple(convert_indices(op_indices, idx) for idx in idx_list)
@@ -523,8 +525,7 @@ def convert_indices(indices, entry):
523525
indices_creation_src = f"indices = ({indices_creation_src})"
524526

525527
if set_or_inc:
526-
fn_name = "incsubtensor"
527-
if node.op.inplace:
528+
if op.inplace:
528529
index_prologue = f"z = {input_names[0]}"
529530
else:
530531
index_prologue = f"z = np.copy({input_names[0]})"
@@ -536,30 +537,17 @@ def convert_indices(indices, entry):
536537
else:
537538
y_name = input_names[1]
538539

539-
if node.op.set_instead_of_inc:
540+
if op.set_instead_of_inc:
541+
fn_name = "setsubtensor"
540542
index_body = f"z[indices] = {y_name}"
541543
else:
544+
fn_name = "incsubtensor"
542545
index_body = f"z[indices] += {y_name}"
543546
else:
544547
fn_name = "subtensor"
545548
index_prologue = ""
546549
index_body = f"z = {input_names[0]}[indices]"
547550

548-
if objmode:
549-
output_var = node.outputs[0]
550-
551-
if not set_or_inc:
552-
# Since `z` is being "created" while in object mode, it's
553-
# considered an "outgoing" variable and needs to be manually typed
554-
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
555-
else:
556-
output_sig = ""
557-
558-
index_body = f"""
559-
with objmode({output_sig}):
560-
{index_body}
561-
"""
562-
563551
subtensor_def_src = f"""
564552
def {fn_name}({", ".join(input_names)}):
565553
{index_prologue}
@@ -572,48 +560,34 @@ def {fn_name}({", ".join(input_names)}):
572560

573561

574562
@numba_funcify.register(Subtensor)
563+
@numba_funcify.register(IncSubtensor)
575564
@numba_funcify.register(AdvancedSubtensor1)
576-
def numba_funcify_Subtensor(op, node, **kwargs):
577-
objmode = isinstance(op, AdvancedSubtensor)
578-
if objmode:
579-
warnings.warn(
580-
("Numba will use object mode to allow run " "AdvancedSubtensor."),
581-
UserWarning,
582-
)
583-
584-
subtensor_def_src = create_index_func(node, objmode=objmode)
585-
586-
global_env = {"np": np}
587-
if objmode:
588-
global_env["objmode"] = numba.objmode
589-
565+
def numba_funcify_default_subtensor(op, node, **kwargs):
566+
function_name = "subtensor"
567+
if isinstance(op, SET_OR_INC_OPS): # type: ignore
568+
function_name = "setsubtensor" if op.set_instead_of_inc else "incsubtensor"
569+
subtensor_def_src = create_index_func(node)
590570
subtensor_fn = compile_function_src(
591-
subtensor_def_src, "subtensor", {**globals(), **global_env}
571+
subtensor_def_src, function_name, {**globals(), "np": np}
592572
)
593-
594573
return numba_njit(subtensor_fn, boundscheck=True)
595574

596575

597-
@numba_funcify.register(IncSubtensor)
598-
def numba_funcify_IncSubtensor(op, node, **kwargs):
599-
objmode = isinstance(op, AdvancedIncSubtensor)
600-
if objmode:
601-
warnings.warn(
602-
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
603-
UserWarning,
604-
)
605-
606-
incsubtensor_def_src = create_index_func(node, objmode=objmode)
607-
608-
global_env = {"np": np}
609-
if objmode:
610-
global_env["objmode"] = numba.objmode
576+
@numba_funcify.register(AdvancedSubtensor)
577+
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
578+
idxs = node.inputs[1:]
579+
adv_idxs_dims = [
580+
idx.type.ndim
581+
for idx in idxs
582+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
583+
]
611584

612-
incsubtensor_fn = compile_function_src(
613-
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
614-
)
585+
# Numba does not support indexes with more than one dimension
586+
# Nor multiple vector indexes
587+
if len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1:
588+
return generate_fallback_impl(op, node, **kwargs)
615589

616-
return numba_njit(incsubtensor_fn, boundscheck=True)
590+
return numba_funcify_default_subtensor(op, node, **kwargs)
617591

618592

619593
@numba_funcify.register(AdvancedIncSubtensor1)

tests/link/numba/test_basic.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,21 +409,46 @@ def test_AdvancedSubtensor1_out_of_bounds():
409409

410410

411411
@pytest.mark.parametrize(
412-
"x, indices",
412+
"x, indices, objmode_needed",
413413
[
414-
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
414+
(
415+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
416+
(0, [1, 2, 2, 3]),
417+
False,
418+
),
419+
(
420+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
421+
(np.array([True, False, False])),
422+
False,
423+
),
424+
(
425+
pt.as_tensor(np.arange(3 * 3).reshape((3, 3))),
426+
(np.eye(3).astype(bool)),
427+
True,
428+
),
429+
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
415430
(
416431
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
417432
([1, 2], slice(None), [3, 4]),
433+
True,
418434
),
419435
],
420436
)
421-
def test_AdvancedSubtensor(x, indices):
437+
@pytest.mark.filterwarnings("error")
438+
def test_AdvancedSubtensor(x, indices, objmode_needed):
422439
"""Test NumPy's advanced indexing in more than one dimension."""
423440
out_pt = x[indices]
424441
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
425442
out_fg = FunctionGraph([], [out_pt])
426-
compare_numba_and_py(out_fg, [])
443+
with (
444+
pytest.warns(
445+
UserWarning,
446+
match="Numba will use object mode to run AdvancedSubtensor's perform method",
447+
)
448+
if objmode_needed
449+
else contextlib.nullcontext()
450+
):
451+
compare_numba_and_py(out_fg, [])
427452

428453

429454
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)