Skip to content

Commit 0f6dbc8

Browse files
committed
Remove py310 only strict arg to zip
1 parent b55f51b commit 0f6dbc8

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

pytensor/link/numba/dispatch/elemwise_codegen.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ def compute_itershape(
1919
ndim = len(in_shapes[0])
2020
shape = [None] * ndim
2121
for i in range(ndim):
22-
for j, (bc, in_shape) in enumerate(
23-
zip(broadcast_pattern, in_shapes, strict=True)
24-
):
22+
for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
2523
length = in_shape[i]
2624
if bc[i]:
2725
with builder.if_then(
@@ -151,14 +149,10 @@ def extract_array(aryty, obj):
151149
# input_scope_set = mod.add_metadata([input_scope, output_scope])
152150
# output_scope_set = mod.add_metadata([input_scope, output_scope])
153151

154-
inputs = tuple(
155-
extract_array(aryty, ary)
156-
for aryty, ary in zip(input_types, inputs, strict=True)
157-
)
152+
inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs))
158153

159154
outputs = tuple(
160-
extract_array(aryty, ary)
161-
for aryty, ary in zip(output_types, outputs, strict=True)
155+
extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs)
162156
)
163157

164158
zero = ir.Constant(ir.IntType(64), 0)
@@ -189,8 +183,8 @@ def extract_array(aryty, obj):
189183

190184
# Load values from input arrays
191185
input_vals = []
192-
for array_info, bc in zip(inputs, input_bc, strict=True):
193-
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)]
186+
for array_info, bc in zip(inputs, input_bc):
187+
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)]
194188
ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
195189
val = builder.load(ptr)
196190
# val.set_metadata("alias.scope", input_scope_set)
@@ -210,9 +204,7 @@ def extract_array(aryty, obj):
210204
output_values = [output_values]
211205

212206
# Update output value or accumulators respectively
213-
for i, ((accu, _), value) in enumerate(
214-
zip(output_accumulator, output_values, strict=True)
215-
):
207+
for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)):
216208
if accu is not None:
217209
load = builder.load(accu)
218210
# load.set_metadata("alias.scope", output_scope_set)
@@ -223,9 +215,7 @@ def extract_array(aryty, obj):
223215
# store.set_metadata("alias.scope", output_scope_set)
224216
# store.set_metadata("noalias", input_scope_set)
225217
else:
226-
idxs_bc = [
227-
zero if bc else idx for idx, bc in zip(idxs, output_bc[i], strict=True)
228-
]
218+
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])]
229219
ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc)
230220
# store = builder.store(value, ptr)
231221
arrayobj.store_item(context, builder, output_types[i], value, ptr)
@@ -237,8 +227,7 @@ def extract_array(aryty, obj):
237227
for output, (accu, accu_depth) in enumerate(output_accumulator):
238228
if accu_depth == depth:
239229
idxs_bc = [
240-
zero if bc else idx
241-
for idx, bc in zip(idxs, output_bc[output], strict=True)
230+
zero if bc else idx for idx, bc in zip(idxs, output_bc[output])
242231
]
243232
ptr = cgutils.get_item_pointer2(
244233
context, builder, *outputs[output], idxs_bc

0 commit comments

Comments
 (0)