Skip to content

Commit e026c37

Browse files
committed
add tests for negative use cases to improve coverage
1 parent 14f043f commit e026c37

File tree

1 file changed

+23
-31
lines changed

1 file changed

+23
-31
lines changed

tests/test_search.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dpctl.tensor as dpt
12
import numpy
23
import pytest
34
from numpy.testing import assert_allclose
@@ -7,63 +8,54 @@
78
from .helper import get_all_dtypes
89

910

11+
@pytest.mark.parametrize("func", ["argmax", "argmin"])
1012
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2])
1113
@pytest.mark.parametrize("keepdims", [False, True])
1214
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
13-
def test_argmax_argmin(axis, keepdims, dtype):
15+
def test_argmax_argmin(func, axis, keepdims, dtype):
1416
a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
1517
ia = dpnp.array(a)
1618

17-
np_res = numpy.argmax(a, axis=axis, keepdims=keepdims)
18-
dpnp_res = dpnp.argmax(ia, axis=axis, keepdims=keepdims)
19-
20-
assert dpnp_res.shape == np_res.shape
21-
assert_allclose(dpnp_res, np_res)
22-
23-
np_res = numpy.argmin(a, axis=axis, keepdims=keepdims)
24-
dpnp_res = dpnp.argmin(ia, axis=axis, keepdims=keepdims)
19+
np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
20+
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
2521

2622
assert dpnp_res.shape == np_res.shape
2723
assert_allclose(dpnp_res, np_res)
2824

2925

26+
@pytest.mark.parametrize("func", ["argmax", "argmin"])
3027
@pytest.mark.parametrize("axis", [None, 0, 1, -1])
3128
@pytest.mark.parametrize("keepdims", [False, True])
32-
def test_argmax_argmin_bool(axis, keepdims):
29+
def test_argmax_argmin_bool(func, axis, keepdims):
3330
a = numpy.arange(2, dtype=dpnp.bool)
3431
a = numpy.tile(a, (2, 2))
3532
ia = dpnp.array(a)
3633

37-
np_res = numpy.argmax(a, axis=axis, keepdims=keepdims)
38-
dpnp_res = dpnp.argmax(ia, axis=axis, keepdims=keepdims)
34+
np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
35+
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
3936

4037
assert dpnp_res.shape == np_res.shape
4138
assert_allclose(dpnp_res, np_res)
4239

43-
np_res = numpy.argmin(a, axis=axis, keepdims=keepdims)
44-
dpnp_res = dpnp.argmin(ia, axis=axis, keepdims=keepdims)
4540

46-
assert dpnp_res.shape == np_res.shape
47-
assert_allclose(dpnp_res, np_res)
48-
49-
50-
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2])
51-
@pytest.mark.parametrize("keepdims", [False, True])
52-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
53-
def test_argmax_argmin_out(axis, keepdims, dtype):
54-
a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
41+
@pytest.mark.parametrize("func", ["argmax", "argmin"])
42+
def test_argmax_argmin_out(func):
43+
a = numpy.arange(6).reshape((2, 3))
5544
ia = dpnp.array(a)
5645

57-
np_res = numpy.argmax(a, axis=axis, keepdims=keepdims)
46+
np_res = getattr(numpy, func)(a, axis=0)
5847
dpnp_res = dpnp.array(numpy.empty_like(np_res))
59-
dpnp.argmax(ia, axis=axis, keepdims=keepdims, out=dpnp_res)
48+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
49+
assert_allclose(dpnp_res, np_res)
6050

61-
assert dpnp_res.shape == np_res.shape
51+
dpnp_res = dpt.asarray(numpy.empty_like(np_res))
52+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
6253
assert_allclose(dpnp_res, np_res)
6354

64-
np_res = numpy.argmin(a, axis=axis, keepdims=keepdims)
65-
dpnp_res = dpnp.array(numpy.empty_like(np_res))
66-
dpnp.argmin(ia, axis=axis, keepdims=keepdims, out=dpnp_res)
55+
dpnp_res = numpy.empty_like(np_res)
56+
with pytest.raises(TypeError):
57+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
6758

68-
assert dpnp_res.shape == np_res.shape
69-
assert_allclose(dpnp_res, np_res)
59+
dpnp_res = dpnp.array(numpy.empty((2, 3)))
60+
with pytest.raises(ValueError):
61+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)

0 commit comments

Comments
 (0)