Skip to content

Commit b55f51b

Browse files
committed
Add typing for some numba elemwise
1 parent 5d2582d commit b55f51b

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

pytensor/link/numba/dispatch/elemwise_codegen.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, List, Optional, Tuple
2+
13
import numba
24
import numpy as np
35
from llvmlite import ir
@@ -10,8 +12,8 @@
1012
def compute_itershape(
1113
ctx: BaseContext,
1214
builder: ir.IRBuilder,
13-
in_shapes,
14-
broadcast_pattern,
15+
in_shapes: Tuple[ir.Instruction, ...],
16+
broadcast_pattern: Tuple[Tuple[bool, ...], ...],
1517
):
1618
one = ir.IntType(64)(1)
1719
ndim = len(in_shapes[0])
@@ -59,16 +61,23 @@ def compute_itershape(
5961

6062

6163
def make_outputs(
62-
ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types
64+
ctx: numba.core.base.BaseContext,
65+
builder: ir.IRBuilder,
66+
iter_shape: Tuple[ir.Instruction, ...],
67+
out_bc: Tuple[Tuple[bool, ...], ...],
68+
dtypes: Tuple[Any, ...],
69+
inplace: Tuple[Tuple[int, int], ...],
70+
inputs: Tuple[Any, ...],
71+
input_types: Tuple[Any, ...],
6372
):
6473
arrays = []
6574
ar_types: list[types.Array] = []
6675
one = ir.IntType(64)(1)
67-
inplace = dict(inplace)
76+
inplace_dict = dict(inplace)
6877
for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)):
69-
if i in inplace:
70-
arrays.append(inputs[inplace[i]])
71-
ar_types.append(input_types[inplace[i]])
78+
if i in inplace_dict:
79+
arrays.append(inputs[inplace_dict[i]])
80+
ar_types.append(input_types[inplace_dict[i]])
7281
# We need to incref once we return the inplace objects
7382
continue
7483
dtype = numba.from_dtype(np.dtype(dtype))
@@ -95,15 +104,15 @@ def make_loop_call(
95104
typingctx,
96105
context: numba.core.base.BaseContext,
97106
builder: ir.IRBuilder,
98-
scalar_func,
99-
scalar_signature,
100-
iter_shape,
101-
inputs,
102-
outputs,
103-
input_bc,
104-
output_bc,
105-
input_types,
106-
output_types,
107+
scalar_func: Any,
108+
scalar_signature: types.FunctionType,
109+
iter_shape: Tuple[ir.Instruction, ...],
110+
inputs: Tuple[ir.Instruction, ...],
111+
outputs: Tuple[ir.Instruction, ...],
112+
input_bc: Tuple[Tuple[bool, ...], ...],
113+
output_bc: Tuple[Tuple[bool, ...], ...],
114+
input_types: Tuple[Any, ...],
115+
output_types: Tuple[Any, ...],
107116
):
108117
safe = (False, False)
109118
n_outputs = len(outputs)
@@ -142,23 +151,25 @@ def extract_array(aryty, obj):
142151
# input_scope_set = mod.add_metadata([input_scope, output_scope])
143152
# output_scope_set = mod.add_metadata([input_scope, output_scope])
144153

145-
inputs = [
154+
inputs = tuple(
146155
extract_array(aryty, ary)
147156
for aryty, ary in zip(input_types, inputs, strict=True)
148-
]
157+
)
149158

150-
outputs = [
159+
outputs = tuple(
151160
extract_array(aryty, ary)
152161
for aryty, ary in zip(output_types, outputs, strict=True)
153-
]
162+
)
154163

155164
zero = ir.Constant(ir.IntType(64), 0)
156165

157166
# Setup loops and initialize accumulators for outputs
158167
# This part corresponds to opening the loops
159168
loop_stack = []
160169
loops = []
161-
output_accumulator = [(None, None)] * n_outputs
170+
output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [
171+
(None, None)
172+
] * n_outputs
162173
for dim, length in enumerate(iter_shape):
163174
# Find outputs that only have accumulations left
164175
for output in range(n_outputs):

0 commit comments

Comments
 (0)