@@ -19,9 +19,7 @@ def compute_itershape(
19
19
ndim = len (in_shapes [0 ])
20
20
shape = [None ] * ndim
21
21
for i in range (ndim ):
22
- for j , (bc , in_shape ) in enumerate (
23
- zip (broadcast_pattern , in_shapes , strict = True )
24
- ):
22
+ for j , (bc , in_shape ) in enumerate (zip (broadcast_pattern , in_shapes )):
25
23
length = in_shape [i ]
26
24
if bc [i ]:
27
25
with builder .if_then (
@@ -151,14 +149,10 @@ def extract_array(aryty, obj):
151
149
# input_scope_set = mod.add_metadata([input_scope, output_scope])
152
150
# output_scope_set = mod.add_metadata([input_scope, output_scope])
153
151
154
- inputs = tuple (
155
- extract_array (aryty , ary )
156
- for aryty , ary in zip (input_types , inputs , strict = True )
157
- )
152
+ inputs = tuple (extract_array (aryty , ary ) for aryty , ary in zip (input_types , inputs ))
158
153
159
154
outputs = tuple (
160
- extract_array (aryty , ary )
161
- for aryty , ary in zip (output_types , outputs , strict = True )
155
+ extract_array (aryty , ary ) for aryty , ary in zip (output_types , outputs )
162
156
)
163
157
164
158
zero = ir .Constant (ir .IntType (64 ), 0 )
@@ -189,8 +183,8 @@ def extract_array(aryty, obj):
189
183
190
184
# Load values from input arrays
191
185
input_vals = []
192
- for array_info , bc in zip (inputs , input_bc , strict = True ):
193
- idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc , strict = True )]
186
+ for array_info , bc in zip (inputs , input_bc ):
187
+ idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc )]
194
188
ptr = cgutils .get_item_pointer2 (context , builder , * array_info , idxs_bc , * safe )
195
189
val = builder .load (ptr )
196
190
# val.set_metadata("alias.scope", input_scope_set)
@@ -210,9 +204,7 @@ def extract_array(aryty, obj):
210
204
output_values = [output_values ]
211
205
212
206
# Update output value or accumulators respectively
213
- for i , ((accu , _ ), value ) in enumerate (
214
- zip (output_accumulator , output_values , strict = True )
215
- ):
207
+ for i , ((accu , _ ), value ) in enumerate (zip (output_accumulator , output_values )):
216
208
if accu is not None :
217
209
load = builder .load (accu )
218
210
# load.set_metadata("alias.scope", output_scope_set)
@@ -223,9 +215,7 @@ def extract_array(aryty, obj):
223
215
# store.set_metadata("alias.scope", output_scope_set)
224
216
# store.set_metadata("noalias", input_scope_set)
225
217
else :
226
- idxs_bc = [
227
- zero if bc else idx for idx , bc in zip (idxs , output_bc [i ], strict = True )
228
- ]
218
+ idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , output_bc [i ])]
229
219
ptr = cgutils .get_item_pointer2 (context , builder , * outputs [i ], idxs_bc )
230
220
# store = builder.store(value, ptr)
231
221
arrayobj .store_item (context , builder , output_types [i ], value , ptr )
@@ -237,8 +227,7 @@ def extract_array(aryty, obj):
237
227
for output , (accu , accu_depth ) in enumerate (output_accumulator ):
238
228
if accu_depth == depth :
239
229
idxs_bc = [
240
- zero if bc else idx
241
- for idx , bc in zip (idxs , output_bc [output ], strict = True )
230
+ zero if bc else idx for idx , bc in zip (idxs , output_bc [output ])
242
231
]
243
232
ptr = cgutils .get_item_pointer2 (
244
233
context , builder , * outputs [output ], idxs_bc
0 commit comments