|
16 | 16 | from . import xps
|
17 | 17 |
|
18 | 18 |
|
19 |
| -def assert_default_index(func_name: str, dtype: DataType): |
| 19 | +def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"): |
20 | 20 | f_dtype = dh.dtype_to_name[dtype]
|
21 | 21 | msg = (
|
22 |
| - f"out.dtype={f_dtype}, should be the default index dtype, " |
| 22 | + f"{repr_name}={f_dtype}, should be the default index dtype, " |
23 | 23 | f"which is either int32 or int64 [{func_name}()]"
|
24 | 24 | )
|
25 | 25 | assert dtype in (xp.int32, xp.int64), msg
|
@@ -95,11 +95,42 @@ def test_argmin(x, data):
|
95 | 95 | assert_equals("argmin", int, out_idx, min_i, expected)
|
96 | 96 |
|
97 | 97 |
|
98 |
| -# TODO: generate kwargs, skip if opted out |
99 | 98 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
|
100 | 99 | def test_nonzero(x):
|
101 |
| - xp.nonzero(x) |
102 |
| - # TODO |
| 100 | + out = xp.nonzero(x) |
| 101 | + if x.ndim == 0: |
| 102 | + assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays" |
| 103 | + else: |
| 104 | + assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" |
| 105 | + size = out[0].size |
| 106 | + for i in range(len(out)): |
| 107 | + assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" |
| 108 | + assert ( |
| 109 | + out[i].size == size |
| 110 | + ), f"out[{i}].size={x.size}, but should be out[0].size={size}" |
| 111 | + assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") |
| 112 | + indices = [] |
| 113 | + if x.dtype == xp.bool: |
| 114 | + for idx in ah.ndindex(x.shape): |
| 115 | + if x[idx]: |
| 116 | + indices.append(idx) |
| 117 | + else: |
| 118 | + for idx in ah.ndindex(x.shape): |
| 119 | + if x[idx] != 0: |
| 120 | + indices.append(idx) |
| 121 | + if x.ndim == 0: |
| 122 | + assert out[0].size == len( |
| 123 | + indices |
| 124 | + ), f"{out[0].size=}, but should be {len(indices)}" |
| 125 | + else: |
| 126 | + for i in range(size): |
| 127 | + idx = tuple(int(x[i]) for x in out) |
| 128 | + f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" |
| 129 | + f_element = f"x[{idx}]={x[idx]}" |
| 130 | + assert idx in indices, f"{f_idx} results in {f_element}, a zero element" |
| 131 | + assert ( |
| 132 | + idx == indices[i] |
| 133 | + ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" |
103 | 134 |
|
104 | 135 |
|
105 | 136 | # TODO: skip if opted out
|
|
0 commit comments