Skip to content

Commit 2d2f297

Browse files
Refactor push_out_non_seq_scan and remove unnecessary cloning
1 parent d47ce12 commit 2d2f297

File tree

1 file changed

+54
-47
lines changed

1 file changed

+54
-47
lines changed

aesara/scan/opt.py

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,11 @@ def push_out_non_seq_scan(fgraph, node):
203203
if not isinstance(node.op, Scan):
204204
return False
205205

206-
# this flag tells if there was any change during the last iterations
207-
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
206+
node_inputs, node_outputs = node.op.inputs, node.op.outputs
208207

209-
local_fgraph_topo = io_toposort(clean_inputs, clean_outputs)
210-
local_fgraph_outs_set = set(clean_outputs)
211-
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)}
208+
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
209+
local_fgraph_outs_set = set(node_outputs)
210+
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
212211

213212
to_remove_set = set()
214213
to_replace_set = set()
@@ -221,18 +220,21 @@ def add_to_replace(y):
221220

222221
add_to_replace.n = 0
223222

223+
# The variables that will replace the variables pushed-out of the
224+
# inner-graph
224225
replace_with_in = []
226+
# The variables that have been pushed-out of the graph
225227
replace_with_out = []
226228

227229
op = node.op
228230
# Construct the list of non_sequences to simplify a few things
229-
inner_non_seqs = op.inner_non_seqs(clean_inputs)
231+
inner_non_seqs = op.inner_non_seqs(node_inputs)
230232
inner_non_seqs_set = set(inner_non_seqs)
231233
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
232234

233235
outer_non_seqs = op.outer_non_seqs(node.inputs)
234236

235-
inner_seqs = op.inner_seqs(clean_inputs)
237+
inner_seqs = op.inner_seqs(node_inputs)
236238
outer_seqs = op.outer_seqs(node.inputs)
237239

238240
assert len(inner_non_seqs) == len(outer_non_seqs)
@@ -242,55 +244,60 @@ def add_to_replace(y):
242244
if ( # we haven't already looked at this node
243245
nd not in to_remove_set
244246
and all(
245-
[
246-
(
247-
(x in inner_non_seqs_set)
248-
or (x.owner in to_remove_set)
249-
or isinstance(x, Constant)
250-
)
251-
for x in nd.inputs
252-
]
247+
(
248+
(x in inner_non_seqs_set)
249+
or (x.owner in to_remove_set)
250+
or isinstance(x, Constant)
251+
)
252+
for x in nd.inputs
253253
)
254-
and
255-
# we can do this because the assumption is that a
256-
# viewOp or deepCopyOp will be just at the end of the
257-
# function and not somewhere in the middle ..
258-
not isinstance(nd.op, aesara.compile.ViewOp)
254+
# We can (supposedly) do this because the assumption is that a
255+
# `ViewOp` or `DeepCopyOp` will be just at the end of the
256+
# function and not somewhere in the middle
257+
and not isinstance(nd.op, aesara.compile.ViewOp)
259258
and not isinstance(nd.op, aesara.compile.DeepCopyOp)
260259
):
261-
262-
# We have a candidate node to removable
263-
# Step 1. Reconstruct it on outside
260+
# We have a candidate node to remove from the inner-graph
261+
262+
# Step 1. Reconstruct the node using the relevant outer-inputs.
263+
#
264+
# More specifically, the node's current inputs are either
265+
# a) inner-graph input place-holders for non-sequences,
266+
# b) the outputs of other nodes being pushed out of the inner-graph,
267+
# c) or constants.
264268
to_remove_set.add(nd)
265-
outside_ins = []
266-
for x in nd.inputs:
267-
if x in inner_non_seqs_set:
268-
_idx = inner_non_seqs_map[x]
269-
outside_ins.append(outer_non_seqs[_idx])
270-
elif x in to_replace_set:
271-
outside_ins.append(replace_with_out[to_replace_map[x]])
272-
elif isinstance(x, Constant):
273-
outside_ins.append(x.clone())
269+
new_inputs = []
270+
for old_input in nd.inputs:
271+
if old_input in inner_non_seqs_set:
272+
# This is case a), so we want to use the corresponding
273+
# outer-graph input as the input to our new pushed-out node
274+
_idx = inner_non_seqs_map[old_input]
275+
new_input = outer_non_seqs[_idx]
276+
elif old_input in to_replace_set:
277+
# This is case b), so we want to use the new pushed-out node
278+
# as the input to this new pushed-out node
279+
new_input = replace_with_out[to_replace_map[old_input]]
274280
else:
275-
# TODO: Explain why is this an error, and raise an
276-
# appropriate exception type.
277-
raise RuntimeError()
278-
outside_ins = [
279-
x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins)
280-
]
281+
assert isinstance(old_input, Constant)
282+
new_input = old_input
281283

282-
nw_outer_node = nd.op.make_node(*outside_ins)
284+
new_input = old_input.type.filter_variable(new_input)
285+
new_inputs.append(new_input)
286+
287+
pushed_out_node = nd.op.make_node(*new_inputs)
283288

284289
if config.compute_test_value != "off":
285-
compute_test_value(nw_outer_node)
290+
compute_test_value(pushed_out_node)
286291

287-
# Step 2. Create variables for replacements
292+
# Step 2. Create variables to replace the old outputs of the node
293+
# that we're pushing out of the inner-graph
288294
for idx, y in enumerate(nd.outputs):
289-
y_place_holder = safe_new(y, "_replace")
295+
y_place_holder = y.clone()
296+
# y_place_holder = safe_new(y, "_replace")
290297
add_to_replace(y)
291298
replace_with_in.append(y_place_holder)
292-
assert isinstance(y, type(nw_outer_node.outputs[idx]))
293-
replace_with_out.append(nw_outer_node.outputs[idx])
299+
assert isinstance(y, type(pushed_out_node.outputs[idx]))
300+
replace_with_out.append(pushed_out_node.outputs[idx])
294301

295302
# We need to check all candidate replacements and choose those that
296303
# make sense for us
@@ -326,14 +333,14 @@ def add_to_replace(y):
326333
clean_to_replace, clean_replace_with_in, clean_replace_with_out
327334
):
328335
if isinstance(repl_out, Constant):
329-
repl_in = repl_out.clone()
336+
repl_in = repl_out
330337
else:
331338
nw_inner.append(repl_in)
332339
nw_outer.append(repl_out)
333340
givens[to_repl] = repl_in
334341

335-
op_outs = clone_replace(clean_outputs, replace=givens)
336-
op_ins = clean_inputs + nw_inner
342+
op_outs = clone_replace(node_outputs, replace=givens)
343+
op_ins = node_inputs + nw_inner
337344

338345
# Reconstruct node
339346
nwScan = Scan(

0 commit comments

Comments
 (0)