Skip to content

Commit a14cb2b

Browse files
committed
Do not use Numba objmode for supported AdvancedSubtensor operations
Use ScalarTypes in MakeSlice for compatibility with Numba
1 parent a9c52dd commit a14cb2b

File tree

4 files changed

+193
-86
lines changed

4 files changed

+193
-86
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 43 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import warnings
2-
3-
import numba
41
import numpy as np
52

63
from pytensor.graph import Type
74
from pytensor.link.numba.dispatch import numba_funcify
8-
from pytensor.link.numba.dispatch.basic import numba_njit
5+
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
96
from pytensor.link.utils import compile_function_src, unique_name_generator
7+
from pytensor.tensor import TensorType
108
from pytensor.tensor.subtensor import (
119
AdvancedIncSubtensor,
1210
AdvancedIncSubtensor1,
@@ -17,7 +15,10 @@
1715
)
1816

1917

20-
def create_index_func(node, objmode=False):
18+
@numba_funcify.register(Subtensor)
19+
@numba_funcify.register(IncSubtensor)
20+
@numba_funcify.register(AdvancedSubtensor1)
21+
def numba_funcify_default_subtensor(op, node, **kwargs):
2122
"""Create a Python function that assembles and uses an index on an array."""
2223

2324
unique_names = unique_name_generator(
@@ -40,13 +41,13 @@ def convert_indices(indices, entry):
4041
raise ValueError()
4142

4243
set_or_inc = isinstance(
43-
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
44+
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
4445
)
4546
index_start_idx = 1 + int(set_or_inc)
4647

4748
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
4849
op_indices = list(node.inputs[index_start_idx:])
49-
idx_list = getattr(node.op, "idx_list", None)
50+
idx_list = getattr(op, "idx_list", None)
5051

5152
indices_creation_src = (
5253
tuple(convert_indices(op_indices, idx) for idx in idx_list)
@@ -61,8 +62,7 @@ def convert_indices(indices, entry):
6162
indices_creation_src = f"indices = ({indices_creation_src})"
6263

6364
if set_or_inc:
64-
fn_name = "incsubtensor"
65-
if node.op.inplace:
65+
if op.inplace:
6666
index_prologue = f"z = {input_names[0]}"
6767
else:
6868
index_prologue = f"z = np.copy({input_names[0]})"
@@ -74,84 +74,57 @@ def convert_indices(indices, entry):
7474
else:
7575
y_name = input_names[1]
7676

77-
if node.op.set_instead_of_inc:
77+
if op.set_instead_of_inc:
78+
function_name = "setsubtensor"
7879
index_body = f"z[indices] = {y_name}"
7980
else:
81+
function_name = "incsubtensor"
8082
index_body = f"z[indices] += {y_name}"
8183
else:
82-
fn_name = "subtensor"
84+
function_name = "subtensor"
8385
index_prologue = ""
8486
index_body = f"z = {input_names[0]}[indices]"
8587

86-
if objmode:
87-
output_var = node.outputs[0]
88-
89-
if not set_or_inc:
90-
# Since `z` is being "created" while in object mode, it's
91-
# considered an "outgoing" variable and needs to be manually typed
92-
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
93-
else:
94-
output_sig = ""
95-
96-
index_body = f"""
97-
with objmode({output_sig}):
98-
{index_body}
99-
"""
100-
10188
subtensor_def_src = f"""
102-
def {fn_name}({", ".join(input_names)}):
89+
def {function_name}({", ".join(input_names)}):
10390
{index_prologue}
10491
{indices_creation_src}
10592
{index_body}
10693
return np.asarray(z)
10794
"""
10895

109-
return subtensor_def_src
110-
111-
112-
@numba_funcify.register(Subtensor)
113-
@numba_funcify.register(AdvancedSubtensor1)
114-
def numba_funcify_Subtensor(op, node, **kwargs):
115-
objmode = isinstance(op, AdvancedSubtensor)
116-
if objmode:
117-
warnings.warn(
118-
("Numba will use object mode to allow run " "AdvancedSubtensor."),
119-
UserWarning,
120-
)
121-
122-
subtensor_def_src = create_index_func(node, objmode=objmode)
123-
124-
global_env = {"np": np}
125-
if objmode:
126-
global_env["objmode"] = numba.objmode
127-
128-
subtensor_fn = compile_function_src(
129-
subtensor_def_src, "subtensor", {**globals(), **global_env}
96+
func = compile_function_src(
97+
subtensor_def_src,
98+
function_name=function_name,
99+
global_env=globals() | {"np": np},
130100
)
131-
132-
return numba_njit(subtensor_fn, boundscheck=True)
133-
134-
135-
@numba_funcify.register(IncSubtensor)
136-
def numba_funcify_IncSubtensor(op, node, **kwargs):
137-
objmode = isinstance(op, AdvancedIncSubtensor)
138-
if objmode:
139-
warnings.warn(
140-
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
141-
UserWarning,
101+
return numba_njit(func, boundscheck=True)
102+
103+
104+
@numba_funcify.register(AdvancedSubtensor)
105+
@numba_funcify.register(AdvancedIncSubtensor)
106+
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
107+
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
108+
adv_idxs_dims = [
109+
idx.type.ndim
110+
for idx in idxs
111+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
112+
]
113+
114+
if (
115+
# Numba does not support indexes with more than one dimension
116+
# Nor multiple vector indexes
117+
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
118+
# The default index implementation does not handle duplicate indices correctly
119+
or (
120+
isinstance(op, AdvancedIncSubtensor)
121+
and not op.set_instead_of_inc
122+
and not op.ignore_duplicates
142123
)
124+
):
125+
return generate_fallback_impl(op, node, **kwargs)
143126

144-
incsubtensor_def_src = create_index_func(node, objmode=objmode)
145-
146-
global_env = {"np": np}
147-
if objmode:
148-
global_env["objmode"] = numba.objmode
149-
150-
incsubtensor_fn = compile_function_src(
151-
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
152-
)
153-
154-
return numba_njit(incsubtensor_fn, boundscheck=True)
127+
return numba_funcify_default_subtensor(op, node, **kwargs)
155128

156129

157130
@numba_funcify.register(AdvancedIncSubtensor1)

pytensor/tensor/subtensor.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
from pytensor.printing import Printer, pprint, set_precedence
2222
from pytensor.scalar.basic import ScalarConstant
2323
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
24-
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
24+
from pytensor.tensor.basic import (
25+
ScalarFromTensor,
26+
alloc,
27+
get_underlying_scalar_constant_value,
28+
nonzero,
29+
)
2530
from pytensor.tensor.blockwise import vectorize_node_fallback
2631
from pytensor.tensor.elemwise import DimShuffle
2732
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
@@ -168,8 +173,16 @@ def as_index_literal(
168173
if isinstance(idx, Constant):
169174
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
170175

171-
if isinstance(getattr(idx, "type", None), SliceType):
172-
idx = slice(*idx.owner.inputs)
176+
if isinstance(idx, Variable):
177+
if (
178+
isinstance(idx.type, ps.ScalarType)
179+
and idx.owner
180+
and isinstance(idx.owner.op, ScalarFromTensor)
181+
):
182+
return as_index_literal(idx.owner.inputs[0])
183+
184+
if isinstance(idx.type, SliceType):
185+
idx = slice(*idx.owner.inputs)
173186

174187
if isinstance(idx, slice):
175188
return slice(

pytensor/tensor/type_other.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def as_int_none_variable(x):
1818
return NoneConst
1919
elif NoneConst.equals(x):
2020
return x
21-
x = pytensor.tensor.as_tensor_variable(x, ndim=0)
21+
x = pytensor.scalar.as_scalar(x)
2222
if x.type.dtype not in integer_dtypes:
2323
raise TypeError("index must be integers")
2424
return x

0 commit comments

Comments
 (0)