Skip to content

Commit 041717e

Browse files
committed
add a test for count_nonzero
Parrot the test from test_{argmin, argmax}
1 parent 28f1dbf commit 041717e

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,44 @@ def test_argmin(x, data):
8888
ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
8989

9090

91+
@pytest.mark.min_version("2024.12")
92+
@given(
93+
x=hh.arrays(
94+
dtype=hh.real_dtypes,
95+
shape=hh.shapes(min_dims=1, min_side=1),
96+
elements={"allow_nan": False},
97+
),
98+
data=st.data(),
99+
)
100+
def test_count_nonzero(x, data):
101+
kw = data.draw(
102+
hh.kwargs(
103+
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
104+
keepdims=st.booleans(),
105+
),
106+
label="kw",
107+
)
108+
keepdims = kw.get("keepdims", False)
109+
110+
out = xp.count_nonzero(x, **kw)
111+
112+
ph.assert_default_index("count_nonzero", out.dtype)
113+
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
114+
ph.assert_keepdimable_shape(
115+
"count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
116+
)
117+
scalar_type = dh.get_scalar_type(x.dtype)
118+
119+
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
120+
count = int(out[out_idx])
121+
elements = []
122+
for idx in indices:
123+
s = scalar_type(x[idx])
124+
elements.append(s)
125+
expected = sum(el != 0 for el in elements)
126+
ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected)
127+
128+
91129
@given(hh.arrays(dtype=hh.all_dtypes, shape=()))
92130
def test_nonzero_zerodim_error(x):
93131
with pytest.raises(Exception):

0 commit comments

Comments
 (0)