Skip to content

Commit 2cf617d

Browse files
ricardoV94rlouf
authored andcommitted
Remove redundant erf(c) rewrites
1 parent bbf937c commit 2cf617d

File tree

1 file changed

+4
-32
lines changed

1 file changed

+4
-32
lines changed

aesara/tensor/rewriting/math.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,8 +2575,6 @@ def local_greedy_distributor(fgraph, node):
25752575
register_stabilize(local_one_plus_erf)
25762576
register_specialize(local_one_plus_erf)
25772577

2578-
# Only one of the two rewrites below is needed if a canonicalization is added
2579-
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
25802578
# 1-erf(x)=>erfc(x)
25812579
local_one_minus_erf = PatternNodeRewriter(
25822580
(sub, 1, (erf, "x")),
@@ -2590,21 +2588,9 @@ def local_greedy_distributor(fgraph, node):
25902588
register_stabilize(local_one_minus_erf)
25912589
register_specialize(local_one_minus_erf)
25922590

2593-
local_one_minus_erf2 = PatternNodeRewriter(
2594-
(add, 1, (neg, (erf, "x"))),
2595-
(erfc, "x"),
2596-
allow_multiple_clients=True,
2597-
name="local_one_minus_erf2",
2598-
tracks=[erf],
2599-
get_nodes=get_clients_at_depth2,
2600-
)
2601-
register_canonicalize(local_one_minus_erf2)
2602-
register_stabilize(local_one_minus_erf2)
2603-
register_specialize(local_one_minus_erf2)
2604-
26052591
# (-1)+erf(x) => -erfc(x)
2606-
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
2607-
# convert those to the matched pattern
2592+
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the `local_add_mul`
2593+
# canonicalize will convert those to the matched pattern
26082594
local_erf_minus_one = PatternNodeRewriter(
26092595
(add, -1, (erf, "x")),
26102596
(neg, (erfc, "x")),
@@ -2617,8 +2603,6 @@ def local_greedy_distributor(fgraph, node):
26172603
register_stabilize(local_erf_minus_one)
26182604
register_specialize(local_erf_minus_one)
26192605

2620-
# Only one of the two rewrites below is needed if a canonicalization is added
2621-
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
26222606
# 1-erfc(x) => erf(x)
26232607
local_one_minus_erfc = PatternNodeRewriter(
26242608
(sub, 1, (erfc, "x")),
@@ -2632,21 +2616,9 @@ def local_greedy_distributor(fgraph, node):
26322616
register_stabilize(local_one_minus_erfc)
26332617
register_specialize(local_one_minus_erfc)
26342618

2635-
local_one_minus_erfc2 = PatternNodeRewriter(
2636-
(add, 1, (neg, (erfc, "x"))),
2637-
(erf, "x"),
2638-
allow_multiple_clients=True,
2639-
name="local_one_minus_erfc2",
2640-
tracks=[erfc],
2641-
get_nodes=get_clients_at_depth2,
2642-
)
2643-
register_canonicalize(local_one_minus_erfc2)
2644-
register_stabilize(local_one_minus_erfc2)
2645-
register_specialize(local_one_minus_erfc2)
2646-
2647-
# (-1)+erfc(-x)=>erf(x)
2619+
# erfc(-x)-1=>erf(x)
26482620
local_erf_neg_minus_one = PatternNodeRewriter(
2649-
(add, -1, (erfc, (neg, "x"))),
2621+
(sub, (erfc, (neg, "x")), 1),
26502622
(erf, "x"),
26512623
allow_multiple_clients=True,
26522624
name="local_erf_neg_minus_one",

0 commit comments

Comments
 (0)