Skip to content

Commit 056fbe3

Browse files
committed
Cover everything for argmin and argmax tests
1 parent cd8f117 commit 056fbe3

File tree

1 file changed

+85
-10
lines changed

1 file changed

+85
-10
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,98 @@
11
from hypothesis import given
22
from hypothesis import strategies as st
33

4+
from array_api_tests.test_statistical_functions import (
5+
assert_equals,
6+
assert_keepdimable_shape,
7+
axes_ndindex,
8+
normalise_axis,
9+
)
10+
from array_api_tests.typing import DataType
11+
412
from . import _array_module as xp
13+
from . import array_helpers as ah
14+
from . import dtype_helpers as dh
515
from . import hypothesis_helpers as hh
616
from . import xps
717

818

9-
# TODO: generate kwargs
10-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
11-
def test_argmin(x):
12-
xp.argmin(x)
13-
# TODO
19+
def assert_default_index(func_name: str, dtype: DataType):
20+
f_dtype = dh.dtype_to_name[dtype]
21+
msg = (
22+
f"out.dtype={f_dtype}, should be the default index dtype, "
23+
f"which is either int32 or int64 [{func_name}()]"
24+
)
25+
assert dtype in (xp.int32, xp.int64), msg
1426

1527

16-
# TODO: generate kwargs
17-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
18-
def test_argmax(x):
19-
xp.argmax(x)
20-
# TODO
28+
@given(
29+
x=xps.arrays(
30+
dtype=xps.numeric_dtypes(),
31+
shape=hh.shapes(min_side=1),
32+
elements={"allow_nan": False},
33+
),
34+
data=st.data(),
35+
)
36+
def test_argmax(x, data):
37+
kw = data.draw(
38+
hh.kwargs(
39+
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
40+
keepdims=st.booleans(),
41+
),
42+
label="kw",
43+
)
44+
45+
out = xp.argmax(x, **kw)
46+
47+
assert_default_index("argmax", out.dtype)
48+
axes = normalise_axis(kw.get("axis", None), x.ndim)
49+
assert_keepdimable_shape(
50+
"argmax", out.shape, x.shape, axes, kw.get("keepdims", False), **kw
51+
)
52+
scalar_type = dh.get_scalar_type(x.dtype)
53+
for indices, out_idx in zip(axes_ndindex(x.shape, axes), ah.ndindex(out.shape)):
54+
max_i = int(out[out_idx])
55+
elements = []
56+
for idx in indices:
57+
s = scalar_type(x[idx])
58+
elements.append(s)
59+
expected = max(range(len(elements)), key=elements.__getitem__)
60+
assert_equals("argmax", int, out_idx, max_i, expected)
61+
62+
63+
@given(
64+
x=xps.arrays(
65+
dtype=xps.numeric_dtypes(),
66+
shape=hh.shapes(min_side=1),
67+
elements={"allow_nan": False},
68+
),
69+
data=st.data(),
70+
)
71+
def test_argmin(x, data):
72+
kw = data.draw(
73+
hh.kwargs(
74+
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
75+
keepdims=st.booleans(),
76+
),
77+
label="kw",
78+
)
79+
80+
out = xp.argmin(x, **kw)
81+
82+
assert_default_index("argmin", out.dtype)
83+
axes = normalise_axis(kw.get("axis", None), x.ndim)
84+
assert_keepdimable_shape(
85+
"argmin", out.shape, x.shape, axes, kw.get("keepdims", False), **kw
86+
)
87+
scalar_type = dh.get_scalar_type(x.dtype)
88+
for indices, out_idx in zip(axes_ndindex(x.shape, axes), ah.ndindex(out.shape)):
89+
min_i = int(out[out_idx])
90+
elements = []
91+
for idx in indices:
92+
s = scalar_type(x[idx])
93+
elements.append(s)
94+
expected = min(range(len(elements)), key=elements.__getitem__)
95+
assert_equals("argmin", int, out_idx, min_i, expected)
2196

2297

2398
# TODO: generate kwargs, skip if opted out

0 commit comments

Comments
 (0)