Skip to content

Commit a9b191b

Browse files
committed
Test stack()
1 parent 305ad42 commit a9b191b

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,15 +275,46 @@ def test_roll(x, data):
275275

276276

277277
@given(
278-
shape=hh.shapes(),
278+
shape=shared_shapes(min_dims=1),
279279
dtypes=hh.mutually_promotable_dtypes(None),
280+
kw=hh.kwargs(
281+
axis=shared_shapes(min_dims=1).flatmap(
282+
lambda s: st.integers(-len(s), len(s) - 1)
283+
)
284+
),
280285
data=st.data(),
281286
)
282-
def test_stack(shape, dtypes, data):
287+
def test_stack(shape, dtypes, kw, data):
283288
arrays = []
284289
for i, dtype in enumerate(dtypes, 1):
285290
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
286291
arrays.append(x)
287-
out = xp.stack(arrays)
292+
293+
out = xp.stack(arrays, **kw)
294+
288295
ph.assert_dtype("stack", dtypes, out.dtype)
289-
# TODO
296+
297+
axis = kw.get("axis", 0)
298+
_axis = axis if axis >= 0 else len(shape) + axis + 1
299+
_shape = list(shape)
300+
_shape.insert(_axis, len(arrays))
301+
_shape = tuple(_shape)
302+
ph.assert_result_shape(
303+
"stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw
304+
)
305+
306+
# TODO: adjust indices with nonzero axis
307+
if axis == 0:
308+
out_indices = ah.ndindex(out.shape)
309+
for i, x in enumerate(arrays, 1):
310+
msg_suffix = f" [stack({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}"
311+
for x_idx in ah.ndindex(x.shape):
312+
out_idx = next(out_indices)
313+
msg = (
314+
f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}"
315+
)
316+
msg += msg_suffix
317+
if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]):
318+
assert xp.isnan(out[out_idx]), msg
319+
else:
320+
assert out[out_idx] == x[x_idx], msg

0 commit comments

Comments
 (0)