Skip to content

Commit 674dd0a

Browse files
committed
Rudimentary FFT shift tests
1 parent c434e4d commit 674dd0a

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

array_api_tests/test_fft.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,15 @@ def test_rfftfreq(n, kw):
303303
ph.assert_shape("rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n})
304304

305305

306-
# TODO:
307-
# fftshift
308-
# ifftshift
306+
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat))
307+
def test_fftshift(x):
308+
out = xp.fft.fftshift(x)
309+
ph.assert_dtype("fftshift", in_dtype=x.dtype, out_dtype=out.dtype)
310+
ph.assert_shape("fftshift", out_shape=out.shape, expected=x.shape)
311+
312+
313+
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat))
314+
def test_ifftshift(x):
315+
out = xp.fft.ifftshift(x)
316+
ph.assert_dtype("ifftshift", in_dtype=x.dtype, out_dtype=out.dtype)
317+
ph.assert_shape("ifftshift", out_shape=out.shape, expected=x.shape)

0 commit comments

Comments
 (0)