Skip to content

Commit d039a05

Browse files
Remove unnecessary cloning from push_out_seq_scan
1 parent 2d2f297 commit d039a05

File tree

1 file changed

+17
-26
lines changed

1 file changed

+17
-26
lines changed

aesara/scan/opt.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -410,12 +410,11 @@ def push_out_seq_scan(fgraph, node):
410410
if not isinstance(node.op, Scan):
411411
return False
412412

413-
# this flag tells if there was any change during the last iterations
414-
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
413+
node_inputs, node_outputs = node.op.inputs, node.op.outputs
415414

416-
local_fgraph_topo = io_toposort(clean_inputs, clean_outputs)
417-
local_fgraph_outs_set = set(clean_outputs)
418-
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)}
415+
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
416+
local_fgraph_outs_set = set(node_outputs)
417+
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
419418

420419
to_remove_set = set()
421420
to_replace_set = set()
@@ -433,12 +432,12 @@ def add_to_replace(y):
433432

434433
op = node.op
435434
# Construct the list of non_sequences to simplify a few things
436-
inner_non_seqs = op.inner_non_seqs(clean_inputs)
435+
inner_non_seqs = op.inner_non_seqs(node_inputs)
437436
inner_non_seqs_set = set(inner_non_seqs)
438437
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
439438

440439
outer_non_seqs = op.outer_non_seqs(node.inputs)
441-
inner_seqs = op.inner_seqs(clean_inputs)
440+
inner_seqs = op.inner_seqs(node_inputs)
442441
inner_seqs_set = set(inner_seqs)
443442
inner_seqs_map = {v: k for k, v in enumerate(inner_seqs)}
444443

@@ -467,26 +466,18 @@ def add_to_replace(y):
467466
for x in nd.inputs:
468467
if x in inner_non_seqs_set:
469468
_idx = inner_non_seqs_map[x]
470-
outside_ins.append(outer_non_seqs[_idx])
469+
new_input = outer_non_seqs[_idx]
471470
elif x in inner_seqs_set:
472-
outside_ins.append(outer_seqs[inner_seqs_map[x]])
471+
new_input = outer_seqs[inner_seqs_map[x]]
473472
depends_on_seqs = True
474473
elif x in to_replace_set:
475-
outside_ins.append(replace_with_out[to_replace_map[x]])
474+
new_input = replace_with_out[to_replace_map[x]]
476475
depends_on_seqs = True
477-
elif isinstance(x, Constant):
478-
outside_ins.append(x.clone())
479476
else:
480-
raise Exception(
481-
(
482-
"Error in the `scan_pushout_seq_"
483-
"operations`. The optimization tries "
484-
"to move some computation from scan "
485-
"which is not allowed to move. Report "
486-
"this on aesara-users list"
487-
),
488-
x,
489-
)
477+
assert isinstance(x, Constant)
478+
new_input = x
479+
480+
outside_ins.append(new_input)
490481

491482
if not depends_on_seqs:
492483
# Removing this node from the inner graph of scan
@@ -580,15 +571,15 @@ def add_to_replace(y):
580571
clean_to_replace, clean_replace_with_in, clean_replace_with_out
581572
):
582573
if isinstance(repl_out, Constant):
583-
repl_in = repl_out.clone()
574+
repl_in = repl_out
584575
else:
585576
nw_inner.append(repl_in)
586577
nw_outer.append(repl_out)
587578

588579
givens[to_repl] = repl_in
589580

590-
op_outs = clone_replace(clean_outputs, replace=givens)
591-
op_ins = nw_inner + clean_inputs
581+
op_outs = clone_replace(node_outputs, replace=givens)
582+
op_ins = nw_inner + node_inputs
592583

593584
# Reconstruct node
594585
nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner))
@@ -621,7 +612,7 @@ def add_to_replace(y):
621612
if out in local_fgraph_outs_set:
622613
x = node.outputs[local_fgraph_outs_map[out]]
623614
_y = replace_with_out[idx]
624-
ls = clean_outputs
615+
ls = node_outputs
625616
if out in op.inner_mitsot_outs(ls):
626617
odx = op.inner_mitsot_outs(ls).index(out)
627618
inp = op.outer_mitsot(node.inputs)[odx]

0 commit comments

Comments
 (0)