Skip to content

Constant fold branches of variadic add/mul #1422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 83 additions & 29 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
import operator
import sys
from collections import Counter, defaultdict, deque
from collections.abc import Generator
from functools import cache
from functools import cache, reduce
from typing import TypeVar
from warnings import warn

Expand All @@ -16,11 +17,11 @@
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import Output
from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
out2in,
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
Expand All @@ -29,13 +30,15 @@
MakeVector,
alloc,
cast,
constant,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp
from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
register_canonicalize,
register_specialize,
)
Expand Down Expand Up @@ -542,8 +545,8 @@
return rval


@node_rewriter([Elemwise])
def local_add_mul_fusion(fgraph, node):
@node_rewriter([add, mul])
def flatten_nested_add_mul(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs.

It is better to fuse add/mul that way then in a Composite node as
Expand All @@ -554,27 +557,16 @@
This rewrite is almost useless after the AlgebraicCanonizer is used,
but it catches a few edge cases that are not canonicalized by it
"""
if not (
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Add | ps.Mul)
):
return False

s_op = node.op.scalar_op.__class__
s_op = node.op.scalar_op
new_inp = []
fused = False
nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs:
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, s_op)
and
and inp.owner.op.scalar_op == s_op
# Do not duplicate the operation.
len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
and len(fgraph.clients[inp]) == 1
):
new_inp.extend(inp.owner.inputs)
fused = True
Expand All @@ -590,7 +582,7 @@
# Do the recursion here to help lower the number of
# FusionOptimizer iteration.
if output.owner:
output2 = local_add_mul_fusion.transform(fgraph, output.owner)
output2 = flatten_nested_add_mul.transform(fgraph, output.owner)
if output2:
return output2
return [output]
Expand Down Expand Up @@ -1237,6 +1229,76 @@
return new_outputs


@node_rewriter(tracks=[add, mul])
def constant_fold_branches_of_add_mul(fgraph, node):
old_constants = [inp for inp in node.inputs if isinstance(inp, TensorConstant)]

if len(old_constants) <= 1:
return None

new_constants = old_constants.copy()

# Multiply constants if it doesn't result in higher intermediate memory
while True:
n_constants = len(new_constants)
if n_constants <= 1:
break

for i in range(n_constants):
reference_inp = new_constants[i]
other_inps = []
for j in range(n_constants):
if i == j:
continue
other_inp = new_constants[j]
if not broadcasted_by(reference_inp, other_inp):
other_inps.append(other_inp)
if other_inps:
python_op = operator.mul if node.op == mul else operator.add
folded_inputs = [reference_inp, *other_inps]
new_inp = constant(
reduce(python_op, (const.data for const in folded_inputs))
)
new_constants = [
new_inp,
*(inp for inp in new_constants if inp not in folded_inputs),
]
break
else: # no-break
break

if len(new_constants) == len(old_constants):
return None

Check warning on line 1271 in pytensor/tensor/rewriting/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/elemwise.py#L1271

Added line #L1271 was not covered by tests

non_constants = [inp for inp in node.inputs if not isinstance(inp, TensorConstant)]
new_out = node.op(
*new_constants,
*non_constants,
)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


add_mul_fusion_seqopt = SequenceDB()
compile.optdb.register(
"add_mul_fusion",
add_mul_fusion_seqopt,
"fast_run",
position=48, # Before Elemwise fusion
)
add_mul_fusion_seqopt.register(
flatten_nested_add_mul.__name__,
out2in(flatten_nested_add_mul, ignore_newtrees=False),
"fast_run",
position=0,
)
add_mul_fusion_seqopt.register(
constant_fold_branches_of_add_mul.__name__,
in2out(constant_fold_branches_of_add_mul, ignore_newtrees=True),
"fast_run",
position=1,
)

# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB()
compile.optdb.register(
Expand All @@ -1248,14 +1310,6 @@
"FusionOptimizer",
position=49,
)

fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(),
Expand All @@ -1279,7 +1333,7 @@
)
fuse_seqopt.register(
"local_inline_composite_constants",
in2out(local_inline_composite_constants),
in2out(local_inline_composite_constants, ignore_newtrees=True),
"fast_run",
"fusion",
position=20,
Expand Down
22 changes: 22 additions & 0 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class TestFusion:
include=[
"canonicalize",
"fusion",
"add_mul_fusion",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
Expand Down Expand Up @@ -1507,3 +1508,24 @@ def test_local_useless_dimshuffle_makevector():
)

assert y_rewritten_fg.outputs[0] == a


@pytest.mark.parametrize("op", (add, mul))
def test_constant_fold_branches_add_mul(op):
rng = np.random.default_rng()
py_op = np.add if op is add else np.multiply

x = pt.vector("x")
a = rng.normal(size=(1, 512, 5))
b = rng.normal(size=(1, 512, 1))
out = op(op(a, x), b)
new_out = rewrite_graph(out, include=("add_mul_fusion",))
assert len(new_out.owner.inputs) == 2
assert equal_computations([new_out], [op(py_op(a, b), x)])

# c shouldn't be folded as it would increase the memory usage
c = rng.normal(size=(1024, 1, 1))
out = op(op(op(a, x), c), b)
new_out = rewrite_graph(out, include=("add_mul_fusion",))
assert len(new_out.owner.inputs) == 3
assert equal_computations([new_out], [op(py_op(a, b), c, x)])