9
9
from . import pytest_helpers as ph
10
10
from . import xps
11
11
12
+ shared_shapes = st .shared (hh .shapes (), key = "shape" )
13
+
12
14
13
15
@given (
14
16
shape = hh .shapes (min_dims = 1 ),
@@ -32,6 +34,81 @@ def test_concat(shape, dtypes, kw, data):
32
34
# TODO: assert out elements match input arrays
33
35
34
36
37
+ @given (
38
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
39
+ axis = shared_shapes .flatmap (lambda s : st .integers (- len (s ), len (s ))),
40
+ )
41
+ def test_expand_dims (x , axis ):
42
+ xp .expand_dims (x , axis = axis )
43
+ # TODO
44
+
45
+
46
+ @given (
47
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
48
+ kw = hh .kwargs (
49
+ axis = st .one_of (
50
+ st .none (),
51
+ shared_shapes .flatmap (
52
+ lambda s : st .none ()
53
+ if len (s ) == 0
54
+ else st .integers (- len (s ) + 1 , len (s ) - 1 ),
55
+ ),
56
+ )
57
+ ),
58
+ )
59
+ def test_flip (x , kw ):
60
+ xp .flip (x , ** kw )
61
+ # TODO
62
+
63
+
64
+ @given (
65
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
66
+ axes = shared_shapes .flatmap (
67
+ lambda s : st .lists (
68
+ st .integers (0 , max (len (s ) - 1 , 0 )),
69
+ min_size = len (s ),
70
+ max_size = len (s ),
71
+ unique = True ,
72
+ ).map (tuple )
73
+ ),
74
+ )
75
+ def test_permute_dims (x , axes ):
76
+ xp .permute_dims (x , axes )
77
+ # TODO
78
+
79
+
80
+ @given (
81
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
82
+ shape = shared_shapes , # TODO: test more compatible shapes
83
+ )
84
+ def test_reshape (x , shape ):
85
+ xp .reshape (x , shape )
86
+ # TODO
87
+
88
+
89
+ @given (
90
+ # TODO: axis arguments, update shift respectively
91
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
92
+ shift = shared_shapes .flatmap (lambda s : st .integers (0 , max (math .prod (s ) - 1 , 0 ))),
93
+ )
94
+ def test_roll (x , shift ):
95
+ xp .roll (x , shift )
96
+ # TODO
97
+
98
+
99
+ @given (
100
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
101
+ axis = shared_shapes .flatmap (
102
+ lambda s : st .just (0 )
103
+ if len (s ) == 0
104
+ else st .integers (- len (s ) + 1 , len (s ) - 1 ).filter (lambda i : s [i ] == 1 )
105
+ ), # TODO: tuple of axis i.e. axes
106
+ )
107
+ def test_squeeze (x , axis ):
108
+ xp .squeeze (x , axis )
109
+ # TODO
110
+
111
+
35
112
@given (
36
113
shape = hh .shapes (),
37
114
dtypes = hh .mutually_promotable_dtypes (None ),
0 commit comments