Skip to content

Commit 8072622

Browse files
authored
implement sort and argsort (#1660)
* implement sort and argsort * add more tests * update for zero dimensional arrays * address comments * fix typo
1 parent b401ae9 commit 8072622

18 files changed

+365
-273
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ enum class DPNPFuncName : size_t
7878
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
7979
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
8080
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
81-
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
82-
parameters */
8381
DPNP_FN_AROUND, /**< Used in numpy.around() impl */
8482
DPNP_FN_ASTYPE, /**< Used in numpy.astype() impl */
8583
DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() impl */
@@ -357,9 +355,7 @@ enum class DPNPFuncName : size_t
357355
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
358356
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
359357
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
360-
DPNP_FN_SORT_EXT, /**< Used in numpy.sort() impl, requires extra parameters
361-
*/
362-
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
358+
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
363359
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
364360
*/
365361
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,6 @@ template <typename _DataType, typename _idx_DataType>
9797
void (*dpnp_argsort_default_c)(void *, void *, size_t) =
9898
dpnp_argsort_c<_DataType, _idx_DataType>;
9999

100-
template <typename _DataType, typename _idx_DataType>
101-
DPCTLSyclEventRef (*dpnp_argsort_ext_c)(DPCTLSyclQueueRef,
102-
void *,
103-
void *,
104-
size_t,
105-
const DPCTLEventVectorRef) =
106-
dpnp_argsort_c<_DataType, _idx_DataType>;
107-
108100
// template void dpnp_argsort_c<double, long>(void* array1_in, void* result1,
109101
// size_t size); template void dpnp_argsort_c<float, long>(void* array1_in,
110102
// void* result1, size_t size); template void dpnp_argsort_c<long, long>(void*
@@ -471,14 +463,6 @@ void dpnp_sort_c(void *array1_in, void *result1, size_t size)
471463
template <typename _DataType>
472464
void (*dpnp_sort_default_c)(void *, void *, size_t) = dpnp_sort_c<_DataType>;
473465

474-
template <typename _DataType>
475-
DPCTLSyclEventRef (*dpnp_sort_ext_c)(DPCTLSyclQueueRef,
476-
void *,
477-
void *,
478-
size_t,
479-
const DPCTLEventVectorRef) =
480-
dpnp_sort_c<_DataType>;
481-
482466
void func_map_init_sorting(func_map_t &fmap)
483467
{
484468
fmap[DPNPFuncName::DPNP_FN_ARGSORT][eft_INT][eft_INT] = {
@@ -490,15 +474,6 @@ void func_map_init_sorting(func_map_t &fmap)
490474
fmap[DPNPFuncName::DPNP_FN_ARGSORT][eft_DBL][eft_DBL] = {
491475
eft_LNG, (void *)dpnp_argsort_default_c<double, int64_t>};
492476

493-
fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_INT][eft_INT] = {
494-
eft_LNG, (void *)dpnp_argsort_ext_c<int32_t, int64_t>};
495-
fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_LNG][eft_LNG] = {
496-
eft_LNG, (void *)dpnp_argsort_ext_c<int64_t, int64_t>};
497-
fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_FLT][eft_FLT] = {
498-
eft_LNG, (void *)dpnp_argsort_ext_c<float, int64_t>};
499-
fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_DBL][eft_DBL] = {
500-
eft_LNG, (void *)dpnp_argsort_ext_c<double, int64_t>};
501-
502477
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_INT][eft_INT] = {
503478
eft_INT, (void *)dpnp_partition_default_c<int32_t>};
504479
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_LNG][eft_LNG] = {
@@ -550,14 +525,5 @@ void func_map_init_sorting(func_map_t &fmap)
550525
fmap[DPNPFuncName::DPNP_FN_SORT][eft_DBL][eft_DBL] = {
551526
eft_DBL, (void *)dpnp_sort_default_c<double>};
552527

553-
fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_INT][eft_INT] = {
554-
eft_INT, (void *)dpnp_sort_ext_c<int32_t>};
555-
fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_LNG][eft_LNG] = {
556-
eft_LNG, (void *)dpnp_sort_ext_c<int64_t>};
557-
fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_FLT][eft_FLT] = {
558-
eft_FLT, (void *)dpnp_sort_ext_c<float>};
559-
fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_DBL][eft_DBL] = {
560-
eft_DBL, (void *)dpnp_sort_ext_c<double>};
561-
562528
return;
563529
}

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3636
DPNP_FN_ALLCLOSE
3737
DPNP_FN_ALLCLOSE_EXT
3838
DPNP_FN_ARANGE
39-
DPNP_FN_ARGSORT
40-
DPNP_FN_ARGSORT_EXT
4139
DPNP_FN_CHOOSE
4240
DPNP_FN_CHOOSE_EXT
4341
DPNP_FN_COPY
@@ -175,8 +173,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
175173
DPNP_FN_RNG_ZIPF_EXT
176174
DPNP_FN_SEARCHSORTED
177175
DPNP_FN_SEARCHSORTED_EXT
178-
DPNP_FN_SORT
179-
DPNP_FN_SORT_EXT
180176
DPNP_FN_SVD
181177
DPNP_FN_SVD_EXT
182178
DPNP_FN_TRACE
@@ -309,12 +305,6 @@ cpdef dpnp_descriptor dpnp_fmin(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj,
309305
dpnp_descriptor out=*, object where=*)
310306

311307

312-
"""
313-
Sorting functions
314-
"""
315-
cpdef dpnp_descriptor dpnp_argsort(dpnp_descriptor array1)
316-
cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)
317-
318308
"""
319309
Trigonometric functions
320310
"""

dpnp/dpnp_algo/dpnp_algo_sorting.pxi

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@ and the rest of the library
3636
# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file
3737

3838
__all__ += [
39-
"dpnp_argsort",
4039
"dpnp_partition",
4140
"dpnp_searchsorted",
42-
"dpnp_sort"
4341
]
4442

4543

@@ -61,13 +59,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_searchsorted_t)(c_dpctl.DPCTLSyclQ
6159
const c_dpctl.DPCTLEventVectorRef)
6260

6361

64-
cpdef utils.dpnp_descriptor dpnp_argsort(utils.dpnp_descriptor x1):
65-
cdef shape_type_c result_shape = x1.shape
66-
if result_shape == ():
67-
result_shape = (1,)
68-
return call_fptr_1in_1out(DPNP_FN_ARGSORT_EXT, x1, result_shape)
69-
70-
7162
cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None):
7263
cdef shape_type_c shape1 = arr.shape
7364

@@ -148,7 +139,3 @@ cpdef utils.dpnp_descriptor dpnp_searchsorted(utils.dpnp_descriptor arr, utils.d
148139
c_dpctl.DPCTLEvent_Delete(event_ref)
149140

150141
return result
151-
152-
153-
cpdef utils.dpnp_descriptor dpnp_sort(utils.dpnp_descriptor x1):
154-
return call_fptr_1in_1out(DPNP_FN_SORT_EXT, x1, x1.shape)

dpnp/dpnp_array.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -510,39 +510,7 @@ def argsort(self, axis=-1, kind=None, order=None):
510510
"""
511511
Return an ndarray of indices that sort the array along the specified axis.
512512
513-
Parameters
514-
----------
515-
axis : int, optional
516-
Axis along which to sort. If None, the default, the flattened array
517-
is used.
518-
.. versionchanged:: 1.13.0
519-
Previously, the default was documented to be -1, but that was
520-
in error. At some future date, the default will change to -1, as
521-
originally intended.
522-
Until then, the axis should be given explicitly when
523-
``arr.ndim > 1``, to avoid a FutureWarning.
524-
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional
525-
The sorting algorithm used.
526-
order : list, optional
527-
When `a` is an array with fields defined, this argument specifies
528-
which fields to compare first, second, etc. Not all fields need be
529-
specified.
530-
531-
Returns
532-
-------
533-
index_array : ndarray, int
534-
Array of indices that sort `a` along the specified axis.
535-
In other words, ``a[index_array]`` yields a sorted `a`.
536-
537-
See Also
538-
--------
539-
MaskedArray.sort : Describes sorting algorithms used.
540-
:obj:`dpnp.lexsort` : Indirect stable sort with multiple keys.
541-
:obj:`numpy.ndarray.sort` : Inplace sort.
542-
543-
Notes
544-
-----
545-
See `sort` for notes on the different sorting algorithms.
513+
Refer to :obj:`dpnp.argsort` for full documentation.
546514
547515
"""
548516
return dpnp.argsort(self, axis, kind, order)
@@ -1163,14 +1131,44 @@ def size(self):
11631131

11641132
return self._array_obj.size
11651133

1166-
# 'sort',
1134+
def sort(self, axis=-1, kind=None, order=None):
1135+
"""
1136+
Sort an array in-place.
1137+
1138+
Refer to :obj:`dpnp.sort` for full documentation.
1139+
1140+
Note
1141+
----
1142+
`axis` in :obj:`dpnp.sort` could be integr or ``None``. If ``None``,
1143+
the array is flattened before sorting. However, `axis` in :obj:`dpnp.ndarray.sort`
1144+
can only be integer since it sorts an array in-place.
1145+
1146+
Examples
1147+
--------
1148+
>>> import dpnp as np
1149+
>>> a = np.array([[1,4],[3,1]])
1150+
>>> a.sort(axis=1)
1151+
>>> a
1152+
array([[1, 4],
1153+
[1, 3]])
1154+
>>> a.sort(axis=0)
1155+
>>> a
1156+
array([[1, 1],
1157+
[3, 4]])
1158+
1159+
"""
1160+
1161+
if axis is None:
1162+
raise TypeError(
1163+
"'NoneType' object cannot be interpreted as an integer"
1164+
)
1165+
self[...] = dpnp.sort(self, axis=axis, kind=kind, order=order)
11671166

11681167
def squeeze(self, axis=None):
11691168
"""
11701169
Remove single-dimensional entries from the shape of an array.
11711170
1172-
.. seealso::
1173-
:obj:`dpnp.squeeze` for full documentation
1171+
Refer to :obj:`dpnp.squeeze` for full documentation
11741172
11751173
"""
11761174

dpnp/dpnp_iface_indexing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,7 @@ def take_along_axis(a, indices, axis):
851851
--------
852852
:obj:`dpnp.take` : Take along an axis, using the same indices for every 1d slice.
853853
:obj:`dpnp.put_along_axis` : Put values into the destination array by matching 1d index and data slices.
854+
:obj:`dpnp.argsort` : Return the indices that would sort an array.
854855
855856
Examples
856857
--------

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2709,7 +2709,7 @@ def sum(
27092709
27102710
Parameters
27112711
----------
2712-
a : {dpnp.ndarray, usm_ndarray}:
2712+
a : {dpnp.ndarray, usm_ndarray}
27132713
Input array.
27142714
axis : int or tuple of ints, optional
27152715
Axis or axes along which sums must be computed. If a tuple
@@ -2762,7 +2762,7 @@ def sum(
27622762
27632763
Limitations
27642764
-----------
2765-
Parameters `initial` and `where` are supported with their default values.
2765+
Parameters `initial` and `where` are only supported with their default values.
27662766
Otherwise ``NotImplementedError`` exception will be raised.
27672767
27682768
See Also

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# *****************************************************************************
3-
# Copyright (c) 2016-2024, Intel Corporation
3+
# Copyright (c) 2023-2024, Intel Corporation
44
# All rights reserved.
55
#
66
# Redistribution and use in source and binary forms, with or without
@@ -415,7 +415,7 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
415415
416416
Parameters
417417
----------
418-
a : {dpnp.ndarray, usm_ndarray}:
418+
a : {dpnp.ndarray, usm_ndarray}
419419
Input array.
420420
axis : int or tuple of ints, optional
421421
Axis or axes along which the arithmetic means must be computed. If
@@ -696,7 +696,7 @@ def nansum(
696696
697697
Parameters
698698
----------
699-
a : {dpnp.ndarray, usm_ndarray}:
699+
a : {dpnp.ndarray, usm_ndarray}
700700
Input array.
701701
axis : int or tuple of ints, optional
702702
Axis or axes along which sums must be computed. If a tuple
@@ -806,7 +806,7 @@ def nanstd(
806806
807807
Parameters
808808
----------
809-
a : {dpnp.ndarray, usm_ndarray}:
809+
a : {dpnp.ndarray, usm_ndarray}
810810
Input array.
811811
axis : int or tuple of ints, optional
812812
Axis or axes along which the standard deviations must be computed.
@@ -908,7 +908,7 @@ def nanvar(
908908
909909
Parameters
910910
----------
911-
a : {dpnp.ndarray, usm_ndarray}:
911+
a : {dpnp_array, usm_ndarray}
912912
Input array.
913913
axis : int or tuple of ints, optional
914914
axis or axes along which the variances must be computed. If a tuple

0 commit comments

Comments
 (0)