@@ -209,7 +209,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
209
209
210
210
211
211
@node_rewriter ([Scan ])
212
- def push_out_non_seq_scan (fgraph , node ):
212
+ def scan_push_out_non_seq (fgraph , node ):
213
213
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
214
214
215
215
This optimizations pushes, out of `Scan`'s inner function and into the outer
@@ -417,10 +417,10 @@ def add_to_replace(y):
417
417
418
418
419
419
@node_rewriter ([Scan ])
420
- def push_out_seq_scan (fgraph , node ):
420
+ def scan_push_out_seq (fgraph , node ):
421
421
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
422
422
423
- This optimization resembles `push_out_non_seq_scan ` but it tries to push--out of
423
+ This optimization resembles `scan_push_out_non_seq ` but it tries to push--out of
424
424
the inner function--the computation that only relies on sequence and
425
425
non-sequence inputs. The idea behind this optimization is that, when it is
426
426
possible to do so, it is generally more computationally efficient to perform
@@ -822,10 +822,10 @@ def add_nitsot_outputs(
822
822
823
823
824
824
@node_rewriter ([Scan ])
825
- def push_out_add_scan (fgraph , node ):
825
+ def scan_push_out_add (fgraph , node ):
826
826
r"""Push `Add` operations performed at the end of the inner graph to the outside.
827
827
828
- Like `push_out_seq_scan `, this optimization aims to replace many operations
828
+ Like `scan_push_out_seq `, this optimization aims to replace many operations
829
829
on small tensors by few operations on large tensors. It can also lead to
830
830
increased memory usage.
831
831
"""
@@ -1185,7 +1185,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
1185
1185
1186
1186
1187
1187
@node_rewriter ([Scan ])
1188
- def save_mem_new_scan (fgraph , node ):
1188
+ def scan_save_mem (fgraph , node ):
1189
1189
r"""Graph optimizer that reduces scan memory consumption.
1190
1190
1191
1191
This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -2282,7 +2282,7 @@ def map_out(outer_i, inner_o, outer_o, seen):
2282
2282
2283
2283
2284
2284
@node_rewriter ([Scan ])
2285
- def push_out_dot1_scan (fgraph , node ):
2285
+ def scan_push_out_dot1 (fgraph , node ):
2286
2286
r"""
2287
2287
This is another optimization that attempts to detect certain patterns of
2288
2288
computation in a `Scan` `Op`'s inner function and move this computation to the
@@ -2483,7 +2483,7 @@ def push_out_dot1_scan(fgraph, node):
2483
2483
# ScanSaveMem should execute only once per node.
2484
2484
optdb .register (
2485
2485
"scan_save_mem" ,
2486
- in2out (save_mem_new_scan , ignore_newtrees = True ),
2486
+ in2out (scan_save_mem , ignore_newtrees = True ),
2487
2487
"fast_run" ,
2488
2488
"scan" ,
2489
2489
position = 1.61 ,
@@ -2511,8 +2511,9 @@ def push_out_dot1_scan(fgraph, node):
2511
2511
2512
2512
2513
2513
scan_seqopt1 .register (
2514
- "scan_pushout_nonseqs_ops" ,
2515
- in2out (push_out_non_seq_scan , ignore_newtrees = True ),
2514
+ "scan_push_out_non_seq" ,
2515
+ in2out (scan_push_out_non_seq , ignore_newtrees = True ),
2516
+ "scan_pushout_nonseqs_ops" , # For backcompat: so it can be tagged with old name
2516
2517
"fast_run" ,
2517
2518
"scan" ,
2518
2519
"scan_pushout" ,
@@ -2521,8 +2522,9 @@ def push_out_dot1_scan(fgraph, node):
2521
2522
2522
2523
2523
2524
scan_seqopt1 .register (
2524
- "scan_pushout_seqs_ops" ,
2525
- in2out (push_out_seq_scan , ignore_newtrees = True ),
2525
+ "scan_push_out_seq" ,
2526
+ in2out (scan_push_out_seq , ignore_newtrees = True ),
2527
+ "scan_pushout_seqs_ops" , # For backcompat: so it can be tagged with old name
2526
2528
"fast_run" ,
2527
2529
"scan" ,
2528
2530
"scan_pushout" ,
@@ -2531,8 +2533,9 @@ def push_out_dot1_scan(fgraph, node):
2531
2533
2532
2534
2533
2535
scan_seqopt1 .register (
2534
- "scan_pushout_dot1" ,
2535
- in2out (push_out_dot1_scan , ignore_newtrees = True ),
2536
+ "scan_push_out_dot1" ,
2537
+ in2out (scan_push_out_dot1 , ignore_newtrees = True ),
2538
+ "scan_pushout_dot1" , # For backcompat: so it can be tagged with old name
2536
2539
"fast_run" ,
2537
2540
"more_mem" ,
2538
2541
"scan" ,
@@ -2542,9 +2545,10 @@ def push_out_dot1_scan(fgraph, node):
2542
2545
2543
2546
2544
2547
scan_seqopt1 .register (
2545
- "scan_pushout_add " ,
2548
+ "scan_push_out_add " ,
2546
2549
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2547
- in2out (push_out_add_scan , ignore_newtrees = False ),
2550
+ in2out (scan_push_out_add , ignore_newtrees = False ),
2551
+ "scan_pushout_add" , # For backcompat: so it can be tagged with old name
2548
2552
"fast_run" ,
2549
2553
"more_mem" ,
2550
2554
"scan" ,
0 commit comments