Skip to content

Commit 8236912

Browse files
committed
Add tests for (i)rfftn and (i)hfft
And rudimentary support for `2*(m-1)` size rules
1 parent c8822ab commit 8236912

File tree

1 file changed

+72
-12
lines changed

1 file changed

+72
-12
lines changed

array_api_tests/test_fft.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Optional
33

44
import pytest
5-
from hypothesis import given
5+
from hypothesis import assume, given
66
from hypothesis import strategies as st
77

88
from array_api_tests.typing import Array, DataType
@@ -24,10 +24,15 @@
2424
fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)
2525

2626

27-
def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
27+
def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
2828
size = math.prod(x.shape)
29-
n = data.draw(st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n")
29+
n = data.draw(
30+
st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n"
31+
)
3032
axis = data.draw(st.integers(-1, x.ndim - 1), label="axis")
33+
if size_gt_1:
34+
_axis = x.ndim - 1 if axis == -1 else axis
35+
assume(x.shape[_axis] > 1)
3136
norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
3237
kwargs = data.draw(
3338
hh.specified_kwargs(
@@ -40,7 +45,7 @@ def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
4045
return n, axis, norm, kwargs
4146

4247

43-
def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
48+
def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
4449
all_axes = list(range(x.ndim))
4550
axes = data.draw(
4651
st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True),
@@ -54,6 +59,14 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
5459
if axes is None:
5560
s_strat = st.none() | s_strat
5661
s = data.draw(s_strat, label="s")
62+
if size_gt_1:
63+
_s = x.shape if s is None else s
64+
for i in range(x.ndim):
65+
if i in _axes:
66+
side = _s[_axes.index(i)]
67+
else:
68+
side = x.shape[i]
69+
assume(side > 1)
5770
norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
5871
kwargs = data.draw(
5972
hh.specified_kwargs(
@@ -163,7 +176,7 @@ def test_ifftn(x, data):
163176

164177

165178
@given(
166-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
179+
x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
167180
data=st.data(),
168181
)
169182
def test_rfft(x, data):
@@ -176,23 +189,70 @@ def test_rfft(x, data):
176189

177190

178191
@given(
179-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
192+
x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
180193
data=st.data(),
181194
)
182195
def test_irfft(x, data):
183-
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
196+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
184197

185198
out = xp.fft.irfft(x, **kwargs)
186199

187200
assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype)
188201
# TODO: assert shape
189202

190203

191-
# TODO:
192-
# test_rfftn
193-
# test_irfftn
194-
# test_hfft
195-
# test_ihfft
204+
@given(
205+
x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
206+
data=st.data(),
207+
)
208+
def test_rfftn(x, data):
209+
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
210+
211+
out = xp.fft.rfftn(x, **kwargs)
212+
213+
assert_fft_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype)
214+
assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out)
215+
216+
217+
@given(
218+
x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
219+
data=st.data(),
220+
)
221+
def test_irfftn(x, data):
222+
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data, size_gt_1=True)
223+
224+
out = xp.fft.irfftn(x, **kwargs)
225+
226+
assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype)
227+
assert_s_axes_shape("irfftn", x=x, s=s, axes=axes, out=out)
228+
229+
230+
@given(
231+
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
232+
data=st.data(),
233+
)
234+
def test_hfft(x, data):
235+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
236+
237+
out = xp.fft.hfft(x, **kwargs)
238+
239+
assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype)
240+
# TODO: shape
241+
242+
243+
@given(
244+
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
245+
data=st.data(),
246+
)
247+
def test_ihfft(x, data):
248+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
249+
250+
out = xp.fft.ihfft(x, **kwargs)
251+
252+
assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype)
253+
# TODO: shape
254+
255+
196256
# fftfreq
197257
# rfftfreq
198258
# fftshift

0 commit comments

Comments
 (0)