13
13
from . import xps
14
14
from .typing import Array , Shape
15
15
16
+ MAX_SIDE = hh .MAX_ARRAY_SIZE // 64
17
+ MAX_DIMS = min (hh .MAX_ARRAY_SIZE // MAX_SIDE , 32 ) # NumPy only supports up to 32 dims
18
+
16
19
17
20
def shared_shapes (* args , ** kwargs ) -> st .SearchStrategy [Shape ]:
18
21
key = "shape"
@@ -40,26 +43,63 @@ def assert_array_ndindex(
40
43
assert out [out_idx ] == x [x_idx ], msg
41
44
42
45
46
+ @st .composite
47
+ def concat_shapes (draw , shape , axis ):
48
+ shape = list (shape )
49
+ shape [axis ] = draw (st .integers (1 , MAX_SIDE ))
50
+ return tuple (shape )
51
+
52
+
43
53
@given (
44
- shape = hh .shapes (min_dims = 1 ),
45
54
dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
46
- kw = hh .kwargs (axis = st .just ( 0 ) | st .none ()), # TODO: test with axis >= 1
55
+ kw = hh .kwargs (axis = st .none ( ) | st .integers ( - MAX_DIMS , MAX_DIMS - 1 )),
47
56
data = st .data (),
48
57
)
49
- def test_concat (shape , dtypes , kw , data ):
58
+ def test_concat (dtypes , kw , data ):
59
+ axis = kw .get ("axis" , 0 )
60
+ if axis is None :
61
+ shape_strat = hh .shapes ()
62
+ else :
63
+ _axis = axis if axis >= 0 else abs (axis ) - 1
64
+ shape_strat = shared_shapes (min_dims = _axis + 1 ).flatmap (
65
+ lambda s : concat_shapes (s , axis )
66
+ )
50
67
arrays = []
51
68
for i , dtype in enumerate (dtypes , 1 ):
52
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f"x{ i } " )
69
+ x = data .draw (xps .arrays (dtype = dtype , shape = shape_strat ), label = f"x{ i } " )
53
70
arrays .append (x )
71
+
54
72
out = xp .concat (arrays , ** kw )
73
+
55
74
ph .assert_dtype ("concat" , dtypes , out .dtype )
75
+
56
76
shapes = tuple (x .shape for x in arrays )
57
- if kw .get ("axis" , 0 ) == 0 :
58
- pass # TODO: assert expected shape
59
- elif kw ["axis" ] is None :
77
+ axis = kw .get ("axis" , 0 )
78
+ if axis is None :
60
79
size = sum (math .prod (s ) for s in shapes )
61
- ph .assert_result_shape ("concat" , shapes , out .shape , (size ,), ** kw )
62
- # TODO: assert out elements match input arrays
80
+ shape = (size ,)
81
+ else :
82
+ shape = list (shapes [0 ])
83
+ for other_shape in shapes [1 :]:
84
+ shape [axis ] += other_shape [axis ]
85
+ shape = tuple (shape )
86
+ ph .assert_result_shape ("concat" , shapes , out .shape , shape , ** kw )
87
+
88
+ # TODO: adjust indices with nonzero axis
89
+ if axis is None or axis == 0 :
90
+ out_indices = ah .ndindex (out .shape )
91
+ for i , x in enumerate (arrays , 1 ):
92
+ msg_suffix = f" [concat({ ph .fmt_kw (kw )} )]\n x{ i } ={ x !r} \n { out = } "
93
+ for x_idx in ah .ndindex (x .shape ):
94
+ out_idx = next (out_indices )
95
+ msg = (
96
+ f"out[{ out_idx } ]={ out [out_idx ]} , should be x{ i } [{ x_idx } ]={ x [x_idx ]} "
97
+ )
98
+ msg += msg_suffix
99
+ if dh .is_float_dtype (x .dtype ) and xp .isnan (x [x_idx ]):
100
+ assert xp .isnan (out [out_idx ]), msg
101
+ else :
102
+ assert out [out_idx ] == x [x_idx ], msg
63
103
64
104
65
105
@given (
@@ -169,9 +209,8 @@ def test_permute_dims(x, axes):
169
209
# TODO: test elements
170
210
171
211
172
- MAX_RESHAPE_SIDE = hh .MAX_ARRAY_SIZE // 64
173
212
reshape_x_shapes = st .shared (
174
- hh .shapes ().filter (lambda s : math .prod (s ) <= MAX_RESHAPE_SIDE ),
213
+ hh .shapes ().filter (lambda s : math .prod (s ) <= MAX_SIDE ),
175
214
key = "reshape x shape" ,
176
215
)
177
216
@@ -180,7 +219,7 @@ def test_permute_dims(x, axes):
180
219
def reshape_shapes (draw , shape ):
181
220
size = 1 if len (shape ) == 0 else math .prod (shape )
182
221
rshape = draw (st .lists (st .integers (0 )).filter (lambda s : math .prod (s ) == size ))
183
- assume (all (side <= MAX_RESHAPE_SIDE for side in rshape ))
222
+ assume (all (side <= MAX_SIDE for side in rshape ))
184
223
if len (rshape ) != 0 and size > 0 and draw (st .booleans ()):
185
224
index = draw (st .integers (0 , len (rshape ) - 1 ))
186
225
rshape [index ] = - 1
0 commit comments