Skip to content

Commit c0e2255

Browse files
committed
Do not use Numba objmode for supported AdvancedSubtensor operations
Use ScalarTypes in MakeSlice for compatibility with Numba
1 parent 36c55f5 commit c0e2255

File tree

4 files changed

+163
-80
lines changed

4 files changed

+163
-80
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 42 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,13 @@ def numba_funcify_FunctionGraph(
479479
)
480480

481481

482-
def create_index_func(node, objmode=False):
482+
SET_OR_INC_OPS = IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
483+
484+
485+
@numba_funcify.register(Subtensor)
486+
@numba_funcify.register(IncSubtensor)
487+
@numba_funcify.register(AdvancedSubtensor1)
488+
def numba_funcify_default_subtensor(op, node, **kwargs):
483489
"""Create a Python function that assembles and uses an index on an array."""
484490

485491
unique_names = unique_name_generator(
@@ -501,14 +507,12 @@ def convert_indices(indices, entry):
501507
else:
502508
raise ValueError()
503509

504-
set_or_inc = isinstance(
505-
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
506-
)
510+
set_or_inc = isinstance(op, SET_OR_INC_OPS) # type: ignore
507511
index_start_idx = 1 + int(set_or_inc)
508512

509513
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
510514
op_indices = list(node.inputs[index_start_idx:])
511-
idx_list = getattr(node.op, "idx_list", None)
515+
idx_list = getattr(op, "idx_list", None)
512516

513517
indices_creation_src = (
514518
tuple(convert_indices(op_indices, idx) for idx in idx_list)
@@ -523,8 +527,7 @@ def convert_indices(indices, entry):
523527
indices_creation_src = f"indices = ({indices_creation_src})"
524528

525529
if set_or_inc:
526-
fn_name = "incsubtensor"
527-
if node.op.inplace:
530+
if op.inplace:
528531
index_prologue = f"z = {input_names[0]}"
529532
else:
530533
index_prologue = f"z = np.copy({input_names[0]})"
@@ -536,84 +539,57 @@ def convert_indices(indices, entry):
536539
else:
537540
y_name = input_names[1]
538541

539-
if node.op.set_instead_of_inc:
542+
if op.set_instead_of_inc:
543+
function_name = "setsubtensor"
540544
index_body = f"z[indices] = {y_name}"
541545
else:
546+
function_name = "incsubtensor"
542547
index_body = f"z[indices] += {y_name}"
543548
else:
544-
fn_name = "subtensor"
549+
function_name = "subtensor"
545550
index_prologue = ""
546551
index_body = f"z = {input_names[0]}[indices]"
547552

548-
if objmode:
549-
output_var = node.outputs[0]
550-
551-
if not set_or_inc:
552-
# Since `z` is being "created" while in object mode, it's
553-
# considered an "outgoing" variable and needs to be manually typed
554-
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
555-
else:
556-
output_sig = ""
557-
558-
index_body = f"""
559-
with objmode({output_sig}):
560-
{index_body}
561-
"""
562-
563553
subtensor_def_src = f"""
564-
def {fn_name}({", ".join(input_names)}):
554+
def {function_name}({", ".join(input_names)}):
565555
{index_prologue}
566556
{indices_creation_src}
567557
{index_body}
568558
return np.asarray(z)
569559
"""
570560

571-
return subtensor_def_src
572-
573-
574-
@numba_funcify.register(Subtensor)
575-
@numba_funcify.register(AdvancedSubtensor1)
576-
def numba_funcify_Subtensor(op, node, **kwargs):
577-
objmode = isinstance(op, AdvancedSubtensor)
578-
if objmode:
579-
warnings.warn(
580-
("Numba will use object mode to allow run " "AdvancedSubtensor."),
581-
UserWarning,
582-
)
583-
584-
subtensor_def_src = create_index_func(node, objmode=objmode)
585-
586-
global_env = {"np": np}
587-
if objmode:
588-
global_env["objmode"] = numba.objmode
589-
590-
subtensor_fn = compile_function_src(
591-
subtensor_def_src, "subtensor", {**globals(), **global_env}
561+
func = compile_function_src(
562+
subtensor_def_src,
563+
function_name=function_name,
564+
global_env=globals() | {"np": np},
592565
)
566+
return numba_njit(func, boundscheck=True)
593567

594-
return numba_njit(subtensor_fn, boundscheck=True)
595568

569+
@numba_funcify.register(AdvancedSubtensor)
570+
@numba_funcify.register(AdvancedIncSubtensor)
571+
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
572+
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
573+
adv_idxs_dims = [
574+
idx.type.ndim
575+
for idx in idxs
576+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
577+
]
596578

597-
@numba_funcify.register(IncSubtensor)
598-
def numba_funcify_IncSubtensor(op, node, **kwargs):
599-
objmode = isinstance(op, AdvancedIncSubtensor)
600-
if objmode:
601-
warnings.warn(
602-
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
603-
UserWarning,
579+
if (
580+
# Numba does not support indexes with more than one dimension
581+
# Nor multiple vector indexes
582+
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
583+
# The default index implementation does not handle duplicate indices correctly
584+
or (
585+
isinstance(op, AdvancedIncSubtensor)
586+
and not op.set_instead_of_inc
587+
and not op.ignore_duplicates
604588
)
589+
):
590+
return generate_fallback_impl(op, node, **kwargs)
605591

606-
incsubtensor_def_src = create_index_func(node, objmode=objmode)
607-
608-
global_env = {"np": np}
609-
if objmode:
610-
global_env["objmode"] = numba.objmode
611-
612-
incsubtensor_fn = compile_function_src(
613-
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
614-
)
615-
616-
return numba_njit(incsubtensor_fn, boundscheck=True)
592+
return numba_funcify_default_subtensor(op, node, **kwargs)
617593

618594

619595
@numba_funcify.register(AdvancedIncSubtensor1)
@@ -713,7 +689,7 @@ def makeslice(*x):
713689

714690

715691
@numba_funcify.register(MakeSlice)
716-
def numba_funcify_MakeSlice(op, **kwargs):
692+
def numba_funcify_MakeSlice(op, node, **kwargs):
717693
return global_numba_func(makeslice)
718694

719695

pytensor/tensor/subtensor.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
from pytensor.printing import Printer, pprint, set_precedence
2222
from pytensor.scalar.basic import ScalarConstant
2323
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
24-
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
24+
from pytensor.tensor.basic import (
25+
ScalarFromTensor,
26+
alloc,
27+
get_underlying_scalar_constant_value,
28+
nonzero,
29+
)
2530
from pytensor.tensor.blockwise import vectorize_node_fallback
2631
from pytensor.tensor.elemwise import DimShuffle
2732
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
@@ -168,8 +173,16 @@ def as_index_literal(
168173
if isinstance(idx, Constant):
169174
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
170175

171-
if isinstance(getattr(idx, "type", None), SliceType):
172-
idx = slice(*idx.owner.inputs)
176+
if isinstance(idx, Variable):
177+
if (
178+
isinstance(idx.type, ps.ScalarType)
179+
and idx.owner
180+
and isinstance(idx.owner.op, ScalarFromTensor)
181+
):
182+
return as_index_literal(idx.owner.inputs[0])
183+
184+
if isinstance(idx.type, SliceType):
185+
idx = slice(*idx.owner.inputs)
173186

174187
if isinstance(idx, slice):
175188
return slice(

pytensor/tensor/type_other.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def as_int_none_variable(x):
1818
return NoneConst
1919
elif NoneConst.equals(x):
2020
return x
21-
x = pytensor.tensor.as_tensor_variable(x, ndim=0)
21+
x = pytensor.scalar.as_scalar(x)
2222
if x.type.dtype not in integer_dtypes:
2323
raise TypeError("index must be integers")
2424
return x

tests/link/numba/test_basic.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -409,21 +409,46 @@ def test_AdvancedSubtensor1_out_of_bounds():
409409

410410

411411
@pytest.mark.parametrize(
412-
"x, indices",
412+
"x, indices, objmode_needed",
413413
[
414-
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])),
414+
(
415+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
416+
(0, [1, 2, 2, 3]),
417+
False,
418+
),
419+
(
420+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
421+
(np.array([True, False, False])),
422+
False,
423+
),
424+
(
425+
pt.as_tensor(np.arange(3 * 3).reshape((3, 3))),
426+
(np.eye(3).astype(bool)),
427+
True,
428+
),
429+
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
415430
(
416431
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
417432
([1, 2], slice(None), [3, 4]),
433+
True,
418434
),
419435
],
420436
)
421-
def test_AdvancedSubtensor(x, indices):
437+
@pytest.mark.filterwarnings("error")
438+
def test_AdvancedSubtensor(x, indices, objmode_needed):
422439
"""Test NumPy's advanced indexing in more than one dimension."""
423440
out_pt = x[indices]
424441
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
425442
out_fg = FunctionGraph([], [out_pt])
426-
compare_numba_and_py(out_fg, [])
443+
with (
444+
pytest.warns(
445+
UserWarning,
446+
match="Numba will use object mode to run AdvancedSubtensor's perform method",
447+
)
448+
if objmode_needed
449+
else contextlib.nullcontext()
450+
):
451+
compare_numba_and_py(out_fg, [])
427452

428453

429454
@pytest.mark.parametrize(
@@ -534,35 +559,96 @@ def test_AdvancedIncSubtensor1(x, y, indices):
534559

535560

536561
@pytest.mark.parametrize(
537-
"x, y, indices",
562+
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode",
538563
[
564+
(
565+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
566+
-np.arange(3 * 5).reshape(3, 5),
567+
(slice(None, None, 2), [1, 2, 3]),
568+
False,
569+
False,
570+
False,
571+
),
572+
(
573+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
574+
-np.arange(4 * 5).reshape(4, 5),
575+
(0, [1, 2, 2, 3]),
576+
True,
577+
False,
578+
True,
579+
),
580+
(
581+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
582+
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
583+
(np.array([True, False, False])),
584+
False,
585+
False,
586+
False,
587+
),
588+
(
589+
pt.as_tensor(np.arange(3 * 3).reshape((3, 3))),
590+
-np.arange(3),
591+
(np.eye(3).astype(bool)),
592+
False,
593+
True,
594+
True,
595+
),
539596
(
540597
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
541598
pt.as_tensor(rng.poisson(size=(2, 5))),
542599
([1, 2], [2, 3]),
600+
False,
601+
True,
602+
True,
543603
),
544604
(
545605
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
546606
pt.as_tensor(rng.poisson(size=(2, 4))),
547607
([1, 2], slice(None), [3, 4]),
608+
False,
609+
True,
610+
True,
548611
),
549612
pytest.param(
550613
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
551614
pt.as_tensor(rng.poisson(size=(2, 5))),
552615
([1, 1], [2, 2]),
616+
False,
617+
True,
618+
True,
553619
),
554620
],
555621
)
556-
def test_AdvancedIncSubtensor(x, y, indices):
622+
@pytest.mark.filterwarnings("error")
623+
def test_AdvancedIncSubtensor(
624+
x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode
625+
):
557626
out_pt = pt.set_subtensor(x[indices], y)
558627
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
559628
out_fg = FunctionGraph([], [out_pt])
560-
compare_numba_and_py(out_fg, [])
561629

562-
out_pt = pt.inc_subtensor(x[indices], y)
630+
with (
631+
pytest.warns(
632+
UserWarning,
633+
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
634+
)
635+
if set_requires_objmode
636+
else contextlib.nullcontext()
637+
):
638+
compare_numba_and_py(out_fg, [])
639+
640+
out_pt = pt.inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices)
563641
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
564642
out_fg = FunctionGraph([], [out_pt])
565-
compare_numba_and_py(out_fg, [])
643+
with (
644+
pytest.warns(
645+
UserWarning,
646+
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
647+
)
648+
if inc_requires_objmode
649+
else contextlib.nullcontext()
650+
):
651+
compare_numba_and_py(out_fg, [])
566652

567653
x_pt = x.type()
568654
out_pt = pt.set_subtensor(x_pt[indices], y)
@@ -571,7 +657,15 @@ def test_AdvancedIncSubtensor(x, y, indices):
571657
out_pt.owner.op.inplace = True
572658
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
573659
out_fg = FunctionGraph([x_pt], [out_pt])
574-
compare_numba_and_py(out_fg, [x.data])
660+
with (
661+
pytest.warns(
662+
UserWarning,
663+
match="Numba will use object mode to run AdvancedSetSubtensor's perform method",
664+
)
665+
if set_requires_objmode
666+
else contextlib.nullcontext()
667+
):
668+
compare_numba_and_py(out_fg, [x.data])
575669

576670

577671
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)