Skip to content

Commit 0d12385

Browse files
committed
Move argmax helper close to class definition
1 parent 1206acb commit 0d12385

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

pytensor/tensor/math.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,28 @@ def grad(self, inp, grads):
277277
return [x.zeros_like()]
278278

279279

280+
def argmax(x, axis=None, keepdims=False):
281+
"""
282+
Returns indices of maximum elements obtained by iterating over given axis.
283+
284+
When axis is None (the default value), the argmax is performed
285+
over the flattened tensor.
286+
287+
Parameters
288+
----------
289+
keepdims : bool
290+
If this is set to True, the axes which are reduced are left in
291+
the result as dimensions with size one. With this option, the result
292+
will broadcast correctly against the original tensor.
293+
294+
"""
295+
argout = max_and_argmax(x, axis)[1]
296+
297+
if keepdims:
298+
argout = makeKeepDims(x, argout, axis)
299+
return argout
300+
301+
280302
@_vectorize_node.register(Argmax)
281303
def vectorize_argmax_node(op, node, batch_x):
282304
core_ndim = node.inputs[0].type.ndim
@@ -549,28 +571,6 @@ def max(x, axis=None, keepdims=False):
549571
return out
550572

551573

552-
def argmax(x, axis=None, keepdims=False):
553-
"""
554-
Returns indices of maximum elements obtained by iterating over given axis.
555-
556-
When axis is None (the default value), the argmax is performed
557-
over the flattened tensor.
558-
559-
Parameters
560-
----------
561-
keepdims : bool
562-
If this is set to True, the axes which are reduced are left in
563-
the result as dimensions with size one. With this option, the result
564-
will broadcast correctly against the original tensor.
565-
566-
"""
567-
argout = max_and_argmax(x, axis)[1]
568-
569-
if keepdims:
570-
argout = makeKeepDims(x, argout, axis)
571-
return argout
572-
573-
574574
def min(x, axis=None, keepdims=False):
575575
"""
576576
Returns minimum elements obtained by iterating over given axis.

0 commit comments

Comments
 (0)