1
1
import math
2
2
from collections import deque
3
3
from itertools import product
4
- from typing import Iterable , Union
4
+ from typing import Iterable , Iterator , Tuple , Union
5
5
6
6
from hypothesis import assume , given
7
7
from hypothesis import strategies as st
@@ -28,6 +28,16 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
28
28
return st .shared (hh .shapes (* args , ** kwargs ), key = "shape" )
29
29
30
30
31
+ def axis_ndindex (
32
+ shape : Shape , axis : int
33
+ ) -> Iterator [Tuple [Tuple [Union [int , slice ], ...], ...]]:
34
+ assert axis >= 0 # sanity check
35
+ axis_indices = [range (side ) for side in shape [:axis ]]
36
+ for _ in range (axis , len (shape )):
37
+ axis_indices .append ([slice (None , None )])
38
+ yield from product (* axis_indices )
39
+
40
+
31
41
def assert_array_ndindex (
32
42
func_name : str ,
33
43
x : Array ,
@@ -115,10 +125,7 @@ def test_concat(dtypes, kw, data):
115
125
)
116
126
else :
117
127
out_indices = ah .ndindex (out .shape )
118
- axis_indices = [range (side ) for side in shapes [0 ][:_axis ]]
119
- for _ in range (_axis , len (shape )):
120
- axis_indices .append ([slice (None , None )])
121
- for idx in product (* axis_indices ):
128
+ for idx in axis_ndindex (shapes [0 ], _axis ):
122
129
f_idx = ", " .join (str (i ) if isinstance (i , int ) else ":" for i in idx )
123
130
for x_num , x in enumerate (arrays , 1 ):
124
131
indexed_x = x [idx ]
@@ -344,18 +351,19 @@ def test_stack(shape, dtypes, kw, data):
344
351
"stack" , tuple (x .shape for x in arrays ), out .shape , _shape , ** kw
345
352
)
346
353
347
- # TODO: adjust indices with nonzero axis
348
- if axis == 0 :
349
- out_indices = ah .ndindex (out .shape )
350
- for i , x in enumerate (arrays , 1 ):
351
- msg_suffix = f" [stack({ ph .fmt_kw (kw )} )]\n x{ i } ={ x !r} \n { out = } "
352
- for x_idx in ah .ndindex (x .shape ):
354
+ out_indices = ah .ndindex (out .shape )
355
+ for idx in axis_ndindex (arrays [0 ].shape , axis = _axis ):
356
+ f_idx = ", " .join (str (i ) if isinstance (i , int ) else ":" for i in idx )
357
+ print (f"{ f_idx = } " )
358
+ for x_num , x in enumerate (arrays , 1 ):
359
+ indexed_x = x [idx ]
360
+ for x_idx in ah .ndindex (indexed_x .shape ):
353
361
out_idx = next (out_indices )
354
- msg = (
355
- f"out[{ out_idx } ]={ out [out_idx ]} , should be x{ i } [{ x_idx } ]={ x [x_idx ]} "
362
+ assert_equals (
363
+ "stack" ,
364
+ f"x{ x_num } [{ f_idx } ][{ x_idx } ]" ,
365
+ indexed_x [x_idx ],
366
+ f"out[{ out_idx } ]" ,
367
+ out [out_idx ],
368
+ ** kw ,
356
369
)
357
- msg += msg_suffix
358
- if dh .is_float_dtype (x .dtype ) and xp .isnan (x [x_idx ]):
359
- assert xp .isnan (out [out_idx ]), msg
360
- else :
361
- assert out [out_idx ] == x [x_idx ], msg
0 commit comments