Skip to content

Commit 9fe718a

Browse files
committed
Do not use Numba objmode for supported AdvancedSubtensor operations
MakeSlice uses scalar types
1 parent 36c55f5 commit 9fe718a

File tree

3 files changed

+145
-68
lines changed

3 files changed

+145
-68
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 40 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ def numba_funcify_FunctionGraph(
479479
)
480480

481481

482+
SET_OR_INC_OPS = IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
483+
484+
482485
def create_index_func(node, objmode=False):
483486
"""Create a Python function that assembles and uses an index on an array."""
484487

@@ -501,14 +504,13 @@ def convert_indices(indices, entry):
501504
else:
502505
raise ValueError()
503506

504-
set_or_inc = isinstance(
505-
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
506-
)
507+
op = node.op
508+
set_or_inc = isinstance(op, SET_OR_INC_OPS) # type: ignore
507509
index_start_idx = 1 + int(set_or_inc)
508510

509511
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
510512
op_indices = list(node.inputs[index_start_idx:])
511-
idx_list = getattr(node.op, "idx_list", None)
513+
idx_list = getattr(op, "idx_list", None)
512514

513515
indices_creation_src = (
514516
tuple(convert_indices(op_indices, idx) for idx in idx_list)
@@ -523,8 +525,7 @@ def convert_indices(indices, entry):
523525
indices_creation_src = f"indices = ({indices_creation_src})"
524526

525527
if set_or_inc:
526-
fn_name = "incsubtensor"
527-
if node.op.inplace:
528+
if op.inplace:
528529
index_prologue = f"z = {input_names[0]}"
529530
else:
530531
index_prologue = f"z = np.copy({input_names[0]})"
@@ -536,30 +537,17 @@ def convert_indices(indices, entry):
536537
else:
537538
y_name = input_names[1]
538539

539-
if node.op.set_instead_of_inc:
540+
if op.set_instead_of_inc:
541+
fn_name = "setsubtensor"
540542
index_body = f"z[indices] = {y_name}"
541543
else:
544+
fn_name = "incsubtensor"
542545
index_body = f"z[indices] += {y_name}"
543546
else:
544547
fn_name = "subtensor"
545548
index_prologue = ""
546549
index_body = f"z = {input_names[0]}[indices]"
547550

548-
if objmode:
549-
output_var = node.outputs[0]
550-
551-
if not set_or_inc:
552-
# Since `z` is being "created" while in object mode, it's
553-
# considered an "outgoing" variable and needs to be manually typed
554-
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
555-
else:
556-
output_sig = ""
557-
558-
index_body = f"""
559-
with objmode({output_sig}):
560-
{index_body}
561-
"""
562-
563551
subtensor_def_src = f"""
564552
def {fn_name}({", ".join(input_names)}):
565553
{index_prologue}
@@ -572,48 +560,43 @@ def {fn_name}({", ".join(input_names)}):
572560

573561

574562
@numba_funcify.register(Subtensor)
563+
@numba_funcify.register(IncSubtensor)
575564
@numba_funcify.register(AdvancedSubtensor1)
576-
def numba_funcify_Subtensor(op, node, **kwargs):
577-
objmode = isinstance(op, AdvancedSubtensor)
578-
if objmode:
579-
warnings.warn(
580-
("Numba will use object mode to allow run " "AdvancedSubtensor."),
581-
UserWarning,
582-
)
583-
584-
subtensor_def_src = create_index_func(node, objmode=objmode)
585-
586-
global_env = {"np": np}
587-
if objmode:
588-
global_env["objmode"] = numba.objmode
589-
565+
def numba_funcify_default_subtensor(op, node, **kwargs):
566+
function_name = "subtensor"
567+
if isinstance(op, SET_OR_INC_OPS): # type: ignore
568+
function_name = "setsubtensor" if op.set_instead_of_inc else "incsubtensor"
569+
subtensor_def_src = create_index_func(node)
590570
subtensor_fn = compile_function_src(
591-
subtensor_def_src, "subtensor", {**globals(), **global_env}
571+
subtensor_def_src, function_name, {**globals(), "np": np}
592572
)
593-
594573
return numba_njit(subtensor_fn, boundscheck=True)
595574

596575

597-
@numba_funcify.register(IncSubtensor)
598-
def numba_funcify_IncSubtensor(op, node, **kwargs):
599-
objmode = isinstance(op, AdvancedIncSubtensor)
600-
if objmode:
601-
warnings.warn(
602-
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
603-
UserWarning,
604-
)
605-
606-
incsubtensor_def_src = create_index_func(node, objmode=objmode)
607-
608-
global_env = {"np": np}
609-
if objmode:
610-
global_env["objmode"] = numba.objmode
576+
@numba_funcify.register(AdvancedSubtensor)
577+
@numba_funcify.register(AdvancedIncSubtensor)
578+
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
579+
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
580+
adv_idxs_dims = [
581+
idx.type.ndim
582+
for idx in idxs
583+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
584+
]
611585

612-
incsubtensor_fn = compile_function_src(
613-
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
614-
)
586+
if (
587+
# Numba does not support indexes with more than one dimension
588+
# Nor multiple vector indexes
589+
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
590+
# The default index implementation does not handle duplicate indices correctly
591+
or (
592+
isinstance(op, AdvancedIncSubtensor)
593+
and not op.set_instead_of_inc
594+
and not op.ignore_duplicates
595+
)
596+
):
597+
return generate_fallback_impl(op, node, **kwargs)
615598

616-
return numba_njit(incsubtensor_fn, boundscheck=True)
599+
return numba_funcify_default_subtensor(op, node, **kwargs)
617600

618601

619602
@numba_funcify.register(AdvancedIncSubtensor1)
@@ -713,7 +696,7 @@ def makeslice(*x):
713696

714697

715698
@numba_funcify.register(MakeSlice)
716-
def numba_funcify_MakeSlice(op, **kwargs):
699+
def numba_funcify_MakeSlice(op, node, **kwargs):
717700
return global_numba_func(makeslice)
718701

719702

pytensor/tensor/type_other.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def as_int_none_variable(x):
1818
return NoneConst
1919
elif NoneConst.equals(x):
2020
return x
21-
x = pytensor.tensor.as_tensor_variable(x, ndim=0)
21+
x = pytensor.scalar.as_scalar(x)
2222
if x.type.dtype not in integer_dtypes:
2323
raise TypeError("index must be integers")
2424
return x

tests/link/numba/test_basic.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -409,21 +409,46 @@ def test_AdvancedSubtensor1_out_of_bounds():
409409

410410

411411
@pytest.mark.parametrize(
412-
"x, indices",
412+
"x, indices, objmode_needed",
413413
[
414-
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
414+
(
415+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
416+
(0, [1, 2, 2, 3]),
417+
False,
418+
),
419+
(
420+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
421+
(np.array([True, False, False])),
422+
False,
423+
),
424+
(
425+
pt.as_tensor(np.arange(3 * 3).reshape((3, 3))),
426+
(np.eye(3).astype(bool)),
427+
True,
428+
),
429+
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
415430
(
416431
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
417432
([1, 2], slice(None), [3, 4]),
433+
True,
418434
),
419435
],
420436
)
421-
def test_AdvancedSubtensor(x, indices):
437+
@pytest.mark.filterwarnings("error")
438+
def test_AdvancedSubtensor(x, indices, objmode_needed):
422439
"""Test NumPy's advanced indexing in more than one dimension."""
423440
out_pt = x[indices]
424441
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
425442
out_fg = FunctionGraph([], [out_pt])
426-
compare_numba_and_py(out_fg, [])
443+
with (
444+
pytest.warns(
445+
UserWarning,
446+
match="Numba will use object mode to run AdvancedSubtensor's perform method",
447+
)
448+
if objmode_needed
449+
else contextlib.nullcontext()
450+
):
451+
compare_numba_and_py(out_fg, [])
427452

428453

429454
@pytest.mark.parametrize(
@@ -534,35 +559,96 @@ def test_AdvancedIncSubtensor1(x, y, indices):
534559

535560

536561
@pytest.mark.parametrize(
537-
"x, y, indices",
562+
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode",
538563
[
564+
(
565+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
566+
-np.arange(3 * 5).reshape(3, 5),
567+
(slice(None, None, 2), [1, 2, 3]),
568+
False,
569+
False,
570+
False,
571+
),
572+
(
573+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
574+
-np.arange(4 * 5).reshape(4, 5),
575+
(0, [1, 2, 2, 3]),
576+
True,
577+
False,
578+
True,
579+
),
580+
(
581+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
582+
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
583+
(np.array([True, False, False])),
584+
False,
585+
False,
586+
False,
587+
),
588+
(
589+
pt.as_tensor(np.arange(3 * 3).reshape((3, 3))),
590+
-np.arange(3),
591+
(np.eye(3).astype(bool)),
592+
False,
593+
True,
594+
True,
595+
),
539596
(
540597
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
541598
pt.as_tensor(rng.poisson(size=(2, 5))),
542599
([1, 2], [2, 3]),
600+
False,
601+
True,
602+
True,
543603
),
544604
(
545605
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
546606
pt.as_tensor(rng.poisson(size=(2, 4))),
547607
([1, 2], slice(None), [3, 4]),
608+
False,
609+
True,
610+
True,
548611
),
549612
pytest.param(
550613
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
551614
pt.as_tensor(rng.poisson(size=(2, 5))),
552615
([1, 1], [2, 2]),
616+
False,
617+
True,
618+
True,
553619
),
554620
],
555621
)
556-
def test_AdvancedIncSubtensor(x, y, indices):
622+
@pytest.mark.filterwarnings("error")
623+
def test_AdvancedIncSubtensor(
624+
x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode
625+
):
557626
out_pt = pt.set_subtensor(x[indices], y)
558627
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
559628
out_fg = FunctionGraph([], [out_pt])
560-
compare_numba_and_py(out_fg, [])
561629

562-
out_pt = pt.inc_subtensor(x[indices], y)
630+
with (
631+
pytest.warns(
632+
UserWarning,
633+
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
634+
)
635+
if set_requires_objmode
636+
else contextlib.nullcontext()
637+
):
638+
compare_numba_and_py(out_fg, [])
639+
640+
out_pt = pt.inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices)
563641
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
564642
out_fg = FunctionGraph([], [out_pt])
565-
compare_numba_and_py(out_fg, [])
643+
with (
644+
pytest.warns(
645+
UserWarning,
646+
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
647+
)
648+
if inc_requires_objmode
649+
else contextlib.nullcontext()
650+
):
651+
compare_numba_and_py(out_fg, [])
566652

567653
x_pt = x.type()
568654
out_pt = pt.set_subtensor(x_pt[indices], y)
@@ -571,7 +657,15 @@ def test_AdvancedIncSubtensor(x, y, indices):
571657
out_pt.owner.op.inplace = True
572658
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
573659
out_fg = FunctionGraph([x_pt], [out_pt])
574-
compare_numba_and_py(out_fg, [x.data])
660+
with (
661+
pytest.warns(
662+
UserWarning,
663+
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
664+
)
665+
if set_requires_objmode
666+
else contextlib.nullcontext()
667+
):
668+
compare_numba_and_py(out_fg, [x.data])
575669

576670

577671
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)