@@ -129,12 +129,11 @@ def test_expand_dims(x, axis):
129
129
data = st .data (),
130
130
)
131
131
def test_squeeze (x , data ):
132
- # axis=shared_shapes(min_side=1).flatmap(lambda s: nd_axes(len(s))),
132
+ # TODO: generate valid negative axis (which keep uniqueness)
133
133
squeezable_axes = st .sampled_from (
134
134
[i for i , side in enumerate (x .shape ) if side == 1 ]
135
135
)
136
136
axis = data .draw (
137
- # TODO: generate valid negative axis
138
137
squeezable_axes | st .lists (squeezable_axes , unique = True ).map (tuple ),
139
138
label = "axis" ,
140
139
)
@@ -157,20 +156,19 @@ def test_squeeze(x, data):
157
156
assert_array_ndindex ("squeeze" , x , ah .ndindex (x .shape ), out , ah .ndindex (out .shape ))
158
157
159
158
160
- @st .composite
161
- def flip_axis (draw , shape ):
162
- if len (shape ) == 0 or draw (st .booleans ()):
163
- return None
164
- else :
165
- ndim = len (shape )
166
- return draw (st .integers (- ndim , ndim - 1 ) | xps .valid_tuple_axes (ndim ))
167
-
168
-
169
159
@given (
170
- x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ()),
171
- kw = hh . kwargs ( axis = shared_shapes (). flatmap ( flip_axis ) ),
160
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = hh . shapes ()),
161
+ data = st . data ( ),
172
162
)
173
- def test_flip (x , kw ):
163
+ def test_flip (x , data ):
164
+ if x .ndim == 0 :
165
+ axis_strat = st .none ()
166
+ else :
167
+ axis_strat = (
168
+ st .none () | st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
169
+ )
170
+ kw = data .draw (hh .kwargs (axis = axis_strat ), label = "kw" )
171
+
174
172
out = xp .flip (x , ** kw )
175
173
176
174
ph .assert_dtype ("flip" , x .dtype , out .dtype )
@@ -209,12 +207,6 @@ def test_permute_dims(x, axes):
209
207
# TODO: test elements
210
208
211
209
212
- reshape_x_shapes = st .shared (
213
- hh .shapes ().filter (lambda s : math .prod (s ) <= MAX_SIDE ),
214
- key = "reshape x shape" ,
215
- )
216
-
217
-
218
210
@st .composite
219
211
def reshape_shapes (draw , shape ):
220
212
size = 1 if len (shape ) == 0 else math .prod (shape )
@@ -227,21 +219,22 @@ def reshape_shapes(draw, shape):
227
219
228
220
229
221
@given (
230
- x = xps .arrays (dtype = xps .scalar_dtypes (), shape = reshape_x_shapes ),
231
- shape = reshape_x_shapes . flatmap ( reshape_shapes ),
222
+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = hh . shapes ( max_side = MAX_SIDE ) ),
223
+ data = st . data ( ),
232
224
)
233
- def test_reshape (x , shape ):
234
- assume ( math . prod ( shape ) == math . prod (x .shape ))
225
+ def test_reshape (x , data ):
226
+ shape = data . draw ( reshape_shapes (x .shape ))
235
227
236
228
out = xp .reshape (x , shape )
237
229
238
230
ph .assert_dtype ("reshape" , x .dtype , out .dtype )
239
231
240
- _shape = shape
232
+ _shape = list ( shape )
241
233
if any (side == - 1 for side in shape ):
242
234
size = math .prod (x .shape )
243
235
rsize = math .prod (shape ) * - 1
244
236
_shape [shape .index (- 1 )] = size / rsize
237
+ _shape = tuple (_shape )
245
238
ph .assert_result_shape ("reshape" , (x .shape ,), out .shape , _shape , shape = shape )
246
239
247
240
assert_array_ndindex ("reshape" , x , ah .ndindex (x .shape ), out , ah .ndindex (out .shape ))
0 commit comments