1
+ from typing import Any , List , Optional , Tuple
2
+
1
3
import numba
2
4
import numpy as np
3
5
from llvmlite import ir
10
12
def compute_itershape (
11
13
ctx : BaseContext ,
12
14
builder : ir .IRBuilder ,
13
- in_shapes ,
14
- broadcast_pattern ,
15
+ in_shapes : Tuple [ ir . Instruction , ...] ,
16
+ broadcast_pattern : Tuple [ Tuple [ bool , ...], ...] ,
15
17
):
16
18
one = ir .IntType (64 )(1 )
17
19
ndim = len (in_shapes [0 ])
@@ -59,16 +61,23 @@ def compute_itershape(
59
61
60
62
61
63
def make_outputs (
62
- ctx , builder : ir .IRBuilder , iter_shape , out_bc , dtypes , inplace , inputs , input_types
64
+ ctx : numba .core .base .BaseContext ,
65
+ builder : ir .IRBuilder ,
66
+ iter_shape : Tuple [ir .Instruction , ...],
67
+ out_bc : Tuple [Tuple [bool , ...], ...],
68
+ dtypes : Tuple [Any , ...],
69
+ inplace : Tuple [Tuple [int , int ], ...],
70
+ inputs : Tuple [Any , ...],
71
+ input_types : Tuple [Any , ...],
63
72
):
64
73
arrays = []
65
74
ar_types : list [types .Array ] = []
66
75
one = ir .IntType (64 )(1 )
67
- inplace = dict (inplace )
76
+ inplace_dict = dict (inplace )
68
77
for i , (bc , dtype ) in enumerate (zip (out_bc , dtypes )):
69
- if i in inplace :
70
- arrays .append (inputs [inplace [i ]])
71
- ar_types .append (input_types [inplace [i ]])
78
+ if i in inplace_dict :
79
+ arrays .append (inputs [inplace_dict [i ]])
80
+ ar_types .append (input_types [inplace_dict [i ]])
72
81
# We need to incref once we return the inplace objects
73
82
continue
74
83
dtype = numba .from_dtype (np .dtype (dtype ))
@@ -95,15 +104,15 @@ def make_loop_call(
95
104
typingctx ,
96
105
context : numba .core .base .BaseContext ,
97
106
builder : ir .IRBuilder ,
98
- scalar_func ,
99
- scalar_signature ,
100
- iter_shape ,
101
- inputs ,
102
- outputs ,
103
- input_bc ,
104
- output_bc ,
105
- input_types ,
106
- output_types ,
107
+ scalar_func : Any ,
108
+ scalar_signature : types . FunctionType ,
109
+ iter_shape : Tuple [ ir . Instruction , ...] ,
110
+ inputs : Tuple [ ir . Instruction , ...] ,
111
+ outputs : Tuple [ ir . Instruction , ...] ,
112
+ input_bc : Tuple [ Tuple [ bool , ...], ...] ,
113
+ output_bc : Tuple [ Tuple [ bool , ...], ...] ,
114
+ input_types : Tuple [ Any , ...] ,
115
+ output_types : Tuple [ Any , ...] ,
107
116
):
108
117
safe = (False , False )
109
118
n_outputs = len (outputs )
@@ -142,23 +151,25 @@ def extract_array(aryty, obj):
142
151
# input_scope_set = mod.add_metadata([input_scope, output_scope])
143
152
# output_scope_set = mod.add_metadata([input_scope, output_scope])
144
153
145
- inputs = [
154
+ inputs = tuple (
146
155
extract_array (aryty , ary )
147
156
for aryty , ary in zip (input_types , inputs , strict = True )
148
- ]
157
+ )
149
158
150
- outputs = [
159
+ outputs = tuple (
151
160
extract_array (aryty , ary )
152
161
for aryty , ary in zip (output_types , outputs , strict = True )
153
- ]
162
+ )
154
163
155
164
zero = ir .Constant (ir .IntType (64 ), 0 )
156
165
157
166
# Setup loops and initialize accumulators for outputs
158
167
# This part corresponds to opening the loops
159
168
loop_stack = []
160
169
loops = []
161
- output_accumulator = [(None , None )] * n_outputs
170
+ output_accumulator : List [Tuple [Optional [Any ], Optional [int ]]] = [
171
+ (None , None )
172
+ ] * n_outputs
162
173
for dim , length in enumerate (iter_shape ):
163
174
# Find outputs that only have accumulations left
164
175
for output in range (n_outputs ):
0 commit comments