-
Notifications
You must be signed in to change notification settings - Fork 132
Constant fold branches of variadic add/mul #1422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Constant fold branches of variadic add/mul #1422
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t know why the test fails but it looks like the fusion rewrite is applied only once. Maybe the equilibrium rewrite that you took out should be added back in?
nb_inputs = len(node.inputs) | ||
max_inputs = float("inf") | ||
if hasattr(node.op, "max_inputs"): | ||
max_inputs = node.op.max_inputs(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn’t this needed?
# Do not duplicate the operation. | ||
len(fgraph.clients[inp]) == 1 | ||
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn’t this needed?
|
||
fuse_seqopt.register( | ||
"local_add_mul_fusion", | ||
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn’t this needed?
) | ||
add_mul_flat_seqopt.register( | ||
constant_fold_branches_of_add_mul.__name__, | ||
in2out(constant_fold_branches_of_add_mul, ignore_newtrees=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should ignore new trees be true?
@ricardoV94, I just went through your branch's code and found that the error is coming from the fact that the |
Refactoring and renaming:
local_add_mul_fusion
function toflatten_nested_add_mul
to better reflect its purpose of flattening nested add/mul operations. The function now explicitly tracksadd
andmul
operations instead of relying on genericElemwise
checks. [1] [2] [3]New optimization for constant folding:
Introduced a new rewrite function,
constant_fold_branches_of_add_mul
, which folds constants in add/mul operations when it does not result in higher intermediate memory usage. This optimization is registered in a new sequence database,add_mul_flat_seqopt
, which runs before generic elementwise fusion.This is pulled out of a separate database so it's included in JAX rewrites (JAX does not include fusion rewrites). We've found this could help avoding XLA constant fold (CC @lucianopaz)
📚 Documentation preview 📚: https://pytensor--1422.org.readthedocs.build/en/1422/