Skip to content

Commit 42df02e

Browse files
committed
Smoke all manipulation methods
1 parent c2276bf commit 42df02e

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from . import pytest_helpers as ph
1010
from . import xps
1111

12+
shared_shapes = st.shared(hh.shapes(), key="shape")
13+
1214

1315
@given(
1416
shape=hh.shapes(min_dims=1),
@@ -32,6 +34,81 @@ def test_concat(shape, dtypes, kw, data):
3234
# TODO: assert out elements match input arrays
3335

3436

37+
@given(
38+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes),
39+
axis=shared_shapes.flatmap(lambda s: st.integers(-len(s), len(s))),
40+
)
41+
def test_expand_dims(x, axis):
42+
xp.expand_dims(x, axis=axis)
43+
# TODO
44+
45+
46+
@given(
47+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes),
48+
kw=hh.kwargs(
49+
axis=st.one_of(
50+
st.none(),
51+
shared_shapes.flatmap(
52+
lambda s: st.none()
53+
if len(s) == 0
54+
else st.integers(-len(s) + 1, len(s) - 1),
55+
),
56+
)
57+
),
58+
)
59+
def test_flip(x, kw):
60+
xp.flip(x, **kw)
61+
# TODO
62+
63+
64+
@given(
65+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes),
66+
axes=shared_shapes.flatmap(
67+
lambda s: st.lists(
68+
st.integers(0, max(len(s) - 1, 0)),
69+
min_size=len(s),
70+
max_size=len(s),
71+
unique=True,
72+
).map(tuple)
73+
),
74+
)
75+
def test_permute_dims(x, axes):
76+
xp.permute_dims(x, axes)
77+
# TODO
78+
79+
80+
@given(
81+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes),
82+
shape=shared_shapes, # TODO: test more compatible shapes
83+
)
84+
def test_reshape(x, shape):
85+
xp.reshape(x, shape)
86+
# TODO
87+
88+
89+
@given(
90+
# TODO: axis arguments, update shift respectively
91+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes),
92+
shift=shared_shapes.flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))),
93+
)
94+
def test_roll(x, shift):
95+
xp.roll(x, shift)
96+
# TODO
97+
98+
99+
@given(
100+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes),
101+
axis=shared_shapes.flatmap(
102+
lambda s: st.just(0)
103+
if len(s) == 0
104+
else st.integers(-len(s) + 1, len(s) - 1).filter(lambda i: s[i] == 1)
105+
), # TODO: tuple of axis i.e. axes
106+
)
107+
def test_squeeze(x, axis):
108+
xp.squeeze(x, axis)
109+
# TODO
110+
111+
35112
@given(
36113
shape=hh.shapes(),
37114
dtypes=hh.mutually_promotable_dtypes(None),

0 commit comments

Comments
 (0)