Skip to content

Commit 0a0b5c0

Browse files
committed
Cover everything for test_nonzero
1 parent 064aae6 commit 0a0b5c0

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from . import xps
1717

1818

19-
def assert_default_index(func_name: str, dtype: DataType):
19+
def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"):
2020
f_dtype = dh.dtype_to_name[dtype]
2121
msg = (
22-
f"out.dtype={f_dtype}, should be the default index dtype, "
22+
f"{repr_name}={f_dtype}, should be the default index dtype, "
2323
f"which is either int32 or int64 [{func_name}()]"
2424
)
2525
assert dtype in (xp.int32, xp.int64), msg
@@ -95,11 +95,42 @@ def test_argmin(x, data):
9595
assert_equals("argmin", int, out_idx, min_i, expected)
9696

9797

98-
# TODO: generate kwargs, skip if opted out
9998
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
10099
def test_nonzero(x):
101-
xp.nonzero(x)
102-
# TODO
100+
out = xp.nonzero(x)
101+
if x.ndim == 0:
102+
assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays"
103+
else:
104+
assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
105+
size = out[0].size
106+
for i in range(len(out)):
107+
assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"
108+
assert (
109+
out[i].size == size
110+
), f"out[{i}].size={x.size}, but should be out[0].size={size}"
111+
assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
112+
indices = []
113+
if x.dtype == xp.bool:
114+
for idx in ah.ndindex(x.shape):
115+
if x[idx]:
116+
indices.append(idx)
117+
else:
118+
for idx in ah.ndindex(x.shape):
119+
if x[idx] != 0:
120+
indices.append(idx)
121+
if x.ndim == 0:
122+
assert out[0].size == len(
123+
indices
124+
), f"{out[0].size=}, but should be {len(indices)}"
125+
else:
126+
for i in range(size):
127+
idx = tuple(int(x[i]) for x in out)
128+
f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}"
129+
f_element = f"x[{idx}]={x[idx]}"
130+
assert idx in indices, f"{f_idx} results in {f_element}, a zero element"
131+
assert (
132+
idx == indices[i]
133+
), f"{f_idx} is in the wrong position, should be {indices.index(idx)}"
103134

104135

105136
# TODO: skip if opted out

0 commit comments

Comments
 (0)