2
2
from typing import List , Optional
3
3
4
4
import pytest
5
- from hypothesis import given
5
+ from hypothesis import assume , given
6
6
from hypothesis import strategies as st
7
7
8
8
from array_api_tests .typing import Array , DataType
24
24
fft_shapes_strat = hh .shapes (min_dims = 1 ).filter (lambda s : math .prod (s ) > 1 )
25
25
26
26
27
- def draw_n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
27
+ def draw_n_axis_norm_kwargs (x : Array , data : st .DataObject , * , size_gt_1 = False ) -> tuple :
28
28
size = math .prod (x .shape )
29
- n = data .draw (st .none () | st .integers ((size // 2 ), math .ceil (size * 1.5 )), label = "n" )
29
+ n = data .draw (
30
+ st .none () | st .integers ((size // 2 ), math .ceil (size * 1.5 )), label = "n"
31
+ )
30
32
axis = data .draw (st .integers (- 1 , x .ndim - 1 ), label = "axis" )
33
+ if size_gt_1 :
34
+ _axis = x .ndim - 1 if axis == - 1 else axis
35
+ assume (x .shape [_axis ] > 1 )
31
36
norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
32
37
kwargs = data .draw (
33
38
hh .specified_kwargs (
@@ -40,7 +45,7 @@ def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
40
45
return n , axis , norm , kwargs
41
46
42
47
43
- def draw_s_axes_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
48
+ def draw_s_axes_norm_kwargs (x : Array , data : st .DataObject , * , size_gt_1 = False ) -> tuple :
44
49
all_axes = list (range (x .ndim ))
45
50
axes = data .draw (
46
51
st .none () | st .lists (st .sampled_from (all_axes ), min_size = 1 , unique = True ),
@@ -54,6 +59,14 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
54
59
if axes is None :
55
60
s_strat = st .none () | s_strat
56
61
s = data .draw (s_strat , label = "s" )
62
+ if size_gt_1 :
63
+ _s = x .shape if s is None else s
64
+ for i in range (x .ndim ):
65
+ if i in _axes :
66
+ side = _s [_axes .index (i )]
67
+ else :
68
+ side = x .shape [i ]
69
+ assume (side > 1 )
57
70
norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
58
71
kwargs = data .draw (
59
72
hh .specified_kwargs (
@@ -163,7 +176,7 @@ def test_ifftn(x, data):
163
176
164
177
165
178
@given (
166
- x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
179
+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
167
180
data = st .data (),
168
181
)
169
182
def test_rfft (x , data ):
@@ -176,23 +189,70 @@ def test_rfft(x, data):
176
189
177
190
178
191
@given (
179
- x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
192
+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
180
193
data = st .data (),
181
194
)
182
195
def test_irfft (x , data ):
183
- n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
196
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
184
197
185
198
out = xp .fft .irfft (x , ** kwargs )
186
199
187
200
assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
188
201
# TODO: assert shape
189
202
190
203
191
- # TODO:
192
- # test_rfftn
193
- # test_irfftn
194
- # test_hfft
195
- # test_ihfft
204
+ @given (
205
+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
206
+ data = st .data (),
207
+ )
208
+ def test_rfftn (x , data ):
209
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
210
+
211
+ out = xp .fft .rfftn (x , ** kwargs )
212
+
213
+ assert_fft_dtype ("rfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
214
+ assert_s_axes_shape ("rfftn" , x = x , s = s , axes = axes , out = out )
215
+
216
+
217
+ @given (
218
+ x = xps .arrays (dtype = xps .complex_dtypes (), shape = fft_shapes_strat ),
219
+ data = st .data (),
220
+ )
221
+ def test_irfftn (x , data ):
222
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data , size_gt_1 = True )
223
+
224
+ out = xp .fft .irfftn (x , ** kwargs )
225
+
226
+ assert_fft_dtype ("irfftn" , in_dtype = x .dtype , out_dtype = out .dtype )
227
+ assert_s_axes_shape ("irfftn" , x = x , s = s , axes = axes , out = out )
228
+
229
+
230
+ @given (
231
+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
232
+ data = st .data (),
233
+ )
234
+ def test_hfft (x , data ):
235
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data , size_gt_1 = True )
236
+
237
+ out = xp .fft .hfft (x , ** kwargs )
238
+
239
+ assert_fft_dtype ("hfft" , in_dtype = x .dtype , out_dtype = out .dtype )
240
+ # TODO: shape
241
+
242
+
243
+ @given (
244
+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
245
+ data = st .data (),
246
+ )
247
+ def test_ihfft (x , data ):
248
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
249
+
250
+ out = xp .fft .ihfft (x , ** kwargs )
251
+
252
+ assert_fft_dtype ("ihfft" , in_dtype = x .dtype , out_dtype = out .dtype )
253
+ # TODO: shape
254
+
255
+
196
256
# fftfreq
197
257
# rfftfreq
198
258
# fftshift
0 commit comments