Skip to content

Commit cd803b8

Browse files
committed
size_gt_1 testing in assert_n_axis_shape()
1 parent 671e07e commit cd803b8

File tree

1 file changed

+34
-9
lines changed

1 file changed

+34
-9
lines changed

array_api_tests/test_fft.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,24 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType)
9393

9494

9595
def assert_n_axis_shape(
96-
func_name: str, *, x: Array, n: Optional[int], axis: int, out: Array
96+
func_name: str,
97+
*,
98+
x: Array,
99+
n: Optional[int],
100+
axis: int,
101+
out: Array,
102+
size_gt_1=False,
97103
):
104+
_axis = len(x.shape) - 1 if axis == -1 else axis
98105
if n is None:
99-
expected_shape = x.shape
106+
if size_gt_1:
107+
axis_side = 2 * (x.shape[_axis] - 1)
108+
else:
109+
axis_side = x.shape[_axis]
100110
else:
101-
_axis = len(x.shape) - 1 if axis == -1 else axis
102-
expected_shape = x.shape[:_axis] + (n,) + x.shape[_axis + 1 :]
103-
ph.assert_shape(func_name, out_shape=out.shape, expected=expected_shape)
111+
axis_side = n
112+
expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
113+
ph.assert_shape(func_name, out_shape=out.shape, expected=expected)
104114

105115

106116
def assert_s_axes_shape(
@@ -198,7 +208,14 @@ def test_irfft(x, data):
198208
out = xp.fft.irfft(x, **kwargs)
199209

200210
assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype)
201-
# TODO: assert shape
211+
212+
_axis = x.ndim - 1 if axis == -1 else axis
213+
if n is None:
214+
axis_side = 2 * (x.shape[_axis] - 1)
215+
else:
216+
axis_side = n
217+
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
218+
ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape)
202219

203220

204221
@given(
@@ -224,7 +241,7 @@ def test_irfftn(x, data):
224241
out = xp.fft.irfftn(x, **kwargs)
225242

226243
assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype)
227-
assert_s_axes_shape("irfftn", x=x, s=s, axes=axes, out=out)
244+
# TODO: shape
228245

229246

230247
@given(
@@ -237,7 +254,14 @@ def test_hfft(x, data):
237254
out = xp.fft.hfft(x, **kwargs)
238255

239256
assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype)
240-
# TODO: shape
257+
258+
_axis = x.ndim - 1 if axis == -1 else axis
259+
if n is None:
260+
axis_side = 2 * (x.shape[_axis] - 1)
261+
else:
262+
axis_side = n
263+
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
264+
ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape)
241265

242266

243267
@given(
@@ -250,9 +274,10 @@ def test_ihfft(x, data):
250274
out = xp.fft.ihfft(x, **kwargs)
251275

252276
assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype)
253-
# TODO: shape
277+
assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True)
254278

255279

280+
# TODO:
256281
# fftfreq
257282
# rfftfreq
258283
# fftshift

0 commit comments

Comments
 (0)