Skip to content

Commit 2684651

Browse files
Parametrize sorting tests by kind
1 parent a15e4aa commit 2684651

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -153,81 +153,117 @@ def test_sort_validation():
153153
dpt.sort(dict())
154154

155155

156+
def test_sort_validation_kind():
157+
get_queue_or_skip()
158+
159+
x = dpt.ones(128, dtype="u1")
160+
161+
with pytest.raises(ValueError):
162+
dpt.sort(x, kind=Ellipsis)
163+
164+
with pytest.raises(ValueError):
165+
dpt.sort(x, kind="invalid")
166+
167+
156168
def test_argsort_validation():
157169
with pytest.raises(TypeError):
158170
dpt.argsort(dict())
159171

160172

161-
def test_sort_axis0():
173+
def test_argsort_validation_kind():
174+
get_queue_or_skip()
175+
176+
x = dpt.arange(127, stop=0, step=-1, dtype="i1")
177+
178+
with pytest.raises(ValueError):
179+
dpt.argsort(x, kind=Ellipsis)
180+
181+
with pytest.raises(ValueError):
182+
dpt.argsort(x, kind="invalid")
183+
184+
185+
_all_kinds = ["stable", "mergesort", "radixsort"]
186+
187+
188+
@pytest.mark.parametrize("kind", _all_kinds)
189+
def test_sort_axis0(kind):
162190
get_queue_or_skip()
163191

164192
n, m = 200, 30
165193
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
166194
x = dpt.reshape(xf, (n, m))
167-
s = dpt.sort(x, axis=0)
195+
s = dpt.sort(x, axis=0, kind=kind)
168196

169197
assert dpt.all(s[:-1, :] <= s[1:, :])
170198

171199

172-
def test_argsort_axis0():
200+
@pytest.mark.parametrize("kind", _all_kinds)
201+
def test_argsort_axis0(kind):
173202
get_queue_or_skip()
174203

175204
n, m = 200, 30
176205
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
177206
x = dpt.reshape(xf, (n, m))
178-
idx = dpt.argsort(x, axis=0)
207+
idx = dpt.argsort(x, axis=0, kind=kind)
179208

180209
s = dpt.take_along_axis(x, idx, axis=0)
181210

182211
assert dpt.all(s[:-1, :] <= s[1:, :])
183212

184213

185-
def test_argsort_axis1():
214+
@pytest.mark.parametrize("kind", _all_kinds)
215+
def test_argsort_axis1(kind):
186216
get_queue_or_skip()
187217

188218
n, m = 200, 30
189219
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
190220
x = dpt.reshape(xf, (n, m))
191-
idx = dpt.argsort(x, axis=1)
221+
idx = dpt.argsort(x, axis=1, kind=kind)
192222

193223
s = dpt.take_along_axis(x, idx, axis=1)
194224

195225
assert dpt.all(s[:, :-1] <= s[:, 1:])
196226

197227

198-
def test_sort_strided():
228+
@pytest.mark.parametrize("kind", _all_kinds)
229+
def test_sort_strided(kind):
199230
get_queue_or_skip()
200231

201232
x_orig = dpt.arange(100, dtype="i4")
202233
x_flipped = dpt.flip(x_orig, axis=0)
203-
s = dpt.sort(x_flipped)
234+
s = dpt.sort(x_flipped, kind=kind)
204235

205236
assert dpt.all(s == x_orig)
206237

207238

208-
def test_argsort_strided():
239+
@pytest.mark.parametrize("kind", _all_kinds)
240+
def test_argsort_strided(kind):
209241
get_queue_or_skip()
210242

211243
x_orig = dpt.arange(100, dtype="i4")
212244
x_flipped = dpt.flip(x_orig, axis=0)
213-
idx = dpt.argsort(x_flipped)
245+
idx = dpt.argsort(x_flipped, kind=kind)
214246
s = dpt.take_along_axis(x_flipped, idx, axis=0)
215247

216248
assert dpt.all(s == x_orig)
217249

218250

219-
def test_sort_0d_array():
251+
@pytest.mark.parametrize("kind", _all_kinds)
252+
def test_sort_0d_array(kind):
220253
get_queue_or_skip()
221254

222255
x = dpt.asarray(1, dtype="i4")
223-
assert dpt.sort(x) == 1
256+
expected = dpt.asarray(1, dtype="i4")
257+
assert dpt.sort(x, kind=kind) == expected
224258

225259

226-
def test_argsort_0d_array():
260+
@pytest.mark.parametrize("kind", _all_kinds)
261+
def test_argsort_0d_array(kind):
227262
get_queue_or_skip()
228263

229264
x = dpt.asarray(1, dtype="i4")
230-
assert dpt.argsort(x) == 0
265+
expected = dpt.asarray(0, dtype="i4")
266+
assert dpt.argsort(x, kind=kind) == expected
231267

232268

233269
@pytest.mark.parametrize(
@@ -238,22 +274,23 @@ def test_argsort_0d_array():
238274
"f8",
239275
],
240276
)
241-
def test_sort_real_fp_nan(dtype):
277+
@pytest.mark.parametrize("kind", _all_kinds)
278+
def test_sort_real_fp_nan(dtype, kind):
242279
q = get_queue_or_skip()
243280
skip_if_dtype_not_supported(dtype, q)
244281

245282
x = dpt.asarray(
246283
[-0.0, 0.1, dpt.nan, 0.0, -0.1, dpt.nan, 0.2, -0.3], dtype=dtype
247284
)
248-
s = dpt.sort(x)
285+
s = dpt.sort(x, kind=kind)
249286

250287
expected = dpt.asarray(
251288
[-0.3, -0.1, -0.0, 0.0, 0.1, 0.2, dpt.nan, dpt.nan], dtype=dtype
252289
)
253290

254291
assert dpt.allclose(s, expected, equal_nan=True)
255292

256-
s = dpt.sort(x, descending=True)
293+
s = dpt.sort(x, descending=True, kind=kind)
257294

258295
expected = dpt.asarray(
259296
[dpt.nan, dpt.nan, 0.2, 0.1, -0.0, 0.0, -0.1, -0.3], dtype=dtype

0 commit comments

Comments
 (0)