Skip to content

Commit 5632777

Browse files
committed
Simplify logic with variadic_add and variadic_mul helpers
1 parent cdae903 commit 5632777

File tree

6 files changed

+53
-59
lines changed

6 files changed

+53
-59
lines changed

pytensor/tensor/blas.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
from pytensor.tensor.basic import expand_dims
103103
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
104104
from pytensor.tensor.elemwise import DimShuffle
105-
from pytensor.tensor.math import add, mul, neg, sub
105+
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
106106
from pytensor.tensor.shape import shape_padright, specify_broadcastable
107107
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
108108

@@ -1399,11 +1399,7 @@ def item_to_var(t):
13991399
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
14001400
]
14011401
add_inputs.extend(gemm_of_sM_list)
1402-
if len(add_inputs) > 1:
1403-
rval = [add(*add_inputs)]
1404-
else:
1405-
rval = add_inputs
1406-
# print "RETURNING GEMM THING", rval
1402+
rval = [variadic_add(*add_inputs)]
14071403
return rval, old_dot22
14081404

14091405

pytensor/tensor/math.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
14291429
else:
14301430
shp = cast(shp, "float64")
14311431

1432-
if axis is None:
1433-
axis = list(range(input.ndim))
1434-
elif isinstance(axis, int | np.integer):
1435-
axis = [axis]
1436-
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
1437-
axis = [int(axis)]
1438-
else:
1439-
axis = [int(a) for a in axis]
1440-
1441-
# This sequential division will possibly be optimized by PyTensor:
1442-
for i in axis:
1443-
s = true_div(s, shp[i])
1432+
reduced_dims = (
1433+
shp
1434+
if axis is None
1435+
else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)]
1436+
)
1437+
s /= variadic_mul(*reduced_dims).astype(shp.dtype)
14441438

14451439
# This can happen when axis is an empty list/tuple
14461440
if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
@@ -1596,6 +1590,15 @@ def add(a, *other_terms):
15961590
# see decorator for function body
15971591

15981592

1593+
def variadic_add(*args):
1594+
"""Add that accepts arbitrary number of inputs, including zero or one."""
1595+
if not args:
1596+
return constant(0)
1597+
if len(args) == 1:
1598+
return args[0]
1599+
return add(*args)
1600+
1601+
15991602
@scalar_elemwise
16001603
def sub(a, b):
16011604
"""elementwise subtraction"""
@@ -1608,6 +1611,15 @@ def mul(a, *other_terms):
16081611
# see decorator for function body
16091612

16101613

1614+
def variadic_mul(*args):
1615+
"""Mul that accepts arbitrary number of inputs, including zero or one."""
1616+
if not args:
1617+
return constant(1)
1618+
if len(args) == 1:
1619+
return args[0]
1620+
return mul(*args)
1621+
1622+
16111623
@scalar_elemwise
16121624
def true_div(a, b):
16131625
"""elementwise [true] division (inverse of multiplication)"""

pytensor/tensor/rewriting/basic.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6969
from pytensor.tensor.exceptions import NotScalarConstantError
7070
from pytensor.tensor.extra_ops import broadcast_arrays
71-
from pytensor.tensor.math import Sum, add, eq
71+
from pytensor.tensor.math import Sum, add, eq, variadic_add
7272
from pytensor.tensor.shape import Shape_i, shape_padleft
7373
from pytensor.tensor.type import DenseTensorType, TensorType
7474
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node):
939939
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
940940
return
941941

942-
if len(elements) == 0:
943-
element_sum = zeros(dtype=out_dtype, shape=())
944-
elif len(elements) == 1:
945-
element_sum = cast(elements[0], out_dtype)
946-
else:
947-
element_sum = cast(
948-
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
949-
)
942+
element_sum = cast(
943+
variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype
944+
)
950945

951946
return [element_sum]
952947

pytensor/tensor/rewriting/blas.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,15 @@
9696
)
9797
from pytensor.tensor.elemwise import DimShuffle, Elemwise
9898
from pytensor.tensor.exceptions import NotScalarConstantError
99-
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
99+
from pytensor.tensor.math import (
100+
Dot,
101+
_matrix_matrix_matmul,
102+
add,
103+
mul,
104+
neg,
105+
sub,
106+
variadic_add,
107+
)
100108
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
101109
from pytensor.tensor.type import (
102110
DenseTensorType,
@@ -386,10 +394,7 @@ def item_to_var(t):
386394
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
387395
]
388396
add_inputs.extend(gemm_of_sM_list)
389-
if len(add_inputs) > 1:
390-
rval = [add(*add_inputs)]
391-
else:
392-
rval = add_inputs
397+
rval = [variadic_add(*add_inputs)]
393398
# print "RETURNING GEMM THING", rval
394399
return rval, old_dot22
395400

pytensor/tensor/rewriting/math.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
sub,
7777
tri_gamma,
7878
true_div,
79+
variadic_add,
80+
variadic_mul,
7981
)
8082
from pytensor.tensor.math import abs as pt_abs
8183
from pytensor.tensor.math import max as pt_max
@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
12701272

12711273
if not outer_terms:
12721274
return None
1273-
elif len(outer_terms) == 1:
1274-
[outer_term] = outer_terms
12751275
else:
1276-
outer_term = mul(*outer_terms)
1276+
outer_term = variadic_mul(*outer_terms)
12771277

12781278
if not inner_terms:
12791279
inner_term = None
1280-
elif len(inner_terms) == 1:
1281-
[inner_term] = inner_terms
12821280
else:
1283-
inner_term = mul(*inner_terms)
1281+
inner_term = variadic_mul(*inner_terms)
12841282

12851283
else: # true_div
12861284
# We only care about removing the denominator out of the reduction
@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node):
21432141
assert cst.type.broadcastable == (True,) * ndim
21442142
return [alloc_like(cst, node_output, fgraph)]
21452143

2146-
if len(new_inputs) == 1:
2147-
ret = [alloc_like(new_inputs[0], node_output, fgraph)]
2148-
else:
2149-
ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
2144+
ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)]
21502145

21512146
# The dtype should not be changed. It can happen if the input
21522147
# that was forcing upcasting was equal to 0.
@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node):
22572252
# scalar_inputs are potentially dimshuffled and fill'd scalars
22582253
if scalars and np.allclose(np.sum(scalars), 1):
22592254
if nonconsts:
2260-
if len(nonconsts) > 1:
2261-
ninp = add(*nonconsts)
2262-
else:
2263-
ninp = nonconsts[0]
2255+
ninp = variadic_add(*nonconsts)
22642256
if ninp.dtype != log_arg.type.dtype:
22652257
ninp = ninp.astype(node.outputs[0].dtype)
22662258
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
30843076
return
30853077
# put the new numerator together
30863078
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
3087-
if len(new_num) == 1:
3088-
new_num = new_num[0]
3089-
else:
3090-
new_num = mul(*new_num)
3079+
new_num = variadic_mul(*new_num)
30913080

30923081
if num_neg ^ denom_neg:
30933082
new_num = -new_num

pytensor/tensor/rewriting/subtensor.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
maximum,
4949
minimum,
5050
or_,
51+
variadic_add,
5152
)
5253
from pytensor.tensor.math import all as pt_all
5354
from pytensor.tensor.rewriting.basic import (
@@ -1241,15 +1242,11 @@ def movable(i):
12411242
new_inputs = [i for i in node.inputs if not movable(i)] + [
12421243
mi.owner.inputs[0] for mi in movable_inputs
12431244
]
1244-
if len(new_inputs) == 0:
1245-
new_add = new_inputs[0]
1246-
else:
1247-
new_add = add(*new_inputs)
1248-
1249-
# Copy over stacktrace from original output, as an error
1250-
# (e.g. an index error) in this add operation should
1251-
# correspond to an error in the original add operation.
1252-
copy_stack_trace(node.outputs[0], new_add)
1245+
new_add = variadic_add(*new_inputs)
1246+
# Copy over stacktrace from original output, as an error
1247+
# (e.g. an index error) in this add operation should
1248+
# correspond to an error in the original add operation.
1249+
copy_stack_trace(node.outputs[0], new_add)
12531250

12541251
# stack up the new incsubtensors
12551252
tip = new_add

0 commit comments

Comments
 (0)