Skip to content

Commit 66f995b

Browse files
authored
update_argsort_test (#1667)
1 parent e404fa6 commit 66f995b

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/test_sort.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def test_argsort_dtype(self, dtype):
9494
np_array = numpy.array(a, dtype=dtype)
9595
dp_array = dpnp.array(np_array)
9696

97-
result = dpnp.argsort(dp_array)
98-
expected = numpy.argsort(np_array)
97+
result = dpnp.argsort(dp_array, kind="stable")
98+
expected = numpy.argsort(np_array, kind="stable")
9999
assert_dtype_allclose(result, expected)
100100

101101
@pytest.mark.parametrize("dtype", get_complex_dtypes())

tests/third_party/cupy/sorting_tests/test_sort.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,12 @@ def test_F_order(self, xp):
296296
)
297297
)
298298
class TestArgsort(unittest.TestCase):
299-
def argsort(self, a, axis=-1):
299+
def argsort(self, a, axis=-1, kind=None):
300300
if self.external:
301301
xp = cupy.get_array_module(a)
302-
return xp.argsort(a, axis=axis)
302+
return xp.argsort(a, axis=axis, kind=kind)
303303
else:
304-
return a.argsort(axis=axis)
304+
return a.argsort(axis=axis, kind=kind)
305305

306306
# Test base cases
307307

@@ -317,7 +317,7 @@ def test_argsort_zero_dim(self, xp, dtype):
317317
@testing.numpy_cupy_array_equal()
318318
def test_argsort_one_dim(self, xp, dtype):
319319
a = testing.shaped_random((10,), xp, dtype)
320-
return self.argsort(a)
320+
return self.argsort(a, axis=-1, kind="stable")
321321

322322
@testing.for_all_dtypes()
323323
@testing.numpy_cupy_array_equal()

0 commit comments

Comments
 (0)