@@ -275,15 +275,46 @@ def test_roll(x, data):
275
275
276
276
277
277
@given (
278
- shape = hh . shapes ( ),
278
+ shape = shared_shapes ( min_dims = 1 ),
279
279
dtypes = hh .mutually_promotable_dtypes (None ),
280
+ kw = hh .kwargs (
281
+ axis = shared_shapes (min_dims = 1 ).flatmap (
282
+ lambda s : st .integers (- len (s ), len (s ) - 1 )
283
+ )
284
+ ),
280
285
data = st .data (),
281
286
)
282
- def test_stack (shape , dtypes , data ):
287
+ def test_stack (shape , dtypes , kw , data ):
283
288
arrays = []
284
289
for i , dtype in enumerate (dtypes , 1 ):
285
290
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f"x{ i } " )
286
291
arrays .append (x )
287
- out = xp .stack (arrays )
292
+
293
+ out = xp .stack (arrays , ** kw )
294
+
288
295
ph .assert_dtype ("stack" , dtypes , out .dtype )
289
- # TODO
296
+
297
+ axis = kw .get ("axis" , 0 )
298
+ _axis = axis if axis >= 0 else len (shape ) + axis + 1
299
+ _shape = list (shape )
300
+ _shape .insert (_axis , len (arrays ))
301
+ _shape = tuple (_shape )
302
+ ph .assert_result_shape (
303
+ "stack" , tuple (x .shape for x in arrays ), out .shape , _shape , ** kw
304
+ )
305
+
306
+ # TODO: adjust indices with nonzero axis
307
+ if axis == 0 :
308
+ out_indices = ah .ndindex (out .shape )
309
+ for i , x in enumerate (arrays , 1 ):
310
+ msg_suffix = f" [stack({ ph .fmt_kw (kw )} )]\n x{ i } ={ x !r} \n { out = } "
311
+ for x_idx in ah .ndindex (x .shape ):
312
+ out_idx = next (out_indices )
313
+ msg = (
314
+ f"out[{ out_idx } ]={ out [out_idx ]} , should be x{ i } [{ x_idx } ]={ x [x_idx ]} "
315
+ )
316
+ msg += msg_suffix
317
+ if dh .is_float_dtype (x .dtype ) and xp .isnan (x [x_idx ]):
318
+ assert xp .isnan (out [out_idx ]), msg
319
+ else :
320
+ assert out [out_idx ] == x [x_idx ], msg
0 commit comments