Skip to content

Commit 9a0ccf8

Browse files
committed
Manipulation tests clean up
1 parent a9b191b commit 9a0ccf8

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,11 @@ def test_expand_dims(x, axis):
129129
data=st.data(),
130130
)
131131
def test_squeeze(x, data):
132-
# axis=shared_shapes(min_side=1).flatmap(lambda s: nd_axes(len(s))),
132+
# TODO: generate valid negative axis (which keep uniqueness)
133133
squeezable_axes = st.sampled_from(
134134
[i for i, side in enumerate(x.shape) if side == 1]
135135
)
136136
axis = data.draw(
137-
# TODO: generate valid negative axis
138137
squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple),
139138
label="axis",
140139
)
@@ -157,20 +156,19 @@ def test_squeeze(x, data):
157156
assert_array_ndindex("squeeze", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape))
158157

159158

160-
@st.composite
161-
def flip_axis(draw, shape):
162-
if len(shape) == 0 or draw(st.booleans()):
163-
return None
164-
else:
165-
ndim = len(shape)
166-
return draw(st.integers(-ndim, ndim - 1) | xps.valid_tuple_axes(ndim))
167-
168-
169159
@given(
170-
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()),
171-
kw=hh.kwargs(axis=shared_shapes().flatmap(flip_axis)),
160+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
161+
data=st.data(),
172162
)
173-
def test_flip(x, kw):
163+
def test_flip(x, data):
164+
if x.ndim == 0:
165+
axis_strat = st.none()
166+
else:
167+
axis_strat = (
168+
st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
169+
)
170+
kw = data.draw(hh.kwargs(axis=axis_strat), label="kw")
171+
174172
out = xp.flip(x, **kw)
175173

176174
ph.assert_dtype("flip", x.dtype, out.dtype)
@@ -209,12 +207,6 @@ def test_permute_dims(x, axes):
209207
# TODO: test elements
210208

211209

212-
reshape_x_shapes = st.shared(
213-
hh.shapes().filter(lambda s: math.prod(s) <= MAX_SIDE),
214-
key="reshape x shape",
215-
)
216-
217-
218210
@st.composite
219211
def reshape_shapes(draw, shape):
220212
size = 1 if len(shape) == 0 else math.prod(shape)
@@ -227,21 +219,22 @@ def reshape_shapes(draw, shape):
227219

228220

229221
@given(
230-
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=reshape_x_shapes),
231-
shape=reshape_x_shapes.flatmap(reshape_shapes),
222+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)),
223+
data=st.data(),
232224
)
233-
def test_reshape(x, shape):
234-
assume(math.prod(shape) == math.prod(x.shape))
225+
def test_reshape(x, data):
226+
shape = data.draw(reshape_shapes(x.shape))
235227

236228
out = xp.reshape(x, shape)
237229

238230
ph.assert_dtype("reshape", x.dtype, out.dtype)
239231

240-
_shape = shape
232+
_shape = list(shape)
241233
if any(side == -1 for side in shape):
242234
size = math.prod(x.shape)
243235
rsize = math.prod(shape) * -1
244236
_shape[shape.index(-1)] = size / rsize
237+
_shape = tuple(_shape)
245238
ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape)
246239

247240
assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape))

0 commit comments

Comments
 (0)