Skip to content

Commit 54561bc

Browse files
committed
Cover all axis scenarios in test_stack
1 parent 326eff8 commit 54561bc

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from collections import deque
33
from itertools import product
4-
from typing import Iterable, Union
4+
from typing import Iterable, Iterator, Tuple, Union
55

66
from hypothesis import assume, given
77
from hypothesis import strategies as st
@@ -28,6 +28,16 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
2828
return st.shared(hh.shapes(*args, **kwargs), key="shape")
2929

3030

31+
def axis_ndindex(
32+
shape: Shape, axis: int
33+
) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
34+
assert axis >= 0 # sanity check
35+
axis_indices = [range(side) for side in shape[:axis]]
36+
for _ in range(axis, len(shape)):
37+
axis_indices.append([slice(None, None)])
38+
yield from product(*axis_indices)
39+
40+
3141
def assert_array_ndindex(
3242
func_name: str,
3343
x: Array,
@@ -115,10 +125,7 @@ def test_concat(dtypes, kw, data):
115125
)
116126
else:
117127
out_indices = ah.ndindex(out.shape)
118-
axis_indices = [range(side) for side in shapes[0][:_axis]]
119-
for _ in range(_axis, len(shape)):
120-
axis_indices.append([slice(None, None)])
121-
for idx in product(*axis_indices):
128+
for idx in axis_ndindex(shapes[0], _axis):
122129
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
123130
for x_num, x in enumerate(arrays, 1):
124131
indexed_x = x[idx]
@@ -344,18 +351,19 @@ def test_stack(shape, dtypes, kw, data):
344351
"stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw
345352
)
346353

347-
# TODO: adjust indices with nonzero axis
348-
if axis == 0:
349-
out_indices = ah.ndindex(out.shape)
350-
for i, x in enumerate(arrays, 1):
351-
msg_suffix = f" [stack({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}"
352-
for x_idx in ah.ndindex(x.shape):
354+
out_indices = ah.ndindex(out.shape)
355+
for idx in axis_ndindex(arrays[0].shape, axis=_axis):
356+
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
357+
print(f"{f_idx=}")
358+
for x_num, x in enumerate(arrays, 1):
359+
indexed_x = x[idx]
360+
for x_idx in ah.ndindex(indexed_x.shape):
353361
out_idx = next(out_indices)
354-
msg = (
355-
f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}"
362+
assert_equals(
363+
"stack",
364+
f"x{x_num}[{f_idx}][{x_idx}]",
365+
indexed_x[x_idx],
366+
f"out[{out_idx}]",
367+
out[out_idx],
368+
**kw,
356369
)
357-
msg += msg_suffix
358-
if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]):
359-
assert xp.isnan(out[out_idx]), msg
360-
else:
361-
assert out[out_idx] == x[x_idx], msg

0 commit comments

Comments
 (0)