From 70db72e2a7fc0ee7a3df8e4304b1acc6b2a6c94f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 27 May 2025 13:36:55 +0200 Subject: [PATCH] Constant fold branches of variadic add/mul --- pytensor/tensor/rewriting/elemwise.py | 112 ++++++++++++++++++------ tests/tensor/rewriting/test_elemwise.py | 22 +++++ 2 files changed, 105 insertions(+), 29 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 4b5a5075eb..98fc4e074c 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -1,8 +1,9 @@ import itertools +import operator import sys from collections import Counter, defaultdict, deque from collections.abc import Generator -from functools import cache +from functools import cache, reduce from typing import TypeVar from warnings import warn @@ -16,11 +17,11 @@ from pytensor.graph.features import ReplaceValidate from pytensor.graph.fg import Output from pytensor.graph.rewriting.basic import ( - EquilibriumGraphRewriter, GraphRewriter, copy_stack_trace, in2out, node_rewriter, + out2in, ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined @@ -29,13 +30,15 @@ MakeVector, alloc, cast, + constant, get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import exp +from pytensor.tensor.math import add, exp, mul from pytensor.tensor.rewriting.basic import ( alloc_like, + broadcasted_by, register_canonicalize, register_specialize, ) @@ -542,8 +545,8 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): return rval -@node_rewriter([Elemwise]) -def local_add_mul_fusion(fgraph, node): +@node_rewriter([add, mul]) +def flatten_nested_add_mul(fgraph, node): """Fuse consecutive add or mul in one such node with more inputs. It is better to fuse add/mul that way then in a Composite node as @@ -554,27 +557,16 @@ def local_add_mul_fusion(fgraph, node): This rewrite is almost useless after the AlgebraicCanonizer is used, but it catches a few edge cases that are not canonicalized by it """ - if not ( - isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Add | ps.Mul) - ): - return False - - s_op = node.op.scalar_op.__class__ + s_op = node.op.scalar_op new_inp = [] fused = False - nb_inputs = len(node.inputs) - max_inputs = float("inf") - if hasattr(node.op, "max_inputs"): - max_inputs = node.op.max_inputs(node) for inp in node.inputs: if ( inp.owner and isinstance(inp.owner.op, Elemwise) - and isinstance(inp.owner.op.scalar_op, s_op) - and + and inp.owner.op.scalar_op == s_op # Do not duplicate the operation. - len(fgraph.clients[inp]) == 1 - and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs + and len(fgraph.clients[inp]) == 1 ): new_inp.extend(inp.owner.inputs) fused = True @@ -590,7 +582,7 @@ def local_add_mul_fusion(fgraph, node): # Do the recursion here to help lower the number of # FusionOptimizer iteration. if output.owner: - output2 = local_add_mul_fusion.transform(fgraph, output.owner) + output2 = flatten_nested_add_mul.transform(fgraph, output.owner) if output2: return output2 return [output] @@ -1237,6 +1229,76 @@ def local_inline_composite_constants(fgraph, node): return new_outputs +@node_rewriter(tracks=[add, mul]) +def constant_fold_branches_of_add_mul(fgraph, node): + old_constants = [inp for inp in node.inputs if isinstance(inp, TensorConstant)] + + if len(old_constants) <= 1: + return None + + new_constants = old_constants.copy() + + # Multiply constants if it doesn't result in higher intermediate memory + while True: + n_constants = len(new_constants) + if n_constants <= 1: + break + + for i in range(n_constants): + reference_inp = new_constants[i] + other_inps = [] + for j in range(n_constants): + if i == j: + continue + other_inp = new_constants[j] + if not broadcasted_by(reference_inp, other_inp): + other_inps.append(other_inp) + if other_inps: + python_op = operator.mul if node.op == mul else operator.add + folded_inputs = [reference_inp, *other_inps] + new_inp = constant( + reduce(python_op, (const.data for const in folded_inputs)) + ) + new_constants = [ + new_inp, + *(inp for inp in new_constants if inp not in folded_inputs), + ] + break + else: # no-break + break + + if len(new_constants) == len(old_constants): + return None + + non_constants = [inp for inp in node.inputs if not isinstance(inp, TensorConstant)] + new_out = node.op( + *new_constants, + *non_constants, + ) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +add_mul_fusion_seqopt = SequenceDB() +compile.optdb.register( + "add_mul_fusion", + add_mul_fusion_seqopt, + "fast_run", + position=48, # Before Elemwise fusion +) +add_mul_fusion_seqopt.register( + flatten_nested_add_mul.__name__, + out2in(flatten_nested_add_mul, ignore_newtrees=False), + "fast_run", + position=0, +) +add_mul_fusion_seqopt.register( + constant_fold_branches_of_add_mul.__name__, + in2out(constant_fold_branches_of_add_mul, ignore_newtrees=True), + "fast_run", + position=1, +) + # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) fuse_seqopt = SequenceDB() compile.optdb.register( @@ -1248,14 +1310,6 @@ def local_inline_composite_constants(fgraph, node): "FusionOptimizer", position=49, ) - -fuse_seqopt.register( - "local_add_mul_fusion", - EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), - "fast_run", - "fusion", - position=0, -) fuse_seqopt.register( "composite_elemwise_fusion", FusionOptimizer(), @@ -1279,7 +1333,7 @@ def local_inline_composite_constants(fgraph, node): ) fuse_seqopt.register( "local_inline_composite_constants", - in2out(local_inline_composite_constants), + in2out(local_inline_composite_constants, ignore_newtrees=True), "fast_run", "fusion", position=20, diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 4e7fe54581..cdd1f6bd77 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -238,6 +238,7 @@ class TestFusion: include=[ "canonicalize", "fusion", + "add_mul_fusion", "inplace", ], exclude=["cxx_only", "BlasOpt"], @@ -1507,3 +1508,24 @@ def test_local_useless_dimshuffle_makevector(): ) assert y_rewritten_fg.outputs[0] == a + + +@pytest.mark.parametrize("op", (add, mul)) +def test_constant_fold_branches_add_mul(op): + rng = np.random.default_rng() + py_op = np.add if op is add else np.multiply + + x = pt.vector("x") + a = rng.normal(size=(1, 512, 5)) + b = rng.normal(size=(1, 512, 1)) + out = op(op(a, x), b) + new_out = rewrite_graph(out, include=("add_mul_fusion",)) + assert len(new_out.owner.inputs) == 2 + assert equal_computations([new_out], [op(py_op(a, b), x)]) + + # c shouldn't be folded as it would increase the memory usage + c = rng.normal(size=(1024, 1, 1)) + out = op(op(op(a, x), c), b) + new_out = rewrite_graph(out, include=("add_mul_fusion",)) + assert len(new_out.owner.inputs) == 3 + assert equal_computations([new_out], [op(py_op(a, b), c, x)])