Skip to content

Commit 61985e6

Browse files
committed
Remove Scalar prefix from Maximum, Minimum, and Softplus
1 parent a149f6c commit 61985e6

File tree

11 files changed

+43
-56
lines changed

11 files changed

+43
-56
lines changed

pytensor/compile/profiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,8 +1480,8 @@ def print_tips(self, file):
14801480
ps.XOR,
14811481
ps.AND,
14821482
ps.Invert,
1483-
ps.ScalarMaximum,
1484-
ps.ScalarMinimum,
1483+
ps.Maximum,
1484+
ps.Minimum,
14851485
ps.Add,
14861486
ps.Mul,
14871487
ps.Sub,

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
XOR,
2727
Add,
2828
IntDiv,
29+
Maximum,
30+
Minimum,
2931
Mul,
30-
ScalarMaximum,
31-
ScalarMinimum,
3232
Sub,
3333
TrueDiv,
3434
get_scalar_type,
35-
scalar_maximum,
35+
maximum,
3636
)
3737
from pytensor.scalar.basic import add as add_as
3838
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -103,16 +103,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr):
103103
return f"{res}[{idx}] //= {arr}"
104104

105105

106-
@scalar_in_place_fn.register(ScalarMaximum)
107-
def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr):
106+
@scalar_in_place_fn.register(Maximum)
107+
def scalar_in_place_fn_Maximum(op, idx, res, arr):
108108
return f"""
109109
if {res}[{idx}] < {arr}:
110110
{res}[{idx}] = {arr}
111111
"""
112112

113113

114-
@scalar_in_place_fn.register(ScalarMinimum)
115-
def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
114+
@scalar_in_place_fn.register(Minimum)
115+
def scalar_in_place_fn_Minimum(op, idx, res, arr):
116116
return f"""
117117
if {res}[{idx}] > {arr}:
118118
{res}[{idx}] = {arr}
@@ -458,7 +458,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
458458
if axis is not None:
459459
axis = normalize_axis_index(axis, x_at.ndim)
460460
reduce_max_py = create_multiaxis_reducer(
461-
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
461+
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
462462
)
463463
reduce_sum_py = create_multiaxis_reducer(
464464
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
@@ -522,7 +522,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
522522
if axis is not None:
523523
axis = normalize_axis_index(axis, x_at.ndim)
524524
reduce_max_py = create_multiaxis_reducer(
525-
scalar_maximum,
525+
maximum,
526526
-np.inf,
527527
(axis,),
528528
x_at.ndim,

pytensor/scalar/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,7 +1868,7 @@ def c_code(self, node, name, inputs, outputs, sub):
18681868
##############
18691869
# Arithmetic
18701870
##############
1871-
class ScalarMaximum(BinaryScalarOp):
1871+
class Maximum(BinaryScalarOp):
18721872
commutative = True
18731873
associative = True
18741874
nfunc_spec = ("maximum", 2, 1)
@@ -1908,10 +1908,10 @@ def L_op(self, inputs, outputs, gout):
19081908
return (gx, gy)
19091909

19101910

1911-
scalar_maximum = ScalarMaximum(upcast_out, name="maximum")
1911+
maximum = Maximum(upcast_out, name="maximum")
19121912

19131913

1914-
class ScalarMinimum(BinaryScalarOp):
1914+
class Minimum(BinaryScalarOp):
19151915
commutative = True
19161916
associative = True
19171917
nfunc_spec = ("minimum", 2, 1)
@@ -1950,7 +1950,7 @@ def L_op(self, inputs, outputs, gout):
19501950
return (gx, gy)
19511951

19521952

1953-
scalar_minimum = ScalarMinimum(upcast_out, name="minimum")
1953+
minimum = Minimum(upcast_out, name="minimum")
19541954

19551955

19561956
class Add(ScalarOp):

pytensor/scalar/math.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
isinf,
3333
log,
3434
log1p,
35+
maximum,
3536
reciprocal,
36-
scalar_maximum,
3737
sqrt,
3838
switch,
3939
true_div,
@@ -1305,7 +1305,7 @@ def c_code_cache_version(self):
13051305
return v
13061306

13071307

1308-
softplus = Softplus(upgrade_to_float, name="scalar_softplus")
1308+
softplus = Softplus(upgrade_to_float, name="softplus")
13091309

13101310

13111311
class Log1mexp(UnaryScalarOp):
@@ -1575,9 +1575,7 @@ def inner_loop(
15751575
derivative_new = K * (F1 * dK + F2)
15761576

15771577
errapx = scalar_abs(derivative - derivative_new)
1578-
d_errapx = errapx / scalar_maximum(
1579-
err_threshold, scalar_abs(derivative_new)
1580-
)
1578+
d_errapx = errapx / maximum(err_threshold, scalar_abs(derivative_new))
15811579

15821580
min_iters_cond = n > (min_iters - 1)
15831581
derivative = switch(
@@ -1823,7 +1821,7 @@ def inner_loop(*args):
18231821
if len(grad_incs) == 1:
18241822
[max_abs_grad_inc] = grad_incs
18251823
else:
1826-
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
1824+
max_abs_grad_inc = reduce(maximum, abs_grad_incs)
18271825

18281826
return (
18291827
(*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),

pytensor/tensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ def _obj_is_wrappable_as_tensor(x):
262262
ps.Mul,
263263
ps.IntDiv,
264264
ps.TrueDiv,
265-
ps.ScalarMinimum,
266-
ps.ScalarMaximum,
265+
ps.Minimum,
266+
ps.Maximum,
267267
)
268268

269269

pytensor/tensor/blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,8 +947,8 @@ def infer_shape(self, fgraph, node, input_shapes):
947947
z_shape, _, x_shape, y_shape, _ = input_shapes
948948
return [
949949
(
950-
pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]),
951-
pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]),
950+
pytensor.scalar.maximum(z_shape[0], x_shape[0]),
951+
pytensor.scalar.maximum(z_shape[1], y_shape[1]),
952952
)
953953
]
954954

pytensor/tensor/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class Max(NonZeroDimsCAReduce):
406406
nfunc_spec = ("max", 1, 1)
407407

408408
def __init__(self, axis):
409-
super().__init__(ps.scalar_maximum, axis)
409+
super().__init__(ps.maximum, axis)
410410

411411
def clone(self, **kwargs):
412412
axis = kwargs.get("axis", self.axis)
@@ -464,7 +464,7 @@ class Min(NonZeroDimsCAReduce):
464464
nfunc_spec = ("min", 1, 1)
465465

466466
def __init__(self, axis):
467-
super().__init__(ps.scalar_minimum, axis)
467+
super().__init__(ps.minimum, axis)
468468

469469
def clone(self, **kwargs):
470470
axis = kwargs.get("axis", self.axis)
@@ -2757,7 +2757,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable:
27572757
return ifelse(even_k, even_median, odd_median, name="median")
27582758

27592759

2760-
@scalar_elemwise(symbolname="scalar_maximum")
2760+
@scalar_elemwise
27612761
def maximum(x, y):
27622762
"""elemwise maximum. See max for the maximum in one tensor
27632763
@@ -2793,7 +2793,7 @@ def maximum(x, y):
27932793
# see decorator for function body
27942794

27952795

2796-
@scalar_elemwise(symbolname="scalar_minimum")
2796+
@scalar_elemwise
27972797
def minimum(x, y):
27982798
"""elemwise minimum. See min for the minimum in one tensor
27992799

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ def local_useless_elemwise_comparison(fgraph, node):
14501450

14511451
# Elemwise[{minimum,maximum}](X, X) -> X
14521452
if (
1453-
isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum)
1453+
isinstance(node.op.scalar_op, ps.Minimum | ps.Maximum)
14541454
and node.inputs[0] is node.inputs[1]
14551455
):
14561456
res = node.inputs[0]
@@ -1493,7 +1493,7 @@ def local_useless_elemwise_comparison(fgraph, node):
14931493
return [res]
14941494

14951495
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
1496-
if isinstance(node.op.scalar_op, ps.ScalarMaximum):
1496+
if isinstance(node.op.scalar_op, ps.Maximum):
14971497
for idx in range(2):
14981498
if (
14991499
node.inputs[idx].owner
@@ -1512,7 +1512,7 @@ def local_useless_elemwise_comparison(fgraph, node):
15121512
return [res]
15131513

15141514
# Elemwise[minimum](X.shape[i], 0) -> 0
1515-
if isinstance(node.op.scalar_op, ps.ScalarMinimum):
1515+
if isinstance(node.op.scalar_op, ps.Minimum):
15161516
for idx in range(2):
15171517
if (
15181518
node.inputs[idx].owner

pytensor/tensor/rewriting/uncanonicalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def local_max_to_min(fgraph, node):
6060
if (
6161
max.owner
6262
and isinstance(max.owner.op, CAReduce)
63-
and max.owner.op.scalar_op == ps.scalar_maximum
63+
and max.owner.op.scalar_op == ps.maximum
6464
):
6565
neg_node = max.owner.inputs[0]
6666
if neg_node.owner and neg_node.owner.op == neg:

tests/tensor/rewriting/test_math.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3829,10 +3829,7 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
38293829

38303830
fgraph = f.maker.fgraph.toposort()
38313831
for node in fgraph:
3832-
if (
3833-
hasattr(node.op, "scalar_op")
3834-
and node.op.scalar_op == ps.basic.scalar_maximum
3835-
):
3832+
if hasattr(node.op, "scalar_op") and node.op.scalar_op == ps.basic.maximum:
38363833
return
38373834

38383835
# In mode FAST_COMPILE, the rewrites don't replace the

tests/tensor/test_elemwise.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -544,14 +544,14 @@ def with_mode(
544544
elif scalar_op == ps.mul:
545545
for axis in sorted(tosum, reverse=True):
546546
zv = np.multiply.reduce(zv, axis)
547-
elif scalar_op == ps.scalar_maximum:
547+
elif scalar_op == ps.maximum:
548548
# There is no identity value for the maximum function
549549
# So we can't support shape of dimensions 0.
550550
if np.prod(zv.shape) == 0:
551551
continue
552552
for axis in sorted(tosum, reverse=True):
553553
zv = np.maximum.reduce(zv, axis)
554-
elif scalar_op == ps.scalar_minimum:
554+
elif scalar_op == ps.minimum:
555555
# There is no identity value for the minimum function
556556
# So we can't support shape of dimensions 0.
557557
if np.prod(zv.shape) == 0:
@@ -594,7 +594,7 @@ def with_mode(
594594
tosum = list(range(len(xsh)))
595595
f = pytensor.function([x], e.shape, mode=mode, on_unused_input="ignore")
596596
if not (
597-
scalar_op in [ps.scalar_maximum, ps.scalar_minimum]
597+
scalar_op in [ps.maximum, ps.minimum]
598598
and (xsh == () or np.prod(xsh) == 0)
599599
):
600600
assert all(f(xv) == zv.shape)
@@ -606,8 +606,8 @@ def test_perform(self):
606606
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
607607
self.with_mode(Mode(linker="py"), ps.add, dtype=dtype)
608608
self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype)
609-
self.with_mode(Mode(linker="py"), ps.scalar_maximum, dtype=dtype)
610-
self.with_mode(Mode(linker="py"), ps.scalar_minimum, dtype=dtype)
609+
self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype)
610+
self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype)
611611
self.with_mode(Mode(linker="py"), ps.and_, dtype=dtype, tensor_op=pt_all)
612612
self.with_mode(Mode(linker="py"), ps.or_, dtype=dtype, tensor_op=pt_any)
613613
for dtype in ["int8", "uint8"]:
@@ -619,12 +619,8 @@ def test_perform_nan(self):
619619
for dtype in ["floatX", "complex64", "complex128"]:
620620
self.with_mode(Mode(linker="py"), ps.add, dtype=dtype, test_nan=True)
621621
self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype, test_nan=True)
622-
self.with_mode(
623-
Mode(linker="py"), ps.scalar_maximum, dtype=dtype, test_nan=True
624-
)
625-
self.with_mode(
626-
Mode(linker="py"), ps.scalar_minimum, dtype=dtype, test_nan=True
627-
)
622+
self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype, test_nan=True)
623+
self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype, test_nan=True)
628624
self.with_mode(
629625
Mode(linker="py"),
630626
ps.or_,
@@ -659,8 +655,8 @@ def test_c(self):
659655
self.with_mode(Mode(linker="c"), ps.add, dtype=dtype)
660656
self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype)
661657
for dtype in ["bool", "floatX", "int8", "uint8"]:
662-
self.with_mode(Mode(linker="c"), ps.scalar_minimum, dtype=dtype)
663-
self.with_mode(Mode(linker="c"), ps.scalar_maximum, dtype=dtype)
658+
self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype)
659+
self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype)
664660
self.with_mode(Mode(linker="c"), ps.and_, dtype=dtype, tensor_op=pt_all)
665661
self.with_mode(Mode(linker="c"), ps.or_, dtype=dtype, tensor_op=pt_any)
666662
for dtype in ["bool", "int8", "uint8"]:
@@ -678,12 +674,8 @@ def test_c_nan(self):
678674
self.with_mode(Mode(linker="c"), ps.add, dtype=dtype, test_nan=True)
679675
self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype, test_nan=True)
680676
for dtype in ["floatX"]:
681-
self.with_mode(
682-
Mode(linker="c"), ps.scalar_minimum, dtype=dtype, test_nan=True
683-
)
684-
self.with_mode(
685-
Mode(linker="c"), ps.scalar_maximum, dtype=dtype, test_nan=True
686-
)
677+
self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype, test_nan=True)
678+
self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype, test_nan=True)
687679

688680
def test_infer_shape(self, dtype=None, pre_scalar_op=None):
689681
if dtype is None:

0 commit comments

Comments
 (0)