Skip to content

Commit 462d0d3

Browse files
committed
Avoid calling hh.arrays with dtype=None for complex types
1 parent 11bb686 commit 462d0d3

File tree

2 files changed

+98
-95
lines changed

2 files changed

+98
-95
lines changed

array_api_tests/test_fft.py

Lines changed: 95 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -109,44 +109,45 @@ def assert_s_axes_shape(
109109
ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected))
110110

111111

112-
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
113-
def test_fft(x, data):
114-
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
112+
if hh.complex_dtypes:
113+
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
114+
def test_fft(x, data):
115+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
115116

116-
out = xp.fft.fft(x, **kwargs)
117+
out = xp.fft.fft(x, **kwargs)
117118

118-
ph.assert_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
119-
assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)
119+
ph.assert_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
120+
assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)
120121

121122

122-
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
123-
def test_ifft(x, data):
124-
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
123+
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
124+
def test_ifft(x, data):
125+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
125126

126-
out = xp.fft.ifft(x, **kwargs)
127+
out = xp.fft.ifft(x, **kwargs)
127128

128-
ph.assert_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
129-
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)
129+
ph.assert_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
130+
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)
130131

131132

132-
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
133-
def test_fftn(x, data):
134-
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
133+
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
134+
def test_fftn(x, data):
135+
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
135136

136-
out = xp.fft.fftn(x, **kwargs)
137+
out = xp.fft.fftn(x, **kwargs)
137138

138-
ph.assert_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
139-
assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)
139+
ph.assert_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
140+
assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)
140141

141142

142-
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
143-
def test_ifftn(x, data):
144-
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
143+
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
144+
def test_ifftn(x, data):
145+
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
145146

146-
out = xp.fft.ifftn(x, **kwargs)
147+
out = xp.fft.ifftn(x, **kwargs)
147148

148-
ph.assert_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
149-
assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out)
149+
ph.assert_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
150+
assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out)
150151

151152

152153
@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
@@ -166,26 +167,27 @@ def test_rfft(x, data):
166167
ph.assert_shape("rfft", out_shape=out.shape, expected=expected_shape)
167168

168169

169-
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
170-
def test_irfft(x, data):
171-
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
170+
if hh.complex_dtypes:
171+
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
172+
def test_irfft(x, data):
173+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
172174

173-
out = xp.fft.irfft(x, **kwargs)
175+
out = xp.fft.irfft(x, **kwargs)
174176

175-
ph.assert_dtype(
176-
"irfft",
177-
in_dtype=x.dtype,
178-
out_dtype=out.dtype,
179-
expected=dh.dtype_components[x.dtype],
180-
)
177+
ph.assert_dtype(
178+
"irfft",
179+
in_dtype=x.dtype,
180+
out_dtype=out.dtype,
181+
expected=dh.dtype_components[x.dtype],
182+
)
181183

182-
_axis = x.ndim - 1 if axis == -1 else axis
183-
if n is None:
184-
axis_side = 2 * (x.shape[_axis] - 1)
185-
else:
186-
axis_side = n
187-
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
188-
ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape)
184+
_axis = x.ndim - 1 if axis == -1 else axis
185+
if n is None:
186+
axis_side = 2 * (x.shape[_axis] - 1)
187+
else:
188+
axis_side = n
189+
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
190+
ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape)
189191

190192

191193
@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
@@ -209,59 +211,60 @@ def test_rfftn(x, data):
209211
ph.assert_shape("rfftn", out_shape=out.shape, expected=tuple(expected))
210212

211213

212-
@given(
213-
x=hh.arrays(
214-
dtype=hh.complex_dtypes, shape=fft_shapes_strat.filter(lambda s: s[-1] > 1)
215-
),
216-
data=st.data(),
217-
)
218-
def test_irfftn(x, data):
219-
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
220-
221-
out = xp.fft.irfftn(x, **kwargs)
222-
223-
ph.assert_dtype(
224-
"irfftn",
225-
in_dtype=x.dtype,
226-
out_dtype=out.dtype,
227-
expected=dh.dtype_components[x.dtype],
228-
)
229-
230-
# TODO: assert shape correctly
231-
# _axes = sh.normalize_axis(axes, x.ndim)
232-
# _s = x.shape if s is None else s
233-
# expected = []
234-
# for i in range(x.ndim):
235-
# if i in _axes:
236-
# side = _s[_axes.index(i)]
237-
# else:
238-
# side = x.shape[i]
239-
# expected.append(side)
240-
# last_axis = max(_axes)
241-
# expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
242-
# ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
243-
244-
245-
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
246-
def test_hfft(x, data):
247-
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
248-
249-
out = xp.fft.hfft(x, **kwargs)
250-
251-
ph.assert_dtype(
252-
"hfft",
253-
in_dtype=x.dtype,
254-
out_dtype=out.dtype,
255-
expected=dh.dtype_components[x.dtype],
214+
if hh.complex_dtypes:
215+
@given(
216+
x=hh.arrays(
217+
dtype=hh.complex_dtypes, shape=fft_shapes_strat.filter(lambda s: s[-1] > 1)
218+
),
219+
data=st.data(),
256220
)
221+
def test_irfftn(x, data):
222+
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
223+
224+
out = xp.fft.irfftn(x, **kwargs)
225+
226+
ph.assert_dtype(
227+
"irfftn",
228+
in_dtype=x.dtype,
229+
out_dtype=out.dtype,
230+
expected=dh.dtype_components[x.dtype],
231+
)
232+
233+
# TODO: assert shape correctly
234+
# _axes = sh.normalize_axis(axes, x.ndim)
235+
# _s = x.shape if s is None else s
236+
# expected = []
237+
# for i in range(x.ndim):
238+
# if i in _axes:
239+
# side = _s[_axes.index(i)]
240+
# else:
241+
# side = x.shape[i]
242+
# expected.append(side)
243+
# last_axis = max(_axes)
244+
# expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
245+
# ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
246+
247+
248+
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
249+
def test_hfft(x, data):
250+
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
251+
252+
out = xp.fft.hfft(x, **kwargs)
253+
254+
ph.assert_dtype(
255+
"hfft",
256+
in_dtype=x.dtype,
257+
out_dtype=out.dtype,
258+
expected=dh.dtype_components[x.dtype],
259+
)
257260

258-
_axis = x.ndim - 1 if axis == -1 else axis
259-
if n is None:
260-
axis_side = 2 * (x.shape[_axis] - 1)
261-
else:
262-
axis_side = n
263-
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
264-
ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape)
261+
_axis = x.ndim - 1 if axis == -1 else axis
262+
if n is None:
263+
axis_side = 2 * (x.shape[_axis] - 1)
264+
else:
265+
axis_side = n
266+
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
267+
ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape)
265268

266269

267270
@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def refimpl(_x, _min, _max):
10611061
)
10621062

10631063

1064-
if api_version >= "2022.12":
1064+
if api_version >= "2022.12" and hh.complex_dtypes:
10651065

10661066
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
10671067
def test_conj(x):
@@ -1263,7 +1263,7 @@ def test_hypot(x1, x2):
12631263
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
12641264

12651265

1266-
if api_version >= "2022.12":
1266+
if api_version >= "2022.12" and hh.complex_dtypes:
12671267

12681268
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
12691269
def test_imag(x):
@@ -1559,7 +1559,7 @@ def test_pow(ctx, data):
15591559
# Values testing pow is too finicky
15601560

15611561

1562-
if api_version >= "2022.12":
1562+
if api_version >= "2022.12" and hh.complex_dtypes:
15631563

15641564
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
15651565
def test_real(x):

0 commit comments

Comments
 (0)