@@ -203,12 +203,11 @@ def push_out_non_seq_scan(fgraph, node):
203
203
if not isinstance (node .op , Scan ):
204
204
return False
205
205
206
- # this flag tells if there was any change during the last iterations
207
- clean_inputs , clean_outputs = reconstruct_graph (node .op .inputs , node .op .outputs )
206
+ node_inputs , node_outputs = node .op .inputs , node .op .outputs
208
207
209
- local_fgraph_topo = io_toposort (clean_inputs , clean_outputs )
210
- local_fgraph_outs_set = set (clean_outputs )
211
- local_fgraph_outs_map = {v : k for k , v in enumerate (clean_outputs )}
208
+ local_fgraph_topo = io_toposort (node_inputs , node_outputs )
209
+ local_fgraph_outs_set = set (node_outputs )
210
+ local_fgraph_outs_map = {v : k for k , v in enumerate (node_outputs )}
212
211
213
212
to_remove_set = set ()
214
213
to_replace_set = set ()
@@ -221,18 +220,21 @@ def add_to_replace(y):
221
220
222
221
add_to_replace .n = 0
223
222
223
+ # The variables that will replace the variables pushed-out of the
224
+ # inner-graph
224
225
replace_with_in = []
226
+ # The variables that have been pushed-out of the graph
225
227
replace_with_out = []
226
228
227
229
op = node .op
228
230
# Construct the list of non_sequences to simplify a few things
229
- inner_non_seqs = op .inner_non_seqs (clean_inputs )
231
+ inner_non_seqs = op .inner_non_seqs (node_inputs )
230
232
inner_non_seqs_set = set (inner_non_seqs )
231
233
inner_non_seqs_map = {v : k for k , v in enumerate (inner_non_seqs )}
232
234
233
235
outer_non_seqs = op .outer_non_seqs (node .inputs )
234
236
235
- inner_seqs = op .inner_seqs (clean_inputs )
237
+ inner_seqs = op .inner_seqs (node_inputs )
236
238
outer_seqs = op .outer_seqs (node .inputs )
237
239
238
240
assert len (inner_non_seqs ) == len (outer_non_seqs )
@@ -242,55 +244,60 @@ def add_to_replace(y):
242
244
if ( # we haven't already looked at this node
243
245
nd not in to_remove_set
244
246
and all (
245
- [
246
- (
247
- (x in inner_non_seqs_set )
248
- or (x .owner in to_remove_set )
249
- or isinstance (x , Constant )
250
- )
251
- for x in nd .inputs
252
- ]
247
+ (
248
+ (x in inner_non_seqs_set )
249
+ or (x .owner in to_remove_set )
250
+ or isinstance (x , Constant )
251
+ )
252
+ for x in nd .inputs
253
253
)
254
- and
255
- # we can do this because the assumption is that a
256
- # viewOp or deepCopyOp will be just at the end of the
257
- # function and not somewhere in the middle ..
258
- not isinstance (nd .op , aesara .compile .ViewOp )
254
+ # We can (supposedly) do this because the assumption is that a
255
+ # `ViewOp` or `DeepCopyOp` will be just at the end of the
256
+ # function and not somewhere in the middle
257
+ and not isinstance (nd .op , aesara .compile .ViewOp )
259
258
and not isinstance (nd .op , aesara .compile .DeepCopyOp )
260
259
):
261
-
262
- # We have a candidate node to removable
263
- # Step 1. Reconstruct it on outside
260
+ # We have a candidate node to remove from the inner-graph
261
+
262
+ # Step 1. Reconstruct the node using the relevant outer-inputs.
263
+ #
264
+ # More specifically, the node's current inputs are either
265
+ # a) inner-graph input place-holders for non-sequences,
266
+ # b) the outputs of other nodes being pushed out of the inner-graph,
267
+ # c) or constants.
264
268
to_remove_set .add (nd )
265
- outside_ins = []
266
- for x in nd .inputs :
267
- if x in inner_non_seqs_set :
268
- _idx = inner_non_seqs_map [x ]
269
- outside_ins .append (outer_non_seqs [_idx ])
270
- elif x in to_replace_set :
271
- outside_ins .append (replace_with_out [to_replace_map [x ]])
272
- elif isinstance (x , Constant ):
273
- outside_ins .append (x .clone ())
269
+ new_inputs = []
270
+ for old_input in nd .inputs :
271
+ if old_input in inner_non_seqs_set :
272
+ # This is case a), so we want to use the corresponding
273
+ # outer-graph input as the input to our new pushed-out node
274
+ _idx = inner_non_seqs_map [old_input ]
275
+ new_input = outer_non_seqs [_idx ]
276
+ elif old_input in to_replace_set :
277
+ # This is case b), so we want to use the new pushed-out node
278
+ # as the input to this new pushed-out node
279
+ new_input = replace_with_out [to_replace_map [old_input ]]
274
280
else :
275
- # TODO: Explain why is this an error, and raise an
276
- # appropriate exception type.
277
- raise RuntimeError ()
278
- outside_ins = [
279
- x .type .filter_variable (y ) for x , y in zip (nd .inputs , outside_ins )
280
- ]
281
+ assert isinstance (old_input , Constant )
282
+ new_input = old_input
281
283
282
- nw_outer_node = nd .op .make_node (* outside_ins )
284
+ new_input = old_input .type .filter_variable (new_input )
285
+ new_inputs .append (new_input )
286
+
287
+ pushed_out_node = nd .op .make_node (* new_inputs )
283
288
284
289
if config .compute_test_value != "off" :
285
- compute_test_value (nw_outer_node )
290
+ compute_test_value (pushed_out_node )
286
291
287
- # Step 2. Create variables for replacements
292
+ # Step 2. Create variables to replace the old outputs of the node
293
+ # that we're pushing out of the inner-graph
288
294
for idx , y in enumerate (nd .outputs ):
289
- y_place_holder = safe_new (y , "_replace" )
295
+ y_place_holder = y .clone ()
296
+ # y_place_holder = safe_new(y, "_replace")
290
297
add_to_replace (y )
291
298
replace_with_in .append (y_place_holder )
292
- assert isinstance (y , type (nw_outer_node .outputs [idx ]))
293
- replace_with_out .append (nw_outer_node .outputs [idx ])
299
+ assert isinstance (y , type (pushed_out_node .outputs [idx ]))
300
+ replace_with_out .append (pushed_out_node .outputs [idx ])
294
301
295
302
# We need to check all candidate replacements and choose those that
296
303
# make sense for us
@@ -326,14 +333,14 @@ def add_to_replace(y):
326
333
clean_to_replace , clean_replace_with_in , clean_replace_with_out
327
334
):
328
335
if isinstance (repl_out , Constant ):
329
- repl_in = repl_out . clone ()
336
+ repl_in = repl_out
330
337
else :
331
338
nw_inner .append (repl_in )
332
339
nw_outer .append (repl_out )
333
340
givens [to_repl ] = repl_in
334
341
335
- op_outs = clone_replace (clean_outputs , replace = givens )
336
- op_ins = clean_inputs + nw_inner
342
+ op_outs = clone_replace (node_outputs , replace = givens )
343
+ op_ins = node_inputs + nw_inner
337
344
338
345
# Reconstruct node
339
346
nwScan = Scan (
0 commit comments