@@ -410,12 +410,11 @@ def push_out_seq_scan(fgraph, node):
410
410
if not isinstance (node .op , Scan ):
411
411
return False
412
412
413
- # this flag tells if there was any change during the last iterations
414
- clean_inputs , clean_outputs = reconstruct_graph (node .op .inputs , node .op .outputs )
413
+ node_inputs , node_outputs = node .op .inputs , node .op .outputs
415
414
416
- local_fgraph_topo = io_toposort (clean_inputs , clean_outputs )
417
- local_fgraph_outs_set = set (clean_outputs )
418
- local_fgraph_outs_map = {v : k for k , v in enumerate (clean_outputs )}
415
+ local_fgraph_topo = io_toposort (node_inputs , node_outputs )
416
+ local_fgraph_outs_set = set (node_outputs )
417
+ local_fgraph_outs_map = {v : k for k , v in enumerate (node_outputs )}
419
418
420
419
to_remove_set = set ()
421
420
to_replace_set = set ()
@@ -433,12 +432,12 @@ def add_to_replace(y):
433
432
434
433
op = node .op
435
434
# Construct the list of non_sequences to simplify a few things
436
- inner_non_seqs = op .inner_non_seqs (clean_inputs )
435
+ inner_non_seqs = op .inner_non_seqs (node_inputs )
437
436
inner_non_seqs_set = set (inner_non_seqs )
438
437
inner_non_seqs_map = {v : k for k , v in enumerate (inner_non_seqs )}
439
438
440
439
outer_non_seqs = op .outer_non_seqs (node .inputs )
441
- inner_seqs = op .inner_seqs (clean_inputs )
440
+ inner_seqs = op .inner_seqs (node_inputs )
442
441
inner_seqs_set = set (inner_seqs )
443
442
inner_seqs_map = {v : k for k , v in enumerate (inner_seqs )}
444
443
@@ -467,26 +466,18 @@ def add_to_replace(y):
467
466
for x in nd .inputs :
468
467
if x in inner_non_seqs_set :
469
468
_idx = inner_non_seqs_map [x ]
470
- outside_ins . append ( outer_non_seqs [_idx ])
469
+ new_input = outer_non_seqs [_idx ]
471
470
elif x in inner_seqs_set :
472
- outside_ins . append ( outer_seqs [inner_seqs_map [x ]])
471
+ new_input = outer_seqs [inner_seqs_map [x ]]
473
472
depends_on_seqs = True
474
473
elif x in to_replace_set :
475
- outside_ins . append ( replace_with_out [to_replace_map [x ]])
474
+ new_input = replace_with_out [to_replace_map [x ]]
476
475
depends_on_seqs = True
477
- elif isinstance (x , Constant ):
478
- outside_ins .append (x .clone ())
479
476
else :
480
- raise Exception (
481
- (
482
- "Error in the `scan_pushout_seq_"
483
- "operations`. The optimization tries "
484
- "to move some computation from scan "
485
- "which is not allowed to move. Report "
486
- "this on aesara-users list"
487
- ),
488
- x ,
489
- )
477
+ assert isinstance (x , Constant )
478
+ new_input = x
479
+
480
+ outside_ins .append (new_input )
490
481
491
482
if not depends_on_seqs :
492
483
# Removing this node from the inner graph of scan
@@ -580,15 +571,15 @@ def add_to_replace(y):
580
571
clean_to_replace , clean_replace_with_in , clean_replace_with_out
581
572
):
582
573
if isinstance (repl_out , Constant ):
583
- repl_in = repl_out . clone ()
574
+ repl_in = repl_out
584
575
else :
585
576
nw_inner .append (repl_in )
586
577
nw_outer .append (repl_out )
587
578
588
579
givens [to_repl ] = repl_in
589
580
590
- op_outs = clone_replace (clean_outputs , replace = givens )
591
- op_ins = nw_inner + clean_inputs
581
+ op_outs = clone_replace (node_outputs , replace = givens )
582
+ op_ins = nw_inner + node_inputs
592
583
593
584
# Reconstruct node
594
585
nw_info = dataclasses .replace (op .info , n_seqs = op .info .n_seqs + len (nw_inner ))
@@ -621,7 +612,7 @@ def add_to_replace(y):
621
612
if out in local_fgraph_outs_set :
622
613
x = node .outputs [local_fgraph_outs_map [out ]]
623
614
_y = replace_with_out [idx ]
624
- ls = clean_outputs
615
+ ls = node_outputs
625
616
if out in op .inner_mitsot_outs (ls ):
626
617
odx = op .inner_mitsot_outs (ls ).index (out )
627
618
inp = op .outer_mitsot (node .inputs )[odx ]
0 commit comments