Skip to content

Commit 63f630c

Browse files
committed
implement dpnp.argmin and dpnp.argmax using dpctl.tensor
1 parent 01b3948 commit 63f630c

File tree

13 files changed

+302
-274
lines changed

13 files changed

+302
-274
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,7 @@ enum class DPNPFuncName : size_t
7676
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() impl */
7777
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() impl */
7878
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
79-
DPNP_FN_ARGMAX_EXT, /**< Used in numpy.argmax() impl, requires extra
80-
parameters */
8179
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
82-
DPNP_FN_ARGMIN_EXT, /**< Used in numpy.argmin() impl, requires extra
83-
parameters */
8480
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
8581
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
8682
parameters */

dpnp/backend/kernels/dpnp_krnl_searching.cpp

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,6 @@ void (*dpnp_argmax_default_c)(void *,
7878
void *,
7979
size_t) = dpnp_argmax_c<_DataType, _idx_DataType>;
8080

81-
template <typename _DataType, typename _idx_DataType>
82-
DPCTLSyclEventRef (*dpnp_argmax_ext_c)(DPCTLSyclQueueRef,
83-
void *,
84-
void *,
85-
size_t,
86-
const DPCTLEventVectorRef) =
87-
dpnp_argmax_c<_DataType, _idx_DataType>;
88-
8981
template <typename _DataType, typename _idx_DataType>
9082
class dpnp_argmin_c_kernel;
9183

@@ -133,14 +125,6 @@ void (*dpnp_argmin_default_c)(void *,
133125
void *,
134126
size_t) = dpnp_argmin_c<_DataType, _idx_DataType>;
135127

136-
template <typename _DataType, typename _idx_DataType>
137-
DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef,
138-
void *,
139-
void *,
140-
size_t,
141-
const DPCTLEventVectorRef) =
142-
dpnp_argmin_c<_DataType, _idx_DataType>;
143-
144128
void func_map_init_searching(func_map_t &fmap)
145129
{
146130
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {
@@ -160,23 +144,6 @@ void func_map_init_searching(func_map_t &fmap)
160144
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_DBL][eft_LNG] = {
161145
eft_LNG, (void *)dpnp_argmax_default_c<double, int64_t>};
162146

163-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_INT] = {
164-
eft_INT, (void *)dpnp_argmax_ext_c<int32_t, int32_t>};
165-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_LNG] = {
166-
eft_LNG, (void *)dpnp_argmax_ext_c<int32_t, int64_t>};
167-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_INT] = {
168-
eft_INT, (void *)dpnp_argmax_ext_c<int64_t, int32_t>};
169-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_LNG] = {
170-
eft_LNG, (void *)dpnp_argmax_ext_c<int64_t, int64_t>};
171-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_INT] = {
172-
eft_INT, (void *)dpnp_argmax_ext_c<float, int32_t>};
173-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_LNG] = {
174-
eft_LNG, (void *)dpnp_argmax_ext_c<float, int64_t>};
175-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_INT] = {
176-
eft_INT, (void *)dpnp_argmax_ext_c<double, int32_t>};
177-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_LNG] = {
178-
eft_LNG, (void *)dpnp_argmax_ext_c<double, int64_t>};
179-
180147
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_INT] = {
181148
eft_INT, (void *)dpnp_argmin_default_c<int32_t, int32_t>};
182149
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_LNG] = {
@@ -194,22 +161,5 @@ void func_map_init_searching(func_map_t &fmap)
194161
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_DBL][eft_LNG] = {
195162
eft_LNG, (void *)dpnp_argmin_default_c<double, int64_t>};
196163

197-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_INT] = {
198-
eft_INT, (void *)dpnp_argmin_ext_c<int32_t, int32_t>};
199-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_LNG] = {
200-
eft_LNG, (void *)dpnp_argmin_ext_c<int32_t, int64_t>};
201-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_INT] = {
202-
eft_INT, (void *)dpnp_argmin_ext_c<int64_t, int32_t>};
203-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_LNG] = {
204-
eft_LNG, (void *)dpnp_argmin_ext_c<int64_t, int64_t>};
205-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_INT] = {
206-
eft_INT, (void *)dpnp_argmin_ext_c<float, int32_t>};
207-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_LNG] = {
208-
eft_LNG, (void *)dpnp_argmin_ext_c<float, int64_t>};
209-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {
210-
eft_INT, (void *)dpnp_argmin_ext_c<double, int32_t>};
211-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {
212-
eft_LNG, (void *)dpnp_argmin_ext_c<double, int64_t>};
213-
214164
return;
215165
}

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3636
DPNP_FN_ALLCLOSE
3737
DPNP_FN_ALLCLOSE_EXT
3838
DPNP_FN_ARANGE
39-
DPNP_FN_ARGMAX
40-
DPNP_FN_ARGMAX_EXT
41-
DPNP_FN_ARGMIN
42-
DPNP_FN_ARGMIN_EXT
4339
DPNP_FN_ARGSORT
4440
DPNP_FN_ARGSORT_EXT
4541
DPNP_FN_CBRT
@@ -375,12 +371,6 @@ Sorting functions
375371
cpdef dpnp_descriptor dpnp_argsort(dpnp_descriptor array1)
376372
cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)
377373

378-
"""
379-
Searching functions
380-
"""
381-
cpdef dpnp_descriptor dpnp_argmax(dpnp_descriptor array1)
382-
cpdef dpnp_descriptor dpnp_argmin(dpnp_descriptor array1)
383-
384374
"""
385375
Trigonometric functions
386376
"""

dpnp/dpnp_algo/dpnp_algo_searching.pxi

Lines changed: 0 additions & 119 deletions
This file was deleted.

dpnp/dpnp_array.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -490,24 +490,14 @@ def argmax(self, axis=None, out=None):
490490
"""
491491
Returns array of indices of the maximum values along the given axis.
492492
493-
Parameters
494-
----------
495-
axis : {None, integer}
496-
If None, the index is into the flattened array, otherwise along
497-
the specified axis
498-
out : {None, array}, optional
499-
Array into which the result can be placed. Its type is preserved
500-
and it must be of the right shape to hold the output.
501-
502-
Returns
503-
-------
504-
index_array : {integer_array}
493+
Refer to :obj:`dpnp.argmax` for full documentation.
505494
506495
Examples
507496
--------
497+
>>> import dpnp as np
508498
>>> a = np.arange(6).reshape(2,3)
509499
>>> a.argmax()
510-
5
500+
array(5)
511501
>>> a.argmax(0)
512502
array([1, 1, 1])
513503
>>> a.argmax(1)
@@ -520,21 +510,7 @@ def argmin(self, axis=None, out=None):
520510
"""
521511
Return array of indices to the minimum values along the given axis.
522512
523-
Parameters
524-
----------
525-
axis : {None, integer}
526-
If None, the index is into the flattened array, otherwise along
527-
the specified axis
528-
out : {None, array}, optional
529-
Array into which the result can be placed. Its type is preserved
530-
and it must be of the right shape to hold the output.
531-
532-
Returns
533-
-------
534-
ndarray or scalar
535-
If multi-dimension input, returns a new ndarray of indices to the
536-
minimum values along the given axis. Otherwise, returns a scalar
537-
of index to the minimum values along the given axis.
513+
Refer to :obj:`dpnp.argmin` for full documentation.
538514
539515
"""
540516
return dpnp.argmin(self, axis, out)

0 commit comments

Comments
 (0)