Skip to content

Commit 9790466

Browse files
committed
Test ifft, rfft and irfft
1 parent 22f9815 commit 9790466

File tree

1 file changed

+89
-20
lines changed

1 file changed

+89
-20
lines changed

array_api_tests/test_fft.py

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import math
2+
from typing import Optional
23

34
import pytest
45
from hypothesis import given
6+
from hypothesis import strategies as st
57

6-
from array_api_tests.typing import DataType
8+
from array_api_tests.typing import Array, DataType
79

8-
from . import _array_module as xp
910
from . import dtype_helpers as dh
1011
from . import hypothesis_helpers as hh
1112
from . import pytest_helpers as ph
1213
from . import xps
14+
from ._array_module import mod as xp
1315

1416
pytestmark = [
1517
pytest.mark.ci,
@@ -21,6 +23,22 @@
2123
fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)
2224

2325

26+
def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
27+
size = math.prod(x.shape)
28+
n = data.draw(st.none() | st.integers(size // 2, size * 2), label="n")
29+
axis = data.draw(st.integers(-1, x.ndim - 1), label="axis")
30+
norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
31+
kwargs = data.draw(
32+
hh.specified_kwargs(
33+
("n", n, None),
34+
("axis", axis, -1),
35+
("norm", norm, "backward"),
36+
),
37+
label="kwargs",
38+
)
39+
return n, axis, norm, kwargs
40+
41+
2442
def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType):
2543
if in_dtype == xp.float32:
2644
expected = xp.complex64
@@ -34,29 +52,80 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType)
3452
)
3553

3654

37-
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
38-
def test_fft(x):
39-
out = xp.fft.fft(x)
55+
def assert_n_axis_shape(
56+
func_name: str, *, x: Array, n: Optional[int], axis: int, out: Array
57+
):
58+
if n is None:
59+
expected_shape = x.shape
60+
else:
61+
_axis = len(x.shape) - 1 if axis == -1 else axis
62+
expected_shape = x.shape[:_axis] + (n,) + x.shape[_axis + 1 :]
63+
ph.assert_shape(func_name, out_shape=out.shape, expected=expected_shape)
64+
65+
66+
@given(
67+
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
68+
data=st.data(),
69+
)
70+
def test_fft(x, data):
71+
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
72+
73+
out = xp.fft.fft(x, **kwargs)
74+
4075
assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
41-
ph.assert_shape("fft", out_shape=out.shape, expected=x.shape)
76+
assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)
77+
4278

79+
@given(
80+
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
81+
data=st.data(),
82+
)
83+
def test_ifft(x, data):
84+
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
85+
86+
out = xp.fft.ifft(x, **kwargs)
4387

44-
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
45-
def test_ifft(x):
46-
out = xp.fft.ifft(x)
4788
assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
48-
ph.assert_shape("ifft", out_shape=out.shape, expected=x.shape)
89+
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)
90+
91+
92+
# TODO:
93+
# test_fftn
94+
# test_ifftn
95+
96+
97+
@given(
98+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
99+
data=st.data(),
100+
)
101+
def test_rfft(x, data):
102+
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
103+
104+
out = xp.fft.rfft(x, **kwargs)
105+
106+
assert_fft_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype)
107+
assert_n_axis_shape("rfft", x=x, n=n, axis=axis, out=out)
108+
109+
110+
@given(
111+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
112+
data=st.data(),
113+
)
114+
def test_irfft(x, data):
115+
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
49116

117+
out = xp.fft.irfft(x, **kwargs)
50118

51-
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
52-
def test_fftn(x):
53-
out = xp.fft.fftn(x)
54-
assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
55-
ph.assert_shape("fftn", out_shape=out.shape, expected=x.shape)
119+
assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype)
120+
# TODO: assert shape
56121

57122

58-
@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat))
59-
def test_ifftn(x):
60-
out = xp.fft.ifftn(x)
61-
assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
62-
ph.assert_shape("ifftn", out_shape=out.shape, expected=x.shape)
123+
# TODO:
124+
# test_rfftn
125+
# test_irfftn
126+
# test_hfft
127+
# test_ihfft
128+
# fftfreq
129+
# rfftfreq
130+
# fftshift
131+
# ifftshift

0 commit comments

Comments
 (0)