5
5
from pytensor .link .numba .dispatch .basic import generate_fallback_impl , numba_njit
6
6
from pytensor .link .utils import compile_function_src , unique_name_generator
7
7
from pytensor .tensor import TensorType
8
+ from pytensor .tensor .rewriting .subtensor import is_full_slice
8
9
from pytensor .tensor .subtensor import (
9
10
AdvancedIncSubtensor ,
10
11
AdvancedIncSubtensor1 ,
13
14
IncSubtensor ,
14
15
Subtensor ,
15
16
)
17
+ from pytensor .tensor .type_other import NoneTypeT , SliceType
16
18
17
19
18
20
@numba_funcify .register (Subtensor )
@@ -104,18 +106,72 @@ def {function_name}({", ".join(input_names)}):
104
106
@numba_funcify .register (AdvancedSubtensor )
105
107
@numba_funcify .register (AdvancedIncSubtensor )
106
108
def numba_funcify_AdvancedSubtensor (op , node , ** kwargs ):
107
- idxs = node .inputs [1 :] if isinstance (op , AdvancedSubtensor ) else node .inputs [2 :]
108
- adv_idxs_dims = [
109
- idx .type .ndim
109
+ if isinstance (op , AdvancedSubtensor ):
110
+ x , y , idxs = node .inputs [0 ], None , node .inputs [1 :]
111
+ else :
112
+ x , y , * idxs = node .inputs
113
+
114
+ basic_idxs = [
115
+ idx
110
116
for idx in idxs
111
- if (isinstance (idx .type , TensorType ) and idx .type .ndim > 0 )
117
+ if (
118
+ isinstance (idx .type , NoneTypeT )
119
+ or (isinstance (idx .type , SliceType ) and not is_full_slice (idx ))
120
+ )
112
121
]
122
+ adv_idxs = [
123
+ {
124
+ "axis" : i ,
125
+ "dtype" : idx .type .dtype ,
126
+ "bcast" : idx .type .broadcastable ,
127
+ "ndim" : idx .type .ndim ,
128
+ }
129
+ for i , idx in enumerate (idxs )
130
+ if isinstance (idx .type , TensorType )
131
+ ]
132
+
133
+ # Special case for consecutive consecutive vector indices
134
+ def broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
135
+ # Check that x is not broadcasted to y based on broadcastable info
136
+ if len (x_bcast ) < len (to_bcast ):
137
+ return True
138
+ for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
139
+ if x_bcast_dim and not to_bcast_dim :
140
+ return True
141
+ return False
142
+
143
+ if (
144
+ not basic_idxs
145
+ and len (adv_idxs ) >= 2
146
+ # Must be integer vectors
147
+ # Todo: we could allow shape=(1,) if this is the shape of x
148
+ and all (
149
+ (adv_idx ["bcast" ] == (False ,) and adv_idx ["dtype" ] != "bool" )
150
+ for adv_idx in adv_idxs
151
+ )
152
+ # Must be consecutive
153
+ and not op .non_contiguous_adv_indexing (node )
154
+ # y in set/inc_subtensor cannot be broadcasted
155
+ and (
156
+ y is None
157
+ or not broadcasted_to (
158
+ y .type .broadcastable ,
159
+ (
160
+ x .type .broadcastable [: adv_idxs [0 ]["axis" ]]
161
+ + x .type .broadcastable [adv_idxs [- 1 ]["axis" ] :]
162
+ ),
163
+ )
164
+ )
165
+ ):
166
+ return numba_funcify_multiple_vector_indexing (op , node , ** kwargs )
113
167
168
+ # Cases natively supported by Numba
114
169
if (
115
170
# Numba does not support indexes with more than one dimension
171
+ any (idx ["ndim" ] > 1 for idx in adv_idxs )
116
172
# Nor multiple vector indexes
117
- ( len ( adv_idxs_dims ) > 1 or adv_idxs_dims [ 0 ] > 1 )
118
- # The default index implementation does not handle duplicate indices correctly
173
+ or sum ( idx [ "ndim" ] > 0 for idx in adv_idxs ) > 1
174
+ # The default PyTensor implementation does not handle duplicate indices correctly
119
175
or (
120
176
isinstance (op , AdvancedIncSubtensor )
121
177
and not op .set_instead_of_inc
@@ -127,6 +183,87 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
127
183
return numba_funcify_default_subtensor (op , node , ** kwargs )
128
184
129
185
186
+ def numba_funcify_multiple_vector_indexing (
187
+ op : AdvancedSubtensor | AdvancedIncSubtensor , node , ** kwargs
188
+ ):
189
+ # Special-case implementation for multiple consecutive vector indices (and set/incsubtensor)
190
+ if isinstance (op , AdvancedSubtensor ):
191
+ y , idxs = None , node .inputs [1 :]
192
+ else :
193
+ y , * idxs = node .inputs [1 :]
194
+
195
+ first_axis = next (
196
+ i for i , idx in enumerate (idxs ) if isinstance (idx .type , TensorType )
197
+ )
198
+ try :
199
+ after_last_axis = next (
200
+ i
201
+ for i , idx in enumerate (idxs [first_axis :], start = first_axis )
202
+ if not isinstance (idx .type , TensorType )
203
+ )
204
+ except StopIteration :
205
+ after_last_axis = len (idxs )
206
+
207
+ if isinstance (op , AdvancedSubtensor ):
208
+
209
+ @numba_njit
210
+ def advanced_subtensor_multiple_vector (x , * idxs ):
211
+ none_slices = idxs [:first_axis ]
212
+ vec_idxs = idxs [first_axis :after_last_axis ]
213
+
214
+ x_shape = x .shape
215
+ idx_shape = vec_idxs [0 ].shape
216
+ shape_bef = x_shape [:first_axis ]
217
+ shape_aft = x_shape [after_last_axis :]
218
+ out_shape = (* shape_bef , * idx_shape , * shape_aft )
219
+ out_buffer = np .empty (out_shape , dtype = x .dtype )
220
+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
221
+ out_buffer [(* none_slices , i )] = x [(* none_slices , * scalar_idxs )]
222
+ return out_buffer
223
+
224
+ return advanced_subtensor_multiple_vector
225
+
226
+ elif op .set_instead_of_inc :
227
+ inplace = op .inplace
228
+
229
+ @numba_njit
230
+ def advanced_set_subtensor_multiple_vector (x , y , * idxs ):
231
+ vec_idxs = idxs [first_axis :after_last_axis ]
232
+ x_shape = x .shape
233
+
234
+ if inplace :
235
+ out = x
236
+ else :
237
+ out = x .copy ()
238
+
239
+ for outer in np .ndindex (x_shape [:first_axis ]):
240
+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
241
+ out [(* outer , * scalar_idxs )] = y [(* outer , i )]
242
+ return out
243
+
244
+ return advanced_set_subtensor_multiple_vector
245
+
246
+ else :
247
+ inplace = op .inplace
248
+
249
+ @numba_njit
250
+ def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
251
+ vec_idxs = idxs [first_axis :after_last_axis ]
252
+ x_shape = x .shape
253
+
254
+ if inplace :
255
+ out = x
256
+ else :
257
+ out = x .copy ()
258
+
259
+ for outer in np .ndindex (x_shape [:first_axis ]):
260
+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
261
+ out [(* outer , * scalar_idxs )] += y [(* outer , i )]
262
+ return out
263
+
264
+ return advanced_inc_subtensor_multiple_vector
265
+
266
+
130
267
@numba_funcify .register (AdvancedIncSubtensor1 )
131
268
def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
132
269
inplace = op .inplace
0 commit comments