diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 75dba82d97..0129740901 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -86,6 +86,7 @@ from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import pow as pt_pow from pytensor.tensor.math import sum as pt_sum +from pytensor.tensor.math import prod as pt_prod from pytensor.tensor.rewriting.basic import ( alloc_like, broadcasted_by, @@ -1754,6 +1755,30 @@ def local_reduce_broadcastable(fgraph, node): # -- in this case we can remove the reduction completely return [new_reduced.astype(odtype)] +@register_canonicalize +@register_uncanonicalize +@register_specialize +@node_rewriter([Sum, Prod]) +def local_useless_join_(fgraph, node): + """ + sum(join(tensor1, tensor2...)) => sum(sum(tensor) for tensor in tensors) + or + prod(join(tensor1, tensor2...)) => prod(prod(tensor) for tensor in tensors) + + """ + (node_inps,) = node.inputs + if node_inps.owner and isinstance(node_inps.owner.op, Join): + inpts = node_inps.owner.inputs[1:] + # This specific implementation would introduce a + # `MakeVector` into the graph, which would then + # be rewritten again with + # pytensor/tensor/rewriting/basic.py:local_sum_make_vector + # A similar rewrite must be created for `prod` + if isinstance(node.op, Sum): + return [pt_sum([pt_sum(inp) for inp in inpts])] + elif isinstance(node.op, Prod): + return [pt_prod([pt_prod(inp) for inp in inpts])] + @register_specialize @node_rewriter([Sum, Prod])