Skip to content

Commit 1c3bb46

Browse files
committed
Test 0 sided indexed arrays
1 parent 670ab72 commit 1c3bb46

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

xptests/test_array_object.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,29 +46,35 @@ def test_getitem(shape, data):
4646

4747
out = x[key]
4848

49+
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
50+
4951
_key = tuple(key) if isinstance(key, tuple) else (key,)
5052
if Ellipsis in _key:
5153
start_a = _key.index(Ellipsis)
5254
stop_a = start_a + (len(shape) - (len(_key) - 1))
5355
slices = tuple(slice(None, None) for _ in range(start_a, stop_a))
5456
_key = _key[:start_a] + slices + _key[start_a + 1 :]
5557
axes_indices = []
58+
out_shape = []
5659
for a, i in enumerate(_key):
5760
if isinstance(i, int):
5861
axes_indices.append([i])
5962
else:
6063
side = shape[a]
6164
indices = range(side)[i]
62-
assume(len(indices) > 0) # TODO: test 0-sided arrays
6365
axes_indices.append(indices)
64-
expected = []
66+
out_shape.append(len(indices))
67+
out_shape = tuple(out_shape)
68+
ph.assert_shape("__getitem__", out.shape, out_shape)
69+
assume(all(len(indices) > 0 for indices in axes_indices))
70+
out_obj = []
6571
for idx in product(*axes_indices):
6672
val = obj
6773
for i in idx:
6874
val = val[i]
69-
expected.append(val)
70-
expected = reshape(expected, out.shape)
71-
expected = xp.asarray(expected, dtype=dtype)
75+
out_obj.append(val)
76+
out_obj = reshape(out_obj, out_shape)
77+
expected = xp.asarray(out_obj, dtype=dtype)
7278
ph.assert_array("__getitem__", out, expected)
7379

7480

0 commit comments

Comments
 (0)