@@ -277,6 +277,28 @@ def grad(self, inp, grads):
277
277
return [x .zeros_like ()]
278
278
279
279
280
+ def argmax (x , axis = None , keepdims = False ):
281
+ """
282
+ Returns indices of maximum elements obtained by iterating over given axis.
283
+
284
+ When axis is None (the default value), the argmax is performed
285
+ over the flattened tensor.
286
+
287
+ Parameters
288
+ ----------
289
+ keepdims : bool
290
+ If this is set to True, the axes which are reduced are left in
291
+ the result as dimensions with size one. With this option, the result
292
+ will broadcast correctly against the original tensor.
293
+
294
+ """
295
+ argout = max_and_argmax (x , axis )[1 ]
296
+
297
+ if keepdims :
298
+ argout = makeKeepDims (x , argout , axis )
299
+ return argout
300
+
301
+
280
302
@_vectorize_node .register (Argmax )
281
303
def vectorize_argmax_node (op , node , batch_x ):
282
304
core_ndim = node .inputs [0 ].type .ndim
@@ -549,28 +571,6 @@ def max(x, axis=None, keepdims=False):
549
571
return out
550
572
551
573
552
- def argmax (x , axis = None , keepdims = False ):
553
- """
554
- Returns indices of maximum elements obtained by iterating over given axis.
555
-
556
- When axis is None (the default value), the argmax is performed
557
- over the flattened tensor.
558
-
559
- Parameters
560
- ----------
561
- keepdims : bool
562
- If this is set to True, the axes which are reduced are left in
563
- the result as dimensions with size one. With this option, the result
564
- will broadcast correctly against the original tensor.
565
-
566
- """
567
- argout = max_and_argmax (x , axis )[1 ]
568
-
569
- if keepdims :
570
- argout = makeKeepDims (x , argout , axis )
571
- return argout
572
-
573
-
574
574
def min (x , axis = None , keepdims = False ):
575
575
"""
576
576
Returns minimum elements obtained by iterating over given axis.
0 commit comments