Skip to content

Commit 3f28c60

Browse files
committed
Cover invalid axis in test_expand_dims
1 parent 6ff33c9 commit 3f28c60

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,17 @@ def test_concat(dtypes, kw, data):
144144

145145
@given(
146146
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()),
147-
axis=shared_shapes().flatmap(lambda s: st.integers(-len(s) - 1, len(s))),
147+
axis=shared_shapes().flatmap(
148+
# Generate both valid and invalid axis
149+
lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s))
150+
),
148151
)
149152
def test_expand_dims(x, axis):
153+
if axis < -x.ndim - 1 or axis > x.ndim:
154+
with pytest.raises(IndexError):
155+
xp.expand_dims(x, axis=axis)
156+
return
157+
150158
out = xp.expand_dims(x, axis=axis)
151159

152160
ph.assert_dtype("expand_dims", x.dtype, out.dtype)

0 commit comments

Comments
 (0)