@@ -60,3 +60,55 @@ def test_take(x, data):
60
60
# sanity check
61
61
with pytest .raises (StopIteration ):
62
62
next (out_indices )
63
+
64
+
65
+
66
+ @pytest .mark .unvectorized
67
+ @pytest .mark .min_version ("2024.12" )
68
+ @given (
69
+ x = hh .arrays (hh .all_dtypes , hh .shapes (min_dims = 1 , min_side = 1 )),
70
+ data = st .data (),
71
+ )
72
+ def test_take_along_axis (x , data ):
73
+ # TODO
74
+ # 1. negative axis
75
+ # 2. negative indices
76
+ # 3. different dtypes for indices
77
+ axis = data .draw (st .integers (0 , max (x .ndim - 1 , 0 )), label = "axis" )
78
+ len_axis = data .draw (st .integers (0 , 2 * x .shape [axis ]), label = "len_axis" )
79
+
80
+ idx_shape = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :]
81
+ indices = data .draw (
82
+ hh .arrays (
83
+ shape = idx_shape ,
84
+ dtype = dh .default_int ,
85
+ elements = {"min_value" : 0 , "max_value" : x .shape [axis ]- 1 }
86
+ ),
87
+ label = "indices"
88
+ )
89
+ note (f"{ indices = } { idx_shape = } " )
90
+
91
+ out = xp .take_along_axis (x , indices , axis = axis )
92
+
93
+ ph .assert_dtype ("take_along_axis" , in_dtype = x .dtype , out_dtype = out .dtype )
94
+ ph .assert_shape (
95
+ "take_along_axis" ,
96
+ out_shape = out .shape ,
97
+ expected = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :],
98
+ kw = dict (
99
+ x = x ,
100
+ indices = indices ,
101
+ axis = axis ,
102
+ ),
103
+ )
104
+
105
+ # value test: notation is from `np.take_along_axis` docstring
106
+ Ni , Nk = x .shape [:axis ], x .shape [axis + 1 :]
107
+ for ii in sh .ndindex (Ni ):
108
+ for kk in sh .ndindex (Nk ):
109
+ a_1d = x [ii + (slice (None ),) + kk ]
110
+ i_1d = indices [ii + (slice (None ),) + kk ]
111
+ o_1d = out [ii + (slice (None ),) + kk ]
112
+ for j in range (len_axis ):
113
+ assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
114
+
0 commit comments