Skip to content

Commit 4948903

Browse files
committed
Harmonize Scan rewrite and tag names
1 parent 20b6a20 commit 4948903

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def add_inner_in_expr(
184184
# rotation for initially truncated storage.
185185
output_storage_post_proc_stmts: list[str] = []
186186

187-
# In truncated storage situations (e.g. created by `save_mem_new_scan`),
187+
# In truncated storage situations (e.g. created by `scan_save_mem`),
188188
# the taps and output storage overlap, instead of the standard situation in
189189
# which the output storage is large enough to contain both the initial taps
190190
# values and the output storage. In this truncated case, we use the

pytensor/scan/rewriting.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
209209

210210

211211
@node_rewriter([Scan])
212-
def push_out_non_seq_scan(fgraph, node):
212+
def scan_push_out_non_seq(fgraph, node):
213213
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
214214
215215
This optimizations pushes, out of `Scan`'s inner function and into the outer
@@ -417,10 +417,10 @@ def add_to_replace(y):
417417

418418

419419
@node_rewriter([Scan])
420-
def push_out_seq_scan(fgraph, node):
420+
def scan_push_out_seq(fgraph, node):
421421
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
422422
423-
This optimization resembles `push_out_non_seq_scan` but it tries to push--out of
423+
This optimization resembles `scan_push_out_non_seq` but it tries to push--out of
424424
the inner function--the computation that only relies on sequence and
425425
non-sequence inputs. The idea behind this optimization is that, when it is
426426
possible to do so, it is generally more computationally efficient to perform
@@ -822,10 +822,10 @@ def add_nitsot_outputs(
822822

823823

824824
@node_rewriter([Scan])
825-
def push_out_add_scan(fgraph, node):
825+
def scan_push_out_add(fgraph, node):
826826
r"""Push `Add` operations performed at the end of the inner graph to the outside.
827827
828-
Like `push_out_seq_scan`, this optimization aims to replace many operations
828+
Like `scan_push_out_seq`, this optimization aims to replace many operations
829829
on small tensors by few operations on large tensors. It can also lead to
830830
increased memory usage.
831831
"""
@@ -1185,7 +1185,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11851185

11861186

11871187
@node_rewriter([Scan])
1188-
def save_mem_new_scan(fgraph, node):
1188+
def scan_save_mem(fgraph, node):
11891189
r"""Graph optimizer that reduces scan memory consumption.
11901190
11911191
This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -2282,7 +2282,7 @@ def map_out(outer_i, inner_o, outer_o, seen):
22822282

22832283

22842284
@node_rewriter([Scan])
2285-
def push_out_dot1_scan(fgraph, node):
2285+
def scan_push_out_dot1(fgraph, node):
22862286
r"""
22872287
This is another optimization that attempts to detect certain patterns of
22882288
computation in a `Scan` `Op`'s inner function and move this computation to the
@@ -2483,7 +2483,7 @@ def push_out_dot1_scan(fgraph, node):
24832483
# ScanSaveMem should execute only once per node.
24842484
optdb.register(
24852485
"scan_save_mem",
2486-
in2out(save_mem_new_scan, ignore_newtrees=True),
2486+
in2out(scan_save_mem, ignore_newtrees=True),
24872487
"fast_run",
24882488
"scan",
24892489
position=1.61,
@@ -2511,8 +2511,9 @@ def push_out_dot1_scan(fgraph, node):
25112511

25122512

25132513
scan_seqopt1.register(
2514-
"scan_pushout_nonseqs_ops",
2515-
in2out(push_out_non_seq_scan, ignore_newtrees=True),
2514+
"scan_push_out_non_seq",
2515+
in2out(scan_push_out_non_seq, ignore_newtrees=True),
2516+
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
25162517
"fast_run",
25172518
"scan",
25182519
"scan_pushout",
@@ -2521,8 +2522,9 @@ def push_out_dot1_scan(fgraph, node):
25212522

25222523

25232524
scan_seqopt1.register(
2524-
"scan_pushout_seqs_ops",
2525-
in2out(push_out_seq_scan, ignore_newtrees=True),
2525+
"scan_push_out_seq",
2526+
in2out(scan_push_out_seq, ignore_newtrees=True),
2527+
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
25262528
"fast_run",
25272529
"scan",
25282530
"scan_pushout",
@@ -2531,8 +2533,9 @@ def push_out_dot1_scan(fgraph, node):
25312533

25322534

25332535
scan_seqopt1.register(
2534-
"scan_pushout_dot1",
2535-
in2out(push_out_dot1_scan, ignore_newtrees=True),
2536+
"scan_push_out_dot1",
2537+
in2out(scan_push_out_dot1, ignore_newtrees=True),
2538+
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name
25362539
"fast_run",
25372540
"more_mem",
25382541
"scan",
@@ -2542,9 +2545,10 @@ def push_out_dot1_scan(fgraph, node):
25422545

25432546

25442547
scan_seqopt1.register(
2545-
"scan_pushout_add",
2548+
"scan_push_out_add",
25462549
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2547-
in2out(push_out_add_scan, ignore_newtrees=False),
2550+
in2out(scan_push_out_add, ignore_newtrees=False),
2551+
"scan_pushout_add", # For backcompat: so it can be tagged with old name
25482552
"fast_run",
25492553
"more_mem",
25502554
"scan",

tests/scan/test_rewriting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def fn(i, i_tm1):
304304

305305
class TestPushOutNonSeqScan:
306306
"""
307-
Tests for the `push_out_non_seq_scan` optimization in the case where the inner
307+
Tests for the `scan_push_out_non_seq` optimization in the case where the inner
308308
function of a `Scan` `Op` has an output which is the result of a `Dot` product
309309
on a non-sequence matrix input to `Scan` and a vector that is the result of
310310
computation in the inner function.
@@ -595,7 +595,7 @@ def inner_func(x):
595595

596596
class TestPushOutAddScan:
597597
"""
598-
Test case for the `push_out_add_scan` optimization in the case where the `Scan`
598+
Test case for the `scan_push_out_add` optimization in the case where the `Scan`
599599
is used to compute the sum over the dot products between the corresponding
600600
elements of two list of matrices.
601601
@@ -1208,7 +1208,7 @@ def test_inplace3(self):
12081208

12091209

12101210
class TestSaveMem:
1211-
mode = get_default_mode().including("scan_save_mem", "save_mem_new_scan")
1211+
mode = get_default_mode().including("scan_save_mem", "scan_save_mem")
12121212

12131213
def test_save_mem(self):
12141214
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)