@@ -55,6 +55,7 @@ def __init__(
55
55
constant : Sequence [Variable ] | None = None ,
56
56
until : Variable | None = None ,
57
57
name = "ScalarLoop" ,
58
+ ** kwargs ,
58
59
):
59
60
if constant is None :
60
61
constant = []
@@ -75,7 +76,7 @@ def __init__(
75
76
self .nout = len (self .outputs )
76
77
self .name = name
77
78
78
- super ().__init__ ()
79
+ super ().__init__ (** kwargs )
79
80
80
81
def output_types (self , input_types ):
81
82
return self .outputs_type
@@ -115,7 +116,7 @@ def fgraph(self):
115
116
self ._fgraph = fgraph
116
117
return self ._fgraph
117
118
118
- def clone (self ):
119
+ def clone (self , name = None , ** kwargs ):
119
120
if self .is_while :
120
121
* update , until = self .outputs
121
122
else :
@@ -127,28 +128,16 @@ def clone(self):
127
128
update = update ,
128
129
constant = constant ,
129
130
until = until ,
130
- name = self .name ,
131
+ name = self .name if name is None else name ,
132
+ ** kwargs ,
131
133
)
132
134
133
135
@property
134
136
def fn (self ):
135
137
raise NotImplementedError
136
138
137
139
def make_new_inplace (self , output_types_preference = None , name = None ):
138
- """
139
- This op.__init__ fct don't have the same parameter as other scalar op.
140
- This break the insert_inplace_optimizer optimization.
141
- This fct allow fix patch this.
142
-
143
- """
144
- d = {k : getattr (self , k ) for k in self .init_param }
145
- out = self .__class__ (** d )
146
- if name :
147
- out .name = name
148
- else :
149
- name = out .name
150
- super (ScalarLoop , out ).__init__ (output_types_preference , name )
151
- return out
140
+ return self .clone (output_types_preference = output_types_preference , name = name )
152
141
153
142
def make_node (self , n_steps , * inputs ):
154
143
assert len (inputs ) == self .nin - 1
@@ -229,11 +218,11 @@ def c_code_template(self):
229
218
c : f"%(i{ int (i )} )s"
230
219
for i , c in enumerate (fgraph .inputs [n_update :], start = n_update + 1 )
231
220
}
232
- update_subd = {
221
+ out_subd = {
233
222
u : f"%(o{ int (i )} )s" for i , u in enumerate (fgraph .outputs [:n_update ])
234
223
}
235
224
until_subd = {u : "until" for u in fgraph .outputs [n_update :]}
236
- subd = {** carry_subd , ** constant_subd , ** update_subd , ** until_subd }
225
+ subd = {** carry_subd , ** constant_subd , ** until_subd }
237
226
238
227
for var in fgraph .variables :
239
228
if var .owner is None :
@@ -257,11 +246,11 @@ def c_code_template(self):
257
246
_c_code += "bool until = 1;\n \n "
258
247
259
248
# Copy carried inputs
260
- for i , (var , name ) in enumerate (carry_subd .items ()):
261
- copy_var_name = f"{ name } _copy { i } "
262
- _c_code += f"{ var .type .dtype_specs ()[1 ]} { copy_var_name } = { name } ;\n "
263
- carry_subd [var ] = copy_var_name
264
- subd [var ] = copy_var_name
249
+ for i , (var , name ) in enumerate (carry_subd .items (), start = 1 ):
250
+ carry_var_name = f"{ name } _carry { i } "
251
+ _c_code += f"{ var .type .dtype_specs ()[1 ]} { carry_var_name } = { name } ;\n "
252
+ carry_subd [var ] = carry_var_name
253
+ subd [var ] = carry_var_name
265
254
266
255
# _c_code += 'printf("inputs=[");'
267
256
# for i in range(1, len(fgraph.inputs)):
@@ -270,9 +259,8 @@ def c_code_template(self):
270
259
271
260
_c_code += "\n for(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n "
272
261
273
- self .nodenames = [
274
- f"%(nodename)s_subnode{ int (j )} " for j , n in enumerate (fgraph .toposort ())
275
- ]
262
+ # Used by self.c_support_code_apply
263
+ self .nodenames = nodenames = []
276
264
277
265
i = 0
278
266
for j , node in enumerate (fgraph .toposort ()):
@@ -282,9 +270,13 @@ def c_code_template(self):
282
270
name = f"V%(id)s_tmp{ int (i )} "
283
271
subd [output ] = name
284
272
_c_code += f"{ output .type .dtype_specs ()[1 ]} { name } ;\n "
273
+
274
+ nodename = f"%(nodename)s_subnode{ int (j )} "
275
+ nodenames .append (nodename )
276
+
285
277
s = node .op .c_code (
286
278
node ,
287
- self . nodenames [ j ] ,
279
+ nodename ,
288
280
# Any node that depended on `init` will depend on `update` instead
289
281
# The initial value of `update` was set to `init` before the loop
290
282
[subd [input ] for input in node .inputs ],
@@ -294,10 +286,12 @@ def c_code_template(self):
294
286
_c_code += s
295
287
_c_code += "\n "
296
288
297
- # Set the carry variables to the output variables
289
+ # Update the carry variables to the output variables
298
290
_c_code += "\n "
299
- for init , update in zip (carry_subd .values (), update_subd .values (), strict = True ):
300
- _c_code += f"{ init } = { update } ;\n "
291
+ for carry , out in zip (
292
+ carry_subd .values (), fgraph .outputs [:n_update ], strict = True
293
+ ):
294
+ _c_code += f"{ carry } = { subd [out ]} ;\n "
301
295
302
296
# _c_code += 'printf("%%ld\\n", i);\n'
303
297
# for carry in range(1, 10):
@@ -309,6 +303,10 @@ def c_code_template(self):
309
303
# End of the loop
310
304
_c_code += "}\n "
311
305
306
+ # Assign the carry variables to the outputs
307
+ for out , carry in zip (out_subd .values (), carry_subd .values (), strict = True ):
308
+ _c_code += f"{ out } = { carry } ;\n "
309
+
312
310
# Output until flag
313
311
if self .is_while :
314
312
_c_code += f"%(o{ len (fgraph .outputs )- 1 } )s = until;\n "
@@ -343,4 +341,4 @@ def c_code(self, node, nodename, inames, onames, sub):
343
341
return res
344
342
345
343
def c_code_cache_version_outer (self ):
346
- return (3 ,)
344
+ return (4 ,)
0 commit comments