Skip to content

Commit 6ff33c9

Browse files
committed
Check error raising in test_squeeze, use negative axes
1 parent 567686a commit 6ff33c9

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from itertools import product
44
from typing import Iterable, Iterator, Tuple, Union
55

6+
import pytest
67
from hypothesis import assume, given
78
from hypothesis import strategies as st
89

@@ -168,23 +169,26 @@ def test_expand_dims(x, axis):
168169
data=st.data(),
169170
)
170171
def test_squeeze(x, data):
171-
# TODO: generate valid negative axis (which keep uniqueness)
172-
squeezable_axes = st.sampled_from(
173-
[i for i, side in enumerate(x.shape) if side == 1]
174-
)
172+
axes = st.integers(-x.ndim, x.ndim - 1)
175173
axis = data.draw(
176-
squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple),
174+
axes
175+
| st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple),
177176
label="axis",
178177
)
179178

179+
axes = (axis,) if isinstance(axis, int) else axis
180+
axes = normalise_axis(axes, x.ndim)
181+
182+
squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1]
183+
if any(i not in squeezable_axes for i in axes):
184+
with pytest.raises(ValueError):
185+
xp.squeeze(x, axis)
186+
return
187+
180188
out = xp.squeeze(x, axis)
181189

182190
ph.assert_dtype("squeeze", x.dtype, out.dtype)
183191

184-
if isinstance(axis, int):
185-
axes = (axis,)
186-
else:
187-
axes = axis
188192
shape = []
189193
for i, side in enumerate(x.shape):
190194
if i not in axes:

0 commit comments

Comments
 (0)