Skip to content

Commit 305ad42

Browse files
committed
Test concat()
1 parent 0b36ade commit 305ad42

File tree

1 file changed

+51
-12
lines changed

1 file changed

+51
-12
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from . import xps
1414
from .typing import Array, Shape
1515

16+
MAX_SIDE = hh.MAX_ARRAY_SIZE // 64
17+
MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims
18+
1619

1720
def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
1821
key = "shape"
@@ -40,26 +43,63 @@ def assert_array_ndindex(
4043
assert out[out_idx] == x[x_idx], msg
4144

4245

46+
@st.composite
47+
def concat_shapes(draw, shape, axis):
48+
shape = list(shape)
49+
shape[axis] = draw(st.integers(1, MAX_SIDE))
50+
return tuple(shape)
51+
52+
4353
@given(
44-
shape=hh.shapes(min_dims=1),
4554
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
46-
kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1
55+
kw=hh.kwargs(axis=st.none() | st.integers(-MAX_DIMS, MAX_DIMS - 1)),
4756
data=st.data(),
4857
)
49-
def test_concat(shape, dtypes, kw, data):
58+
def test_concat(dtypes, kw, data):
59+
axis = kw.get("axis", 0)
60+
if axis is None:
61+
shape_strat = hh.shapes()
62+
else:
63+
_axis = axis if axis >= 0 else abs(axis) - 1
64+
shape_strat = shared_shapes(min_dims=_axis + 1).flatmap(
65+
lambda s: concat_shapes(s, axis)
66+
)
5067
arrays = []
5168
for i, dtype in enumerate(dtypes, 1):
52-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
69+
x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}")
5370
arrays.append(x)
71+
5472
out = xp.concat(arrays, **kw)
73+
5574
ph.assert_dtype("concat", dtypes, out.dtype)
75+
5676
shapes = tuple(x.shape for x in arrays)
57-
if kw.get("axis", 0) == 0:
58-
pass # TODO: assert expected shape
59-
elif kw["axis"] is None:
77+
axis = kw.get("axis", 0)
78+
if axis is None:
6079
size = sum(math.prod(s) for s in shapes)
61-
ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw)
62-
# TODO: assert out elements match input arrays
80+
shape = (size,)
81+
else:
82+
shape = list(shapes[0])
83+
for other_shape in shapes[1:]:
84+
shape[axis] += other_shape[axis]
85+
shape = tuple(shape)
86+
ph.assert_result_shape("concat", shapes, out.shape, shape, **kw)
87+
88+
# TODO: adjust indices with nonzero axis
89+
if axis is None or axis == 0:
90+
out_indices = ah.ndindex(out.shape)
91+
for i, x in enumerate(arrays, 1):
92+
msg_suffix = f" [concat({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}"
93+
for x_idx in ah.ndindex(x.shape):
94+
out_idx = next(out_indices)
95+
msg = (
96+
f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}"
97+
)
98+
msg += msg_suffix
99+
if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]):
100+
assert xp.isnan(out[out_idx]), msg
101+
else:
102+
assert out[out_idx] == x[x_idx], msg
63103

64104

65105
@given(
@@ -169,9 +209,8 @@ def test_permute_dims(x, axes):
169209
# TODO: test elements
170210

171211

172-
MAX_RESHAPE_SIDE = hh.MAX_ARRAY_SIZE // 64
173212
reshape_x_shapes = st.shared(
174-
hh.shapes().filter(lambda s: math.prod(s) <= MAX_RESHAPE_SIDE),
213+
hh.shapes().filter(lambda s: math.prod(s) <= MAX_SIDE),
175214
key="reshape x shape",
176215
)
177216

@@ -180,7 +219,7 @@ def test_permute_dims(x, axes):
180219
def reshape_shapes(draw, shape):
181220
size = 1 if len(shape) == 0 else math.prod(shape)
182221
rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size))
183-
assume(all(side <= MAX_RESHAPE_SIDE for side in rshape))
222+
assume(all(side <= MAX_SIDE for side in rshape))
184223
if len(rshape) != 0 and size > 0 and draw(st.booleans()):
185224
index = draw(st.integers(0, len(rshape) - 1))
186225
rshape[index] = -1

0 commit comments

Comments
 (0)