Skip to content

Commit bbf937c

Browse files
ricardoV94rlouf
authored andcommitted
Add rewrite for addition with negation x + (-y) -> x - y
* Also reverses test expectation for `local_expm1` rewrite, as previously missed case is now detected
1 parent 1247710 commit bbf937c

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

aesara/tensor/rewriting/math.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,6 +1806,35 @@ def local_sub_neg_to_add(fgraph, node):
18061806
return [new_out]
18071807

18081808

1809+
@register_specialize
1810+
@node_rewriter([add])
1811+
def local_add_neg_to_sub(fgraph, node):
1812+
"""
1813+
-x + y -> y - x
1814+
x + (-y) -> x - y
1815+
1816+
"""
1817+
# This rewrite is only registered during specialization, because the
1818+
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
1819+
1820+
# Rewrite is only applicable when there are two inputs to add
1821+
if node.op == add and len(node.inputs) == 2:
1822+
1823+
# Look for pattern with either input order
1824+
for first, second in (node.inputs, reversed(node.inputs)):
1825+
if second.owner:
1826+
if second.owner.op == neg:
1827+
pre_neg = second.owner.inputs[0]
1828+
new_out = sub(first, pre_neg)
1829+
return [new_out]
1830+
1831+
# Check if it is a negative constant
1832+
const = get_constant(second)
1833+
if const is not None and const < 0:
1834+
new_out = sub(first, np.abs(const))
1835+
return [new_out]
1836+
1837+
18091838
@register_canonicalize
18101839
@node_rewriter([mul])
18111840
def local_mul_zero(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4042,9 +4042,14 @@ def test_local_expm1():
40424042
for n in h.maker.fgraph.toposort()
40434043
)
40444044

4045-
assert not any(
4046-
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, aes.basic.Expm1)
4047-
for n in r.maker.fgraph.toposort()
4045+
# This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked
4046+
expect_rewrite = config.mode != "FAST_COMPILE"
4047+
assert (
4048+
any(
4049+
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, aes.basic.Expm1)
4050+
for n in r.maker.fgraph.toposort()
4051+
)
4052+
== expect_rewrite
40484053
)
40494054

40504055

@@ -4654,3 +4659,41 @@ def test_local_sub_neg_to_add_const():
46544659

46554660
x_test = np.array([3, 4], dtype=config.floatX)
46564661
assert np.allclose(f(x_test), x_test - (-const))
4662+
4663+
4664+
@pytest.mark.parametrize("first_negative", (True, False))
4665+
def test_local_add_neg_to_sub(first_negative):
4666+
x = scalar("x")
4667+
y = vector("y")
4668+
out = -x + y if first_negative else x + (-y)
4669+
4670+
f = function([x, y], out, mode=Mode("py"))
4671+
4672+
nodes = [
4673+
node.op
4674+
for node in f.maker.fgraph.toposort()
4675+
if not isinstance(node.op, DimShuffle)
4676+
]
4677+
assert nodes == [at.sub]
4678+
4679+
x_test = np.full((), 1.0, dtype=config.floatX)
4680+
y_test = np.full(5, 2.0, dtype=config.floatX)
4681+
exp = -x_test + y_test if first_negative else x_test + (-y_test)
4682+
assert np.allclose(f(x_test, y_test), exp)
4683+
4684+
4685+
def test_local_add_neg_to_sub_const():
4686+
x = vector("x")
4687+
const = 5.0
4688+
4689+
f = function([x], x + (-const), mode=Mode("py"))
4690+
4691+
nodes = [
4692+
node.op
4693+
for node in f.maker.fgraph.toposort()
4694+
if not isinstance(node.op, DimShuffle)
4695+
]
4696+
assert nodes == [at.sub]
4697+
4698+
x_test = np.array([3, 4], dtype=config.floatX)
4699+
assert np.allclose(f(x_test), x_test + (-const))

0 commit comments

Comments
 (0)