|
10 | 10 | from . import dtype_helpers as dh
|
11 | 11 | from . import hypothesis_helpers as hh
|
12 | 12 | from . import pytest_helpers as ph
|
| 13 | +from . import shape_helpers as sh |
13 | 14 | from . import xps
|
14 | 15 | from .typing import DataType, Param, Scalar, ScalarType, Shape
|
15 | 16 |
|
@@ -140,12 +141,26 @@ def test_getitem_mask(shape, data):
|
140 | 141 | out_shape = (size,) + x.shape[key.ndim :]
|
141 | 142 | ph.assert_shape("__getitem__", out.shape, out_shape)
|
142 | 143 |
|
| 144 | + if not any(s == 0 for s in key.shape): |
| 145 | + assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios |
| 146 | + out_indices = sh.ndindex(out.shape) |
| 147 | + for x_idx in sh.ndindex(x.shape): |
| 148 | + if key[x_idx]: |
| 149 | + out_idx = next(out_indices) |
| 150 | + ph.assert_0d_equals( |
| 151 | + "__getitem__", |
| 152 | + f"x[{x_idx}]", |
| 153 | + x[x_idx], |
| 154 | + f"out[{out_idx}]", |
| 155 | + out[out_idx], |
| 156 | + ) |
| 157 | + |
143 | 158 |
|
144 | 159 | @given(hh.shapes(min_side=1), st.data())
|
145 | 160 | def test_setitem_mask(shape, data):
|
146 | 161 | x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
|
147 | 162 | key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
|
148 |
| - value = data.draw(xps.from_dtype(x.dtype), label="value") |
| 163 | + value = data.draw(xps.from_dtype(x.dtype), label="value") # TODO: more values |
149 | 164 |
|
150 | 165 | res = xp.asarray(x, copy=True)
|
151 | 166 | res[key] = value
|
|
0 commit comments