Skip to content

Rewrite specifically for Sum and Prod to remove Join #951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.math import prod as pt_prod
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
Expand Down Expand Up @@ -1754,6 +1755,30 @@ def local_reduce_broadcastable(fgraph, node):
# -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)]

@register_canonicalize
@register_uncanonicalize
@register_specialize
@node_rewriter([Sum, Prod])
Copy link
Member

@ricardoV94 ricardoV94 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When there's nothing special about Sum and Prod we should apply the rewrites to all CAReduce operations, which also include stuff like Max, All, Any, ..., of which Sum/Prod are just two more instances

See the related PR I linked to in the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, 2nd time. I did not see this one show up in #59 😄.

Copy link
Member

@ricardoV94 ricardoV94 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No they are not the same exactly.

There's a rewrite for reduction along axis0 for a join along axis 0.

My PR extends this to any axis.

There's then the question of multiple axis, of which axis=None is the most extreme (all axes). This PR can cover that case.

We may also want to think about multiple but not all axis. In what cases can we reduce first and join later?

def local_useless_join_(fgraph, node):
"""
sum(join(tensor1, tensor2...)) => sum(sum(tensor) for tensor in tensors)
or
prod(join(tensor1, tensor2...)) => prod(prod(tensor) for tensor in tensors)

"""
(node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Join):
inpts = node_inps.owner.inputs[1:]
# This specific implementation would introduce a
# `MakeVector` into the graph, which would then
# be rewritten again with
# pytensor/tensor/rewriting/basic.py:local_sum_make_vector
# A similar rewrite must be created for `prod`
if isinstance(node.op, Sum):
return [pt_sum([pt_sum(inp) for inp in inpts])]
elif isinstance(node.op, Prod):
return [pt_prod([pt_prod(inp) for inp in inpts])]


@register_specialize
@node_rewriter([Sum, Prod])
Expand Down
Loading