|
3 | 3 | from itertools import product
|
4 | 4 | from typing import Iterable, Iterator, Tuple, Union
|
5 | 5 |
|
| 6 | +import pytest |
6 | 7 | from hypothesis import assume, given
|
7 | 8 | from hypothesis import strategies as st
|
8 | 9 |
|
@@ -168,23 +169,26 @@ def test_expand_dims(x, axis):
|
168 | 169 | data=st.data(),
|
169 | 170 | )
|
170 | 171 | def test_squeeze(x, data):
|
171 |
| - # TODO: generate valid negative axis (which keep uniqueness) |
172 |
| - squeezable_axes = st.sampled_from( |
173 |
| - [i for i, side in enumerate(x.shape) if side == 1] |
174 |
| - ) |
| 172 | + axes = st.integers(-x.ndim, x.ndim - 1) |
175 | 173 | axis = data.draw(
|
176 |
| - squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), |
| 174 | + axes |
| 175 | + | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple), |
177 | 176 | label="axis",
|
178 | 177 | )
|
179 | 178 |
|
| 179 | + axes = (axis,) if isinstance(axis, int) else axis |
| 180 | + axes = normalise_axis(axes, x.ndim) |
| 181 | + |
| 182 | + squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] |
| 183 | + if any(i not in squeezable_axes for i in axes): |
| 184 | + with pytest.raises(ValueError): |
| 185 | + xp.squeeze(x, axis) |
| 186 | + return |
| 187 | + |
180 | 188 | out = xp.squeeze(x, axis)
|
181 | 189 |
|
182 | 190 | ph.assert_dtype("squeeze", x.dtype, out.dtype)
|
183 | 191 |
|
184 |
| - if isinstance(axis, int): |
185 |
| - axes = (axis,) |
186 |
| - else: |
187 |
| - axes = axis |
188 | 192 | shape = []
|
189 | 193 | for i, side in enumerate(x.shape):
|
190 | 194 | if i not in axes:
|
|
0 commit comments