Skip to content

Commit c2276bf

Browse files
committed
Minor test_concat improvements
1 parent b3d90a9 commit c2276bf

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
from hypothesis import given
24
from hypothesis import strategies as st
35

@@ -11,16 +13,23 @@
1113
@given(
1214
shape=hh.shapes(min_dims=1),
1315
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
16+
kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1
1417
data=st.data(),
1518
)
16-
def test_concat(shape, dtypes, data):
19+
def test_concat(shape, dtypes, kw, data):
1720
arrays = []
1821
for i, dtype in enumerate(dtypes, 1):
1922
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
2023
arrays.append(x)
21-
out = xp.concat(arrays)
24+
out = xp.concat(arrays, **kw)
2225
ph.assert_dtype("concat", dtypes, out.dtype)
23-
# TODO
26+
shapes = tuple(x.shape for x in arrays)
27+
if kw.get("axis", 0) == 0:
28+
pass # TODO: assert expected shape
29+
elif kw["axis"] is None:
30+
size = sum(math.prod(s) for s in shapes)
31+
ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw)
32+
# TODO: assert out elements match input arrays
2433

2534

2635
@given(

0 commit comments

Comments
 (0)