|
| 1 | +import math |
| 2 | + |
1 | 3 | from hypothesis import given
|
2 | 4 | from hypothesis import strategies as st
|
3 | 5 |
|
|
11 | 13 | @given(
|
12 | 14 | shape=hh.shapes(min_dims=1),
|
13 | 15 | dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
|
| 16 | + kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1 |
14 | 17 | data=st.data(),
|
15 | 18 | )
|
16 |
| -def test_concat(shape, dtypes, data): |
| 19 | +def test_concat(shape, dtypes, kw, data): |
17 | 20 | arrays = []
|
18 | 21 | for i, dtype in enumerate(dtypes, 1):
|
19 | 22 | x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
|
20 | 23 | arrays.append(x)
|
21 |
| - out = xp.concat(arrays) |
| 24 | + out = xp.concat(arrays, **kw) |
22 | 25 | ph.assert_dtype("concat", dtypes, out.dtype)
|
23 |
| - # TODO |
| 26 | + shapes = tuple(x.shape for x in arrays) |
| 27 | + if kw.get("axis", 0) == 0: |
| 28 | + pass # TODO: assert expected shape |
| 29 | + elif kw["axis"] is None: |
| 30 | + size = sum(math.prod(s) for s in shapes) |
| 31 | + ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw) |
| 32 | + # TODO: assert out elements match input arrays |
24 | 33 |
|
25 | 34 |
|
26 | 35 | @given(
|
|
0 commit comments