Skip to content

Commit b0cb903

Browse files
committed
Stop defining min as negative of max
My best guess of why this was done, was that historically Max was implemented before Min. Now that we have Min, and we always rewrite into it by the end, there seems to be no good reason not to start with it
1 parent 2b6981b commit b0cb903

File tree

2 files changed

+7
-46
lines changed

2 files changed

+7
-46
lines changed

pytensor/tensor/math.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -609,19 +609,11 @@ def min(x, axis=None, keepdims=False):
609609
will broadcast correctly against the original tensor.
610610
611611
"""
612-
x = as_tensor_variable(x)
613-
str_x_type = str(x.dtype)
614-
if str_x_type.startswith("float") or str_x_type in int_dtypes:
615-
return -max(-x, axis=axis, keepdims=keepdims)
616-
elif str_x_type in uint_dtypes:
617-
itype = np.iinfo(x.dtype)
618-
max_val = np.array(itype.max, dtype=itype.dtype)
619-
return max_val - max(max_val - x, axis=axis, keepdims=keepdims)
620-
elif str_x_type == "bool":
621-
return ~max(~x, axis=axis, keepdims=keepdims)
622-
else:
623-
# Be careful about unsigned integers, complex
624-
raise NotImplementedError()
612+
out = Min(axis=axis)(x)
613+
614+
if keepdims:
615+
out = makeKeepDims(x, out, axis)
616+
return out
625617

626618

627619
def argmin(x, axis=None, keepdims=False):

pytensor/tensor/rewriting/uncanonicalize.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,45 +31,14 @@
3131
3232
"""
3333

34-
from pytensor import scalar as ps
35-
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
34+
from pytensor.graph.rewriting.basic import node_rewriter
3635
from pytensor.tensor.basic import Alloc, alloc, constant
37-
from pytensor.tensor.elemwise import CAReduce, DimShuffle
38-
from pytensor.tensor.math import Min, neg
36+
from pytensor.tensor.elemwise import DimShuffle
3937
from pytensor.tensor.rewriting.basic import register_uncanonicalize
4038
from pytensor.tensor.shape import Reshape, reshape
4139
from pytensor.tensor.subtensor import Subtensor
4240

4341

44-
@register_uncanonicalize
45-
@node_rewriter([neg])
46-
def local_max_to_min(fgraph, node):
47-
"""
48-
Change -(max(-x)) to min.
49-
50-
This is tested in tensor/tests/test_basic.py:test_min_max.
51-
52-
Notes
53-
-----
54-
We don't need an opt that will do the reverse as by default
55-
the interface put only Max into the graph.
56-
57-
"""
58-
if node.op == neg and node.inputs[0].owner:
59-
max = node.inputs[0]
60-
if (
61-
max.owner
62-
and isinstance(max.owner.op, CAReduce)
63-
and max.owner.op.scalar_op == ps.scalar_maximum
64-
):
65-
neg_node = max.owner.inputs[0]
66-
if neg_node.owner and neg_node.owner.op == neg:
67-
new = Min(max.owner.op.axis)(neg_node.owner.inputs[0])
68-
return [copy_stack_trace(node.outputs[0], new)]
69-
70-
return False
71-
72-
7342
@register_uncanonicalize
7443
@node_rewriter([Alloc])
7544
def local_alloc_dimshuffle(fgraph, node):

0 commit comments

Comments
 (0)