1
1
import math
2
+ from typing import Optional
2
3
3
4
import pytest
4
5
from hypothesis import given
6
+ from hypothesis import strategies as st
5
7
6
- from array_api_tests .typing import DataType
8
+ from array_api_tests .typing import Array , DataType
7
9
8
- from . import _array_module as xp
9
10
from . import dtype_helpers as dh
10
11
from . import hypothesis_helpers as hh
11
12
from . import pytest_helpers as ph
12
13
from . import xps
14
+ from ._array_module import mod as xp
13
15
14
16
pytestmark = [
15
17
pytest .mark .ci ,
21
23
fft_shapes_strat = hh .shapes (min_dims = 1 ).filter (lambda s : math .prod (s ) > 1 )
22
24
23
25
26
+ def n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
27
+ size = math .prod (x .shape )
28
+ n = data .draw (st .none () | st .integers (size // 2 , size * 2 ), label = "n" )
29
+ axis = data .draw (st .integers (- 1 , x .ndim - 1 ), label = "axis" )
30
+ norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
31
+ kwargs = data .draw (
32
+ hh .specified_kwargs (
33
+ ("n" , n , None ),
34
+ ("axis" , axis , - 1 ),
35
+ ("norm" , norm , "backward" ),
36
+ ),
37
+ label = "kwargs" ,
38
+ )
39
+ return n , axis , norm , kwargs
40
+
41
+
24
42
def assert_fft_dtype (func_name : str , * , in_dtype : DataType , out_dtype : DataType ):
25
43
if in_dtype == xp .float32 :
26
44
expected = xp .complex64
@@ -34,29 +52,80 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType)
34
52
)
35
53
36
54
37
- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
38
- def test_fft (x ):
39
- out = xp .fft .fft (x )
55
+ def assert_n_axis_shape (
56
+ func_name : str , * , x : Array , n : Optional [int ], axis : int , out : Array
57
+ ):
58
+ if n is None :
59
+ expected_shape = x .shape
60
+ else :
61
+ _axis = len (x .shape ) - 1 if axis == - 1 else axis
62
+ expected_shape = x .shape [:_axis ] + (n ,) + x .shape [_axis + 1 :]
63
+ ph .assert_shape (func_name , out_shape = out .shape , expected = expected_shape )
64
+
65
+
66
+ @given (
67
+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
68
+ data = st .data (),
69
+ )
70
+ def test_fft (x , data ):
71
+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
72
+
73
+ out = xp .fft .fft (x , ** kwargs )
74
+
40
75
assert_fft_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
41
- ph .assert_shape ("fft" , out_shape = out .shape , expected = x .shape )
76
+ assert_n_axis_shape ("fft" , x = x , n = n , axis = axis , out = out )
77
+
42
78
79
+ @given (
80
+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
81
+ data = st .data (),
82
+ )
83
+ def test_ifft (x , data ):
84
+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
85
+
86
+ out = xp .fft .ifft (x , ** kwargs )
43
87
44
- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
45
- def test_ifft (x ):
46
- out = xp .fft .ifft (x )
47
88
assert_fft_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
48
- ph .assert_shape ("ifft" , out_shape = out .shape , expected = x .shape )
89
+ assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
90
+
91
+
92
+ # TODO:
93
+ # test_fftn
94
+ # test_ifftn
95
+
96
+
97
+ @given (
98
+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
99
+ data = st .data (),
100
+ )
101
+ def test_rfft (x , data ):
102
+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
103
+
104
+ out = xp .fft .rfft (x , ** kwargs )
105
+
106
+ assert_fft_dtype ("rfft" , in_dtype = x .dtype , out_dtype = out .dtype )
107
+ assert_n_axis_shape ("rfft" , x = x , n = n , axis = axis , out = out )
108
+
109
+
110
+ @given (
111
+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
112
+ data = st .data (),
113
+ )
114
+ def test_irfft (x , data ):
115
+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
49
116
117
+ out = xp .fft .irfft (x , ** kwargs )
50
118
51
- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
52
- def test_fftn (x ):
53
- out = xp .fft .fftn (x )
54
- assert_fft_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
55
- ph .assert_shape ("fftn" , out_shape = out .shape , expected = x .shape )
119
+ assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
120
+ # TODO: assert shape
56
121
57
122
58
- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
59
- def test_ifftn (x ):
60
- out = xp .fft .ifftn (x )
61
- assert_fft_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
62
- ph .assert_shape ("ifftn" , out_shape = out .shape , expected = x .shape )
123
+ # TODO:
124
+ # test_rfftn
125
+ # test_irfftn
126
+ # test_hfft
127
+ # test_ihfft
128
+ # fftfreq
129
+ # rfftfreq
130
+ # fftshift
131
+ # ifftshift
0 commit comments