|
| 1 | +import dpctl.tensor as dpt |
1 | 2 | import numpy
|
2 | 3 | import pytest
|
3 | 4 | from numpy.testing import assert_allclose
|
|
7 | 8 | from .helper import get_all_dtypes
|
8 | 9 |
|
9 | 10 |
|
| 11 | +@pytest.mark.parametrize("func", ["argmax", "argmin"]) |
10 | 12 | @pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2])
|
11 | 13 | @pytest.mark.parametrize("keepdims", [False, True])
|
12 | 14 | @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
|
13 |
| -def test_argmax_argmin(axis, keepdims, dtype): |
| 15 | +def test_argmax_argmin(func, axis, keepdims, dtype): |
14 | 16 | a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
|
15 | 17 | ia = dpnp.array(a)
|
16 | 18 |
|
17 |
| - np_res = numpy.argmax(a, axis=axis, keepdims=keepdims) |
18 |
| - dpnp_res = dpnp.argmax(ia, axis=axis, keepdims=keepdims) |
19 |
| - |
20 |
| - assert dpnp_res.shape == np_res.shape |
21 |
| - assert_allclose(dpnp_res, np_res) |
22 |
| - |
23 |
| - np_res = numpy.argmin(a, axis=axis, keepdims=keepdims) |
24 |
| - dpnp_res = dpnp.argmin(ia, axis=axis, keepdims=keepdims) |
| 19 | + np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) |
| 20 | + dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) |
25 | 21 |
|
26 | 22 | assert dpnp_res.shape == np_res.shape
|
27 | 23 | assert_allclose(dpnp_res, np_res)
|
28 | 24 |
|
29 | 25 |
|
| 26 | +@pytest.mark.parametrize("func", ["argmax", "argmin"]) |
30 | 27 | @pytest.mark.parametrize("axis", [None, 0, 1, -1])
|
31 | 28 | @pytest.mark.parametrize("keepdims", [False, True])
|
32 |
| -def test_argmax_argmin_bool(axis, keepdims): |
| 29 | +def test_argmax_argmin_bool(func, axis, keepdims): |
33 | 30 | a = numpy.arange(2, dtype=dpnp.bool)
|
34 | 31 | a = numpy.tile(a, (2, 2))
|
35 | 32 | ia = dpnp.array(a)
|
36 | 33 |
|
37 |
| - np_res = numpy.argmax(a, axis=axis, keepdims=keepdims) |
38 |
| - dpnp_res = dpnp.argmax(ia, axis=axis, keepdims=keepdims) |
| 34 | + np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) |
| 35 | + dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) |
39 | 36 |
|
40 | 37 | assert dpnp_res.shape == np_res.shape
|
41 | 38 | assert_allclose(dpnp_res, np_res)
|
42 | 39 |
|
43 |
| - np_res = numpy.argmin(a, axis=axis, keepdims=keepdims) |
44 |
| - dpnp_res = dpnp.argmin(ia, axis=axis, keepdims=keepdims) |
45 | 40 |
|
46 |
| - assert dpnp_res.shape == np_res.shape |
47 |
| - assert_allclose(dpnp_res, np_res) |
48 |
| - |
49 |
| - |
50 |
| -@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2]) |
51 |
| -@pytest.mark.parametrize("keepdims", [False, True]) |
52 |
| -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) |
53 |
| -def test_argmax_argmin_out(axis, keepdims, dtype): |
54 |
| - a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8)) |
| 41 | +@pytest.mark.parametrize("func", ["argmax", "argmin"]) |
| 42 | +def test_argmax_argmin_out(func): |
| 43 | + a = numpy.arange(6).reshape((2, 3)) |
55 | 44 | ia = dpnp.array(a)
|
56 | 45 |
|
57 |
| - np_res = numpy.argmax(a, axis=axis, keepdims=keepdims) |
| 46 | + np_res = getattr(numpy, func)(a, axis=0) |
58 | 47 | dpnp_res = dpnp.array(numpy.empty_like(np_res))
|
59 |
| - dpnp.argmax(ia, axis=axis, keepdims=keepdims, out=dpnp_res) |
| 48 | + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) |
| 49 | + assert_allclose(dpnp_res, np_res) |
60 | 50 |
|
61 |
| - assert dpnp_res.shape == np_res.shape |
| 51 | + dpnp_res = dpt.asarray(numpy.empty_like(np_res)) |
| 52 | + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) |
62 | 53 | assert_allclose(dpnp_res, np_res)
|
63 | 54 |
|
64 |
| - np_res = numpy.argmin(a, axis=axis, keepdims=keepdims) |
65 |
| - dpnp_res = dpnp.array(numpy.empty_like(np_res)) |
66 |
| - dpnp.argmin(ia, axis=axis, keepdims=keepdims, out=dpnp_res) |
| 55 | + dpnp_res = numpy.empty_like(np_res) |
| 56 | + with pytest.raises(TypeError): |
| 57 | + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) |
67 | 58 |
|
68 |
| - assert dpnp_res.shape == np_res.shape |
69 |
| - assert_allclose(dpnp_res, np_res) |
| 59 | + dpnp_res = dpnp.array(numpy.empty((2, 3))) |
| 60 | + with pytest.raises(ValueError): |
| 61 | + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) |
0 commit comments