1
1
import math
2
- from typing import Optional
2
+ from typing import List , Optional
3
3
4
4
import pytest
5
5
from hypothesis import given
10
10
from . import dtype_helpers as dh
11
11
from . import hypothesis_helpers as hh
12
12
from . import pytest_helpers as ph
13
+ from . import shape_helpers as sh
13
14
from . import xps
14
15
from ._array_module import mod as xp
15
16
23
24
fft_shapes_strat = hh .shapes (min_dims = 1 ).filter (lambda s : math .prod (s ) > 1 )
24
25
25
26
26
- def n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
27
+ def draw_n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
27
28
size = math .prod (x .shape )
28
- n = data .draw (st .none () | st .integers (size // 2 , size * 2 ), label = "n" )
29
+ n = data .draw (st .none () | st .integers (( size // 2 ), math . ceil ( size * 1.5 ) ), label = "n" )
29
30
axis = data .draw (st .integers (- 1 , x .ndim - 1 ), label = "axis" )
30
31
norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
31
32
kwargs = data .draw (
@@ -39,6 +40,32 @@ def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
39
40
return n , axis , norm , kwargs
40
41
41
42
43
+ def draw_s_axes_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
44
+ all_axes = list (range (x .ndim ))
45
+ axes = data .draw (
46
+ st .none () | st .lists (st .sampled_from (all_axes ), min_size = 1 , unique = True ),
47
+ label = "axes" ,
48
+ )
49
+ _axes = all_axes if axes is None else axes
50
+ axes_sides = [x .shape [axis ] for axis in _axes ]
51
+ s_strat = st .tuples (
52
+ * [st .integers (max (side // 2 , 1 ), math .ceil (side * 1.5 )) for side in axes_sides ]
53
+ )
54
+ if axes is None :
55
+ s_strat = st .none () | s_strat
56
+ s = data .draw (s_strat , label = "s" )
57
+ norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
58
+ kwargs = data .draw (
59
+ hh .specified_kwargs (
60
+ ("s" , s , None ),
61
+ ("axes" , axes , None ),
62
+ ("norm" , norm , "backward" ),
63
+ ),
64
+ label = "kwargs" ,
65
+ )
66
+ return s , axes , norm , kwargs
67
+
68
+
42
69
def assert_fft_dtype (func_name : str , * , in_dtype : DataType , out_dtype : DataType ):
43
70
if in_dtype == xp .float32 :
44
71
expected = xp .complex64
@@ -63,12 +90,32 @@ def assert_n_axis_shape(
63
90
ph .assert_shape (func_name , out_shape = out .shape , expected = expected_shape )
64
91
65
92
93
+ def assert_s_axes_shape (
94
+ func_name : str ,
95
+ * ,
96
+ x : Array ,
97
+ s : Optional [List [int ]],
98
+ axes : Optional [List [int ]],
99
+ out : Array ,
100
+ ):
101
+ _axes = sh .normalise_axis (axes , x .ndim )
102
+ _s = x .shape if s is None else s
103
+ expected = []
104
+ for i in range (x .ndim ):
105
+ if i in _axes :
106
+ side = _s [_axes .index (i )]
107
+ else :
108
+ side = x .shape [i ]
109
+ expected .append (side )
110
+ ph .assert_shape (func_name , out_shape = out .shape , expected = tuple (expected ))
111
+
112
+
66
113
@given (
67
114
x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
68
115
data = st .data (),
69
116
)
70
117
def test_fft (x , data ):
71
- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
118
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
72
119
73
120
out = xp .fft .fft (x , ** kwargs )
74
121
@@ -81,25 +128,46 @@ def test_fft(x, data):
81
128
data = st .data (),
82
129
)
83
130
def test_ifft (x , data ):
84
- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
131
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
85
132
86
133
out = xp .fft .ifft (x , ** kwargs )
87
134
88
135
assert_fft_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
89
136
assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
90
137
91
138
92
- # TODO:
93
- # test_fftn
94
- # test_ifftn
139
+ @given (
140
+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
141
+ data = st .data (),
142
+ )
143
+ def test_fftn (x , data ):
144
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
145
+
146
+ out = xp .fft .fftn (x , ** kwargs )
147
+
148
+ assert_fft_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
149
+ assert_s_axes_shape ("fftn" , x = x , s = s , axes = axes , out = out )
150
+
151
+
152
+ @given (
153
+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
154
+ data = st .data (),
155
+ )
156
+ def test_ifftn (x , data ):
157
+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
158
+
159
+ out = xp .fft .ifftn (x , ** kwargs )
160
+
161
+ assert_fft_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
162
+ assert_s_axes_shape ("ifftn" , x = x , s = s , axes = axes , out = out )
95
163
96
164
97
165
@given (
98
166
x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
99
167
data = st .data (),
100
168
)
101
169
def test_rfft (x , data ):
102
- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
170
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
103
171
104
172
out = xp .fft .rfft (x , ** kwargs )
105
173
@@ -112,7 +180,7 @@ def test_rfft(x, data):
112
180
data = st .data (),
113
181
)
114
182
def test_irfft (x , data ):
115
- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
183
+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
116
184
117
185
out = xp .fft .irfft (x , ** kwargs )
118
186
0 commit comments