Skip to content

Commit 952b9c3

Browse files
committed
Implement test_fftn and test_ifftn
1 parent 9790466 commit 952b9c3

File tree

2 files changed

+82
-12
lines changed

2 files changed

+82
-12
lines changed

array_api_tests/shape_helpers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import Iterator, List, Optional, Tuple, Union
3+
from typing import Iterator, List, Optional, Sequence, Tuple, Union
44

55
from ndindex import iter_indices as _iter_indices
66

@@ -66,10 +66,12 @@ def broadcast_shapes(*shapes: Shape):
6666

6767

6868
def normalise_axis(
69-
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
69+
axis: Optional[Union[int, Sequence[int]]], ndim: int
7070
) -> Tuple[int, ...]:
7171
if axis is None:
7272
return tuple(range(ndim))
73+
elif isinstance(axis, Sequence) and not isinstance(axis, tuple):
74+
axis = tuple(axis)
7375
axes = axis if isinstance(axis, tuple) else (axis,)
7476
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
7577
return axes

array_api_tests/test_fft.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional
2+
from typing import List, Optional
33

44
import pytest
55
from hypothesis import given
@@ -10,6 +10,7 @@
1010
from . import dtype_helpers as dh
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
13+
from . import shape_helpers as sh
1314
from . import xps
1415
from ._array_module import mod as xp
1516

@@ -23,9 +24,9 @@
2324
fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)
2425

2526

26-
def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
27+
def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
2728
size = math.prod(x.shape)
28-
n = data.draw(st.none() | st.integers(size // 2, size * 2), label="n")
29+
n = data.draw(st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n")
2930
axis = data.draw(st.integers(-1, x.ndim - 1), label="axis")
3031
norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
3132
kwargs = data.draw(
@@ -39,6 +40,32 @@ def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
3940
return n, axis, norm, kwargs
4041

4142

43+
def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
44+
all_axes = list(range(x.ndim))
45+
axes = data.draw(
46+
st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True),
47+
label="axes",
48+
)
49+
_axes = all_axes if axes is None else axes
50+
axes_sides = [x.shape[axis] for axis in _axes]
51+
s_strat = st.tuples(
52+
*[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides]
53+
)
54+
if axes is None:
55+
s_strat = st.none() | s_strat
56+
s = data.draw(s_strat, label="s")
57+
norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
58+
kwargs = data.draw(
59+
hh.specified_kwargs(
60+
("s", s, None),
61+
("axes", axes, None),
62+
("norm", norm, "backward"),
63+
),
64+
label="kwargs",
65+
)
66+
return s, axes, norm, kwargs
67+
68+
4269
def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType):
4370
if in_dtype == xp.float32:
4471
expected = xp.complex64
@@ -63,12 +90,32 @@ def assert_n_axis_shape(
6390
ph.assert_shape(func_name, out_shape=out.shape, expected=expected_shape)
6491

6592

93+
def assert_s_axes_shape(
94+
func_name: str,
95+
*,
96+
x: Array,
97+
s: Optional[List[int]],
98+
axes: Optional[List[int]],
99+
out: Array,
100+
):
101+
_axes = sh.normalise_axis(axes, x.ndim)
102+
_s = x.shape if s is None else s
103+
expected = []
104+
for i in range(x.ndim):
105+
if i in _axes:
106+
side = _s[_axes.index(i)]
107+
else:
108+
side = x.shape[i]
109+
expected.append(side)
110+
ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected))
111+
112+
66113
@given(
67114
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
68115
data=st.data(),
69116
)
70117
def test_fft(x, data):
71-
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
118+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
72119

73120
out = xp.fft.fft(x, **kwargs)
74121

@@ -81,25 +128,46 @@ def test_fft(x, data):
81128
data=st.data(),
82129
)
83130
def test_ifft(x, data):
84-
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
131+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
85132

86133
out = xp.fft.ifft(x, **kwargs)
87134

88135
assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
89136
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)
90137

91138

92-
# TODO:
93-
# test_fftn
94-
# test_ifftn
139+
@given(
140+
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
141+
data=st.data(),
142+
)
143+
def test_fftn(x, data):
144+
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
145+
146+
out = xp.fft.fftn(x, **kwargs)
147+
148+
assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
149+
assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)
150+
151+
152+
@given(
153+
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
154+
data=st.data(),
155+
)
156+
def test_ifftn(x, data):
157+
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
158+
159+
out = xp.fft.ifftn(x, **kwargs)
160+
161+
assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
162+
assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out)
95163

96164

97165
@given(
98166
x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
99167
data=st.data(),
100168
)
101169
def test_rfft(x, data):
102-
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
170+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
103171

104172
out = xp.fft.rfft(x, **kwargs)
105173

@@ -112,7 +180,7 @@ def test_rfft(x, data):
112180
data=st.data(),
113181
)
114182
def test_irfft(x, data):
115-
n, axis, norm, kwargs = n_axis_norm_kwargs(x, data)
183+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
116184

117185
out = xp.fft.irfft(x, **kwargs)
118186

0 commit comments

Comments
 (0)