Skip to content

Commit 74af76f

Browse files
committed
Add boundschecks in numba backend
1 parent 7caee0e commit 74af76f

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def numba_funcify_Subtensor(op, node, **kwargs):
554554
subtensor_def_src, "subtensor", {**globals(), **global_env}
555555
)
556556

557-
return numba_njit(subtensor_fn)
557+
return numba_njit(subtensor_fn, boundscheck=True)
558558

559559

560560
@numba_funcify.register(IncSubtensor)
@@ -570,7 +570,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
570570
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
571571
)
572572

573-
return numba_njit(incsubtensor_fn)
573+
return numba_njit(incsubtensor_fn, boundscheck=True)
574574

575575

576576
@numba_funcify.register(AdvancedIncSubtensor1)
@@ -580,15 +580,15 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
580580

581581
if set_instead_of_inc:
582582

583-
@numba_njit
583+
@numba_njit(boundscheck=True)
584584
def advancedincsubtensor1_inplace(x, vals, idxs):
585585
for idx, val in zip(idxs, vals):
586586
x[idx] = val
587587
return x
588588

589589
else:
590590

591-
@numba_njit
591+
@numba_njit(boundscheck=True)
592592
def advancedincsubtensor1_inplace(x, vals, idxs):
593593
for idx, val in zip(idxs, vals):
594594
x[idx] += val

tests/link/numba/test_basic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ def test_AdvancedSubtensor1(x, indices):
373373
compare_numba_and_py(out_fg, [])
374374

375375

376+
def test_AdvancedSubtensor1_out_of_bounds():
377+
out_at = at_subtensor.advanced_subtensor1(np.arange(3), [4])
378+
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1)
379+
out_fg = FunctionGraph([], [out_at])
380+
with pytest.raises(IndexError):
381+
compare_numba_and_py(out_fg, [])
382+
383+
376384
@pytest.mark.parametrize(
377385
"x, indices",
378386
[

0 commit comments

Comments
 (0)