Skip to content

Constant fold branches of variadic add/mul #1422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 27, 2025

Refactoring and renaming:

  • Renamed the local_add_mul_fusion function to flatten_nested_add_mul to better reflect its purpose of flattening nested add/mul operations. The function now explicitly tracks add and mul operations instead of relying on generic Elemwise checks. [1] [2] [3]

New optimization for constant folding:

  • Introduced a new rewrite function, constant_fold_branches_of_add_mul, which folds constants in add/mul operations when it does not result in higher intermediate memory usage. This optimization is registered in a new sequence database, add_mul_flat_seqopt, which runs before generic elementwise fusion.

  • This is pulled out of a separate database so it's included in JAX rewrites (JAX does not include fusion rewrites). We've found this could help avoding XLA constant fold (CC @lucianopaz)


📚 Documentation preview 📚: https://pytensor--1422.org.readthedocs.build/en/1422/

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t know why the test fails but it looks like the fusion rewrite is applied only once. Maybe the equilibrium rewrite that you took out should be added back in?

nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn’t this needed?

# Do not duplicate the operation.
len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn’t this needed?


fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn’t this needed?

)
add_mul_flat_seqopt.register(
constant_fold_branches_of_add_mul.__name__,
in2out(constant_fold_branches_of_add_mul, ignore_newtrees=True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should ignore new trees be true?

@lucianopaz
Copy link
Member

lucianopaz commented May 29, 2025

@ricardoV94, I just went through your branch's code and found that the error is coming from the fact that the TestFusion class is including: "canonicalize", "fusion", and "inplace" rewrite databases. The add_mul flatten and fusion rewrites that you moved or added here are only included in "fast_run". My question then is whether your rewrites should also be added to fusion, or if you only want to add the fast_run database rewrites to the TestFusion includes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants