Skip to content

Commit 1247710

Browse files
ricardoV94rlouf
authored andcommitted
Add rewrite for subtraction with negation x - (-y) -> x + y
1 parent 77df667 commit 1247710

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

aesara/tensor/rewriting/math.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,24 @@ def local_neg_div_neg(fgraph, node):
17881788
return [true_div(new_num, denom)]
17891789

17901790

1791+
@register_canonicalize
1792+
@register_specialize
1793+
@node_rewriter([sub])
1794+
def local_sub_neg_to_add(fgraph, node):
1795+
"""
1796+
x - (-y) -> x + y
1797+
1798+
"""
1799+
if node.op == sub:
1800+
minuend, subtrahend = node.inputs
1801+
1802+
if subtrahend.owner:
1803+
if subtrahend.owner.op == neg:
1804+
pre_neg = subtrahend.owner.inputs[0]
1805+
new_out = add(minuend, pre_neg)
1806+
return [new_out]
1807+
1808+
17911809
@register_canonicalize
17921810
@node_rewriter([mul])
17931811
def local_mul_zero(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4618,3 +4618,39 @@ def test_deprecations():
46184618
"""Make sure we can import from deprecated modules."""
46194619
with pytest.deprecated_call():
46204620
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
4621+
4622+
4623+
def test_local_sub_neg_to_add():
4624+
x = scalar("x")
4625+
y = vector("y")
4626+
4627+
f = function([x, y], x - (-y), mode=Mode("py"))
4628+
4629+
nodes = [
4630+
node.op
4631+
for node in f.maker.fgraph.toposort()
4632+
if not isinstance(node.op, DimShuffle)
4633+
]
4634+
assert nodes == [at.add]
4635+
4636+
x_test = np.full((), 1.0, dtype=config.floatX)
4637+
y_test = np.full(5, 2.0, dtype=config.floatX)
4638+
assert np.allclose(f(x_test, y_test), x_test - (-y_test))
4639+
4640+
4641+
def test_local_sub_neg_to_add_const():
4642+
# This rewrite is achieved by the local_add_canonizer
4643+
x = vector("x")
4644+
const = 5.0
4645+
4646+
f = function([x], x - (-const), mode=Mode("py"))
4647+
4648+
nodes = [
4649+
node.op
4650+
for node in f.maker.fgraph.toposort()
4651+
if not isinstance(node.op, DimShuffle)
4652+
]
4653+
assert nodes == [at.add]
4654+
4655+
x_test = np.array([3, 4], dtype=config.floatX)
4656+
assert np.allclose(f(x_test), x_test - (-const))

0 commit comments

Comments
 (0)