Skip to content

Commit edb5b11

Browse files
committed
Simplify logic with variadic_add and variadic_mul helpers
1 parent 1c225e6 commit edb5b11

File tree

6 files changed

+49
-59
lines changed

6 files changed

+49
-59
lines changed

pytensor/tensor/blas.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
from pytensor.tensor.basic import expand_dims
103103
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
104104
from pytensor.tensor.elemwise import DimShuffle
105-
from pytensor.tensor.math import add, mul, neg, sub
105+
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
106106
from pytensor.tensor.shape import shape_padright, specify_broadcastable
107107
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
108108

@@ -1399,11 +1399,7 @@ def item_to_var(t):
13991399
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
14001400
]
14011401
add_inputs.extend(gemm_of_sM_list)
1402-
if len(add_inputs) > 1:
1403-
rval = [add(*add_inputs)]
1404-
else:
1405-
rval = add_inputs
1406-
# print "RETURNING GEMM THING", rval
1402+
rval = [variadic_add(*add_inputs)]
14071403
return rval, old_dot22
14081404

14091405

pytensor/tensor/math.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,18 +1425,8 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
14251425
else:
14261426
shp = cast(shp, "float64")
14271427

1428-
if axis is None:
1429-
axis = list(range(input.ndim))
1430-
elif isinstance(axis, int | np.integer):
1431-
axis = [axis]
1432-
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
1433-
axis = [int(axis)]
1434-
else:
1435-
axis = [int(a) for a in axis]
1436-
1437-
# This sequential division will possibly be optimized by PyTensor:
1438-
for i in axis:
1439-
s = true_div(s, shp[i])
1428+
canonical_axis = s.axis
1429+
s /= variadic_mul(*[shp[i] for i in canonical_axis])
14401430

14411431
# This can happen when axis is an empty list/tuple
14421432
if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
@@ -1592,6 +1582,15 @@ def add(a, *other_terms):
15921582
# see decorator for function body
15931583

15941584

1585+
def variadic_add(*args):
1586+
"""Add that accepts arbitrary number of inputs, including zero or one."""
1587+
if not args:
1588+
return 0
1589+
if len(args) == 1:
1590+
return args[0]
1591+
return add(*args)
1592+
1593+
15951594
@scalar_elemwise
15961595
def sub(a, b):
15971596
"""elementwise subtraction"""
@@ -1604,6 +1603,15 @@ def mul(a, *other_terms):
16041603
# see decorator for function body
16051604

16061605

1606+
def variadic_mul(*args):
1607+
"""Mul that accepts arbitrary number of inputs, including zero or one."""
1608+
if not args:
1609+
return 1
1610+
if len(args) == 1:
1611+
return args[0]
1612+
return mul(*args)
1613+
1614+
16071615
@scalar_elemwise
16081616
def true_div(a, b):
16091617
"""elementwise [true] division (inverse of multiplication)"""

pytensor/tensor/rewriting/basic.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6868
from pytensor.tensor.exceptions import NotScalarConstantError
6969
from pytensor.tensor.extra_ops import broadcast_arrays
70-
from pytensor.tensor.math import Sum, add, eq
70+
from pytensor.tensor.math import Sum, eq, variadic_add
7171
from pytensor.tensor.shape import Shape_i, shape_padleft
7272
from pytensor.tensor.type import DenseTensorType, TensorType
7373
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -938,14 +938,9 @@ def local_sum_make_vector(fgraph, node):
938938
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
939939
return
940940

941-
if len(elements) == 0:
942-
element_sum = zeros(dtype=out_dtype, shape=())
943-
elif len(elements) == 1:
944-
element_sum = cast(elements[0], out_dtype)
945-
else:
946-
element_sum = cast(
947-
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
948-
)
941+
element_sum = cast(
942+
variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype
943+
)
949944

950945
return [element_sum]
951946

pytensor/tensor/rewriting/blas.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,15 @@
9696
)
9797
from pytensor.tensor.elemwise import DimShuffle, Elemwise
9898
from pytensor.tensor.exceptions import NotScalarConstantError
99-
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
99+
from pytensor.tensor.math import (
100+
Dot,
101+
_matrix_matrix_matmul,
102+
add,
103+
mul,
104+
neg,
105+
sub,
106+
variadic_add,
107+
)
100108
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
101109
from pytensor.tensor.type import (
102110
DenseTensorType,
@@ -386,10 +394,7 @@ def item_to_var(t):
386394
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
387395
]
388396
add_inputs.extend(gemm_of_sM_list)
389-
if len(add_inputs) > 1:
390-
rval = [add(*add_inputs)]
391-
else:
392-
rval = add_inputs
397+
rval = [variadic_add(*add_inputs)]
393398
# print "RETURNING GEMM THING", rval
394399
return rval, old_dot22
395400

pytensor/tensor/rewriting/math.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
sub,
8282
tri_gamma,
8383
true_div,
84+
variadic_add,
85+
variadic_mul,
8486
)
8587
from pytensor.tensor.math import abs as pt_abs
8688
from pytensor.tensor.math import max as pt_max
@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
12701272

12711273
if not outer_terms:
12721274
return None
1273-
elif len(outer_terms) == 1:
1274-
[outer_term] = outer_terms
12751275
else:
1276-
outer_term = mul(*outer_terms)
1276+
outer_term = variadic_mul(*outer_terms)
12771277

12781278
if not inner_terms:
12791279
inner_term = None
1280-
elif len(inner_terms) == 1:
1281-
[inner_term] = inner_terms
12821280
else:
1283-
inner_term = mul(*inner_terms)
1281+
inner_term = variadic_mul(*inner_terms)
12841282

12851283
else: # true_div
12861284
# We only care about removing the denominator out of the reduction
@@ -2163,10 +2161,7 @@ def local_add_remove_zeros(fgraph, node):
21632161
assert cst.type.broadcastable == (True,) * ndim
21642162
return [alloc_like(cst, node_output, fgraph)]
21652163

2166-
if len(new_inputs) == 1:
2167-
ret = [alloc_like(new_inputs[0], node_output, fgraph)]
2168-
else:
2169-
ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
2164+
ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)]
21702165

21712166
# The dtype should not be changed. It can happen if the input
21722167
# that was forcing upcasting was equal to 0.
@@ -2277,10 +2272,7 @@ def local_log1p(fgraph, node):
22772272
# scalar_inputs are potentially dimshuffled and fill'd scalars
22782273
if scalars and np.allclose(np.sum(scalars), 1):
22792274
if nonconsts:
2280-
if len(nonconsts) > 1:
2281-
ninp = add(*nonconsts)
2282-
else:
2283-
ninp = nonconsts[0]
2275+
ninp = variadic_add(*nonconsts)
22842276
if ninp.dtype != log_arg.type.dtype:
22852277
ninp = ninp.astype(node.outputs[0].dtype)
22862278
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
@@ -3104,10 +3096,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
31043096
return
31053097
# put the new numerator together
31063098
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
3107-
if len(new_num) == 1:
3108-
new_num = new_num[0]
3109-
else:
3110-
new_num = mul(*new_num)
3099+
new_num = variadic_mul(*new_num)
31113100

31123101
if num_neg ^ denom_neg:
31133102
new_num = -new_num

pytensor/tensor/rewriting/subtensor.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
maximum,
4949
minimum,
5050
or_,
51+
variadic_add,
5152
)
5253
from pytensor.tensor.math import all as pt_all
5354
from pytensor.tensor.rewriting.basic import (
@@ -1218,15 +1219,11 @@ def movable(i):
12181219
new_inputs = [i for i in node.inputs if not movable(i)] + [
12191220
mi.owner.inputs[0] for mi in movable_inputs
12201221
]
1221-
if len(new_inputs) == 0:
1222-
new_add = new_inputs[0]
1223-
else:
1224-
new_add = add(*new_inputs)
1225-
1226-
# Copy over stacktrace from original output, as an error
1227-
# (e.g. an index error) in this add operation should
1228-
# correspond to an error in the original add operation.
1229-
copy_stack_trace(node.outputs[0], new_add)
1222+
new_add = variadic_add(*new_inputs)
1223+
# Copy over stacktrace from original output, as an error
1224+
# (e.g. an index error) in this add operation should
1225+
# correspond to an error in the original add operation.
1226+
copy_stack_trace(node.outputs[0], new_add)
12301227

12311228
# stack up the new incsubtensors
12321229
tip = new_add

0 commit comments

Comments
 (0)