Skip to content

Extend sum of mul rewrite for multiple axis #484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 90 additions & 160 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,86 +1190,109 @@ def local_neg_to_mul(fgraph, node):

@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_mul_by_scalar(fgraph, node):
def local_sum_prod_of_mul_or_div(fgraph, node):
"""
sum(scalar * smth) -> scalar * sum(smth)
sum(-smth) -> -sum(smth)
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions

or

prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
prod(-smth) -> -1 ** size(smth) * prod(smth)
prod(a * X) -> (a ** size(X)) * prod(X)

It also applies to reduction of X / a,
but not a / X, as that would still require inverting every value in X before the reduction

TODO: In the case where not all axis overlap with broadcast dimensions,
consider introducing an outer reduction after factoring out the compatible reduced dimensions
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, (Sum, Prod)):
(node_inps,) = node.inputs
if node_inps.owner and node_inps.owner.op == mul:
terms = node_inps.owner.inputs
scalars = [t.dimshuffle() for t in terms if all(t.type.broadcastable)]

if len(scalars) == 0:
return
[node_inps] = node.inputs
if not node_inps.owner:
return None

non_scalars = [t for t in terms if not all(t.broadcastable)]
inner_op = node_inps.owner.op
if not (inner_op == mul or inner_op == true_div):
return None

# Perform the op only on the non-scalar inputs, if applicable
if len(non_scalars) == 0:
new_op_input_nb_elements = 1
new_op_output = 1
elif len(non_scalars) == 1:
new_op_input_nb_elements = non_scalars[0].size
new_op_output = node.op(non_scalars[0])
else:
new_op_input = mul(*non_scalars)
# We assume that errors always come from the prod/mul op in the
# original computational graph, and therefore need to only
# copy over its output stacktrace.
copy_stack_trace(node.outputs, new_op_input)

new_op_input_nb_elements = new_op_input.size
new_op_output = node.op(new_op_input)

if len(non_scalars) != 0:
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, new_op_output)

# If `node.op` is a `Prod`, then the scalars need to be raised to
# the power of the number of elements in the input to the `Prod`
if isinstance(node.op, Prod) and new_op_input_nb_elements != 1:
scalars = [s**new_op_input_nb_elements for s in scalars]

# Scale the output of the op by the scalars and return as
# replacement for the original output
mul_inputs = scalars
if new_op_input_nb_elements != 1:
mul_inputs.append(new_op_output)

if len(mul_inputs) == 1:
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, mul_inputs)

return mul_inputs
reduced_axes = node.op.axis
if reduced_axes is None:
reduced_axes = tuple(range(node_inps.type.ndim))

# Separate terms that can be moved out of the Sum/Prod and those that cannot
if inner_op == mul:
# Mul accepts arbitrary inputs, so we need to separate into two groups
outer_terms = []
inner_terms = []
for term in node_inps.owner.inputs:
term_bcast = term.type.broadcastable
if all(term_bcast[i] for i in reduced_axes):
outer_terms.append(term.squeeze(reduced_axes))
else:
ret = mul(*mul_inputs)
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, [ret] + mul_inputs)
inner_terms.append(term)

return [ret]
if not outer_terms:
return None
elif len(outer_terms) == 1:
[outer_term] = outer_terms
else:
outer_term = mul(*outer_terms)

if isinstance(node.op, Sum) and node_inps.owner and node_inps.owner.op == neg:
s = node.op(node_inps.owner.inputs[0])
ret = neg(s)
# There are never errors in the negative op, thus
# we need only to copy over stacktrace from previous output node to
# the two new ops.
copy_stack_trace(node.outputs, [s, ret])
if not inner_terms:
inner_term = None
elif len(inner_terms) == 1:
[inner_term] = inner_terms
else:
inner_term = mul(*inner_terms)

else: # true_div
# We only care about removing the denominator out of the reduction
numerator, denominator = node_inps.owner.inputs
denominator_bcast = denominator.type.broadcastable
if all(denominator_bcast[i] for i in reduced_axes):
outer_term = denominator.squeeze(reduced_axes)
inner_term = numerator
else:
return None

return [ret]
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
# that were contracted in the input
if isinstance(node.op, Prod) and inner_term:
dtype = inner_term.dtype
n_reduced_elements = prod(
[inner_term.shape[i].astype(dtype) for i in reduced_axes]
)
outer_term = outer_term**n_reduced_elements

if not inner_term:
# Sum/Prod is useless, just return the outer_term
# (This can only happen for mul, not division)
new_out = outer_term
else:
reduced_inner_term = node.op(inner_term)
if inner_op == mul:
new_out = outer_term * reduced_inner_term
else:
new_out = reduced_inner_term / outer_term
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])

copy_stack_trace(node.outputs, new_out)
return [new_out]


@register_specialize
@node_rewriter([Sum])
def local_sum_of_neg_to_neg_of_sum(fgraph, node):
"""Rewrite sum(-X) -> -sum(X)."""
[node_inps] = node.inputs
if node_inps.owner and node_inps.owner.op == neg:
s = node.op(node_inps.owner.inputs[0])
ret = neg(s)
# There are never errors in the negative op, thus
# we need only to copy over stacktrace from previous output node to
# the two new ops.
copy_stack_trace(node.outputs, [s, ret])

return [ret]


@register_specialize
Expand Down Expand Up @@ -1508,99 +1531,6 @@ def investigate(node):
return


@register_canonicalize
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_div_dimshuffle(fgraph, node):
"""
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'

or

prod(a / dimshuffle{...}(b), axis=l) ->
prod(a, axis={...}) / b ** a.shape[l],
if dimension l of the DimShuffle is 'x'
"""

# It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation or production.

if isinstance(node.op, (Sum, Prod)):
axis = node.op.axis
if axis is None:
axis = list(range(node.inputs[0].ndim))
node_input = node.inputs[0]
if node_input.owner and node_input.owner.op == true_div:
numerator, denominator = node_input.owner.inputs

if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order

compatible_dims = []
incompatible_dims = []
for ax in axis:
if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x":
compatible_dims.append(ax)
else:
incompatible_dims.append(ax)
reordered_incompatible_dims = []
for ic_ax in incompatible_dims:
reordered_incompatible_dims.append(
ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax)
)

if len(compatible_dims) > 0:
optimized_dimshuffle_order = [
ax
for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != "x")
]

# Removing leading 'x' (since it will be done automatically)
while (
len(optimized_dimshuffle_order) > 0
and optimized_dimshuffle_order[0] == "x"
):
del optimized_dimshuffle_order[0]

# if optimized_dimshuffle_order is sorted with
# not 'x', then dimshuffle is useless.
if all(i == e for i, e in enumerate(optimized_dimshuffle_order)):
optimized_dimshuffle = dimshuffle_input
else:
optimized_dimshuffle = DimShuffle(
dimshuffle_input.type.broadcastable,
optimized_dimshuffle_order,
)(dimshuffle_input)

if isinstance(node.op, Sum):
op_on_compatible_dims = at_sum(numerator, axis=compatible_dims)
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
if len(reordered_incompatible_dims) > 0:
rval = at_sum(rval, axis=reordered_incompatible_dims)
elif isinstance(node.op, Prod):
op_on_compatible_dims = prod(numerator, axis=compatible_dims)
dtype = numerator.dtype
rval = true_div(
op_on_compatible_dims,
(
optimized_dimshuffle
** prod(
[
numerator.shape[ax].astype(dtype)
for ax in compatible_dims
]
)
),
)
if len(reordered_incompatible_dims) > 0:
rval = prod(rval, axis=reordered_incompatible_dims)
return [rval]


@register_canonicalize
@node_rewriter([Sum, Prod])
def local_sum_prod_all_to_none(fgraph, node):
Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def large_fuseable_graph(self, n):
),
(fx, fy),
(fxv, fyv),
3,
2,
(
np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv),
Expand Down
Loading