1
- import warnings
2
-
3
- import numba
4
1
import numpy as np
5
2
6
3
from pytensor .graph import Type
7
4
from pytensor .link .numba .dispatch import numba_funcify
8
- from pytensor .link .numba .dispatch .basic import numba_njit
5
+ from pytensor .link .numba .dispatch .basic import generate_fallback_impl , numba_njit
9
6
from pytensor .link .utils import compile_function_src , unique_name_generator
7
+ from pytensor .tensor import TensorType
10
8
from pytensor .tensor .subtensor import (
11
9
AdvancedIncSubtensor ,
12
10
AdvancedIncSubtensor1 ,
17
15
)
18
16
19
17
20
- def create_index_func (node , objmode = False ):
18
+ @numba_funcify .register (Subtensor )
19
+ @numba_funcify .register (IncSubtensor )
20
+ @numba_funcify .register (AdvancedSubtensor1 )
21
+ def numba_funcify_default_subtensor (op , node , ** kwargs ):
21
22
"""Create a Python function that assembles and uses an index on an array."""
22
23
23
24
unique_names = unique_name_generator (
@@ -40,13 +41,13 @@ def convert_indices(indices, entry):
40
41
raise ValueError ()
41
42
42
43
set_or_inc = isinstance (
43
- node . op , IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
44
+ op , IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
44
45
)
45
46
index_start_idx = 1 + int (set_or_inc )
46
47
47
48
input_names = [unique_names (v , force_unique = True ) for v in node .inputs ]
48
49
op_indices = list (node .inputs [index_start_idx :])
49
- idx_list = getattr (node . op , "idx_list" , None )
50
+ idx_list = getattr (op , "idx_list" , None )
50
51
51
52
indices_creation_src = (
52
53
tuple (convert_indices (op_indices , idx ) for idx in idx_list )
@@ -61,8 +62,7 @@ def convert_indices(indices, entry):
61
62
indices_creation_src = f"indices = ({ indices_creation_src } )"
62
63
63
64
if set_or_inc :
64
- fn_name = "incsubtensor"
65
- if node .op .inplace :
65
+ if op .inplace :
66
66
index_prologue = f"z = { input_names [0 ]} "
67
67
else :
68
68
index_prologue = f"z = np.copy({ input_names [0 ]} )"
@@ -74,84 +74,57 @@ def convert_indices(indices, entry):
74
74
else :
75
75
y_name = input_names [1 ]
76
76
77
- if node .op .set_instead_of_inc :
77
+ if op .set_instead_of_inc :
78
+ function_name = "setsubtensor"
78
79
index_body = f"z[indices] = { y_name } "
79
80
else :
81
+ function_name = "incsubtensor"
80
82
index_body = f"z[indices] += { y_name } "
81
83
else :
82
- fn_name = "subtensor"
84
+ function_name = "subtensor"
83
85
index_prologue = ""
84
86
index_body = f"z = { input_names [0 ]} [indices]"
85
87
86
- if objmode :
87
- output_var = node .outputs [0 ]
88
-
89
- if not set_or_inc :
90
- # Since `z` is being "created" while in object mode, it's
91
- # considered an "outgoing" variable and needs to be manually typed
92
- output_sig = f"z='{ output_var .dtype } [{ ', ' .join ([':' ] * output_var .ndim )} ]'"
93
- else :
94
- output_sig = ""
95
-
96
- index_body = f"""
97
- with objmode({ output_sig } ):
98
- { index_body }
99
- """
100
-
101
88
subtensor_def_src = f"""
102
- def { fn_name } ({ ", " .join (input_names )} ):
89
+ def { function_name } ({ ", " .join (input_names )} ):
103
90
{ index_prologue }
104
91
{ indices_creation_src }
105
92
{ index_body }
106
93
return np.asarray(z)
107
94
"""
108
95
109
- return subtensor_def_src
110
-
111
-
112
- @numba_funcify .register (Subtensor )
113
- @numba_funcify .register (AdvancedSubtensor1 )
114
- def numba_funcify_Subtensor (op , node , ** kwargs ):
115
- objmode = isinstance (op , AdvancedSubtensor )
116
- if objmode :
117
- warnings .warn (
118
- ("Numba will use object mode to allow run " "AdvancedSubtensor." ),
119
- UserWarning ,
120
- )
121
-
122
- subtensor_def_src = create_index_func (node , objmode = objmode )
123
-
124
- global_env = {"np" : np }
125
- if objmode :
126
- global_env ["objmode" ] = numba .objmode
127
-
128
- subtensor_fn = compile_function_src (
129
- subtensor_def_src , "subtensor" , {** globals (), ** global_env }
96
+ func = compile_function_src (
97
+ subtensor_def_src ,
98
+ function_name = function_name ,
99
+ global_env = globals () | {"np" : np },
130
100
)
131
-
132
- return numba_njit (subtensor_fn , boundscheck = True )
133
-
134
-
135
- @numba_funcify .register (IncSubtensor )
136
- def numba_funcify_IncSubtensor (op , node , ** kwargs ):
137
- objmode = isinstance (op , AdvancedIncSubtensor )
138
- if objmode :
139
- warnings .warn (
140
- ("Numba will use object mode to allow run " "AdvancedIncSubtensor." ),
141
- UserWarning ,
101
+ return numba_njit (func , boundscheck = True )
102
+
103
+
104
+ @numba_funcify .register (AdvancedSubtensor )
105
+ @numba_funcify .register (AdvancedIncSubtensor )
106
+ def numba_funcify_AdvancedSubtensor (op , node , ** kwargs ):
107
+ idxs = node .inputs [1 :] if isinstance (op , AdvancedSubtensor ) else node .inputs [2 :]
108
+ adv_idxs_dims = [
109
+ idx .type .ndim
110
+ for idx in idxs
111
+ if (isinstance (idx .type , TensorType ) and idx .type .ndim > 0 )
112
+ ]
113
+
114
+ if (
115
+ # Numba does not support indexes with more than one dimension
116
+ # Nor multiple vector indexes
117
+ (len (adv_idxs_dims ) > 1 or adv_idxs_dims [0 ] > 1 )
118
+ # The default index implementation does not handle duplicate indices correctly
119
+ or (
120
+ isinstance (op , AdvancedIncSubtensor )
121
+ and not op .set_instead_of_inc
122
+ and not op .ignore_duplicates
142
123
)
124
+ ):
125
+ return generate_fallback_impl (op , node , ** kwargs )
143
126
144
- incsubtensor_def_src = create_index_func (node , objmode = objmode )
145
-
146
- global_env = {"np" : np }
147
- if objmode :
148
- global_env ["objmode" ] = numba .objmode
149
-
150
- incsubtensor_fn = compile_function_src (
151
- incsubtensor_def_src , "incsubtensor" , {** globals (), ** global_env }
152
- )
153
-
154
- return numba_njit (incsubtensor_fn , boundscheck = True )
127
+ return numba_funcify_default_subtensor (op , node , ** kwargs )
155
128
156
129
157
130
@numba_funcify .register (AdvancedIncSubtensor1 )
0 commit comments