Skip to content

Commit eb00670

Browse files
committed
Cover axis in test_concat
1 parent 2eeda9d commit eb00670

File tree

3 files changed

+79
-16
lines changed

3 files changed

+79
-16
lines changed

array_api_tests/meta/test_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22

3+
from .. import array_helpers as ah
34
from ..test_creation_functions import frange
5+
from ..test_manipulation_functions import axis_ndindex
46
from ..test_signatures import extension_module
57
from ..test_statistical_functions import axes_ndindex
68

@@ -27,6 +29,30 @@ def test_frange(r, size, elements):
2729
assert list(r) == elements
2830

2931

32+
@pytest.mark.parametrize(
33+
"shape, expected",
34+
[((), [()])],
35+
)
36+
def test_ndindex(shape, expected):
37+
assert list(ah.ndindex(shape)) == expected
38+
39+
40+
@pytest.mark.parametrize(
41+
"shape, axis, expected",
42+
[
43+
((1,), 0, [(slice(None, None),)]),
44+
((1, 2), 0, [(slice(None, None), slice(None, None))]),
45+
(
46+
(2, 4),
47+
1,
48+
[(0, slice(None, None)), (1, slice(None, None))],
49+
),
50+
],
51+
)
52+
def test_axis_ndindex(shape, axis, expected):
53+
assert list(axis_ndindex(shape, axis)) == expected
54+
55+
3056
@pytest.mark.parametrize(
3157
"shape, axes, expected",
3258
[

array_api_tests/test_manipulation_functions.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
from collections import deque
3-
from typing import Iterable, Union
3+
from itertools import product
4+
from typing import Iterable, Iterator, Tuple, Union
45

56
from hypothesis import assume, given
67
from hypothesis import strategies as st
@@ -43,6 +44,28 @@ def assert_array_ndindex(
4344
assert out[out_idx] == x[x_idx], msg
4445

4546

47+
def axis_ndindex(
48+
shape: Shape, axis: int
49+
) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
50+
iterables = [range(side) for side in shape[:axis]]
51+
for _ in range(len(shape[axis:])):
52+
iterables.append([slice(None, None)])
53+
yield from product(*iterables)
54+
55+
56+
def assert_equals(
57+
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
58+
):
59+
msg = (
60+
f"{out_repr}={out_val}, should be {x_repr}={x_val} "
61+
f"[{func_name}({ph.fmt_kw(kw)})]"
62+
)
63+
if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val):
64+
assert xp.isnan(x_val), msg
65+
else:
66+
assert out_val == out_val, msg
67+
68+
4669
@st.composite
4770
def concat_shapes(draw, shape, axis):
4871
shape = list(shape)
@@ -85,21 +108,35 @@ def test_concat(dtypes, kw, data):
85108
shape = tuple(shape)
86109
ph.assert_result_shape("concat", shapes, out.shape, shape, **kw)
87110

88-
# TODO: adjust indices with nonzero axis
89-
if axis is None or axis == 0:
90-
out_indices = ah.ndindex(out.shape)
91-
for i, x in enumerate(arrays, 1):
92-
msg_suffix = f" [concat({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}"
111+
if axis is None:
112+
out_indices = (i for i in range(out.size))
113+
for x_num, x in enumerate(arrays, 1):
93114
for x_idx in ah.ndindex(x.shape):
94-
out_idx = next(out_indices)
95-
msg = (
96-
f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}"
115+
out_i = next(out_indices)
116+
assert_equals(
117+
"concat",
118+
f"x{x_num}[{x_idx}]",
119+
x[x_idx],
120+
f"out[{out_i}]",
121+
out[out_i],
122+
**kw,
97123
)
98-
msg += msg_suffix
99-
if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]):
100-
assert xp.isnan(out[out_idx]), msg
101-
else:
102-
assert out[out_idx] == x[x_idx], msg
124+
else:
125+
out_indices = ah.ndindex(out.shape)
126+
for idx in axis_ndindex(shape, axis):
127+
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
128+
for x_num, x in enumerate(arrays, 1):
129+
indexed_x = x[idx]
130+
for x_idx in ah.ndindex(indexed_x.shape):
131+
out_idx = next(out_indices)
132+
assert_equals(
133+
"concat",
134+
f"x{x_num}[{f_idx}][{x_idx}]",
135+
indexed_x[x_idx],
136+
f"out[{out_idx}]",
137+
out[out_idx],
138+
**kw,
139+
)
103140

104141

105142
@given(

array_api_tests/test_statistical_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, .
4242
axes_iterables = []
4343
for axis, side in enumerate(shape):
4444
if axis in axes:
45-
base_iterables.append((None,))
45+
base_iterables.append([None])
4646
axes_iterables.append(range(side))
4747
else:
4848
base_iterables.append(range(side))
49-
axes_iterables.append((None,))
49+
axes_iterables.append([None])
5050
for base_idx in product(*base_iterables):
5151
indices = []
5252
for idx in product(*axes_iterables):

0 commit comments

Comments
 (0)