Skip to content

Commit da8e374

Browse files
committed
Rudimentary test_fft
1 parent f82c7bc commit da8e374

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

array_api_tests/_array_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __repr__(self):
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
6565
_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout
66-
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
66+
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"]
6767

6868
for attr in _top_level_attrs:
6969
try:

array_api_tests/stubs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
all_funcs.extend(funcs)
5353
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
5454

55-
EXTENSIONS: str = ["linalg"]
55+
EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available
5656
extension_to_funcs: Dict[str, List[FunctionType]] = {}
5757
for ext in EXTENSIONS:
5858
mod = name_to_mod[ext]

array_api_tests/test_fft.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import math
2+
3+
import pytest
4+
from hypothesis import given
5+
6+
from array_api_tests.typing import DataType
7+
8+
from . import _array_module as xp
9+
from . import hypothesis_helpers as hh
10+
from . import pytest_helpers as ph
11+
from . import xps
12+
13+
pytestmark = [
14+
pytest.mark.ci,
15+
pytest.mark.xp_extension("fft"),
16+
pytest.mark.min_version("draft"),
17+
]
18+
19+
20+
fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)
21+
22+
23+
def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType):
24+
if in_dtype == xp.float32:
25+
expected = xp.complex64
26+
else:
27+
assert in_dtype == xp.float64 # sanity check
28+
expected = xp.complex128
29+
ph.assert_dtype(
30+
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
31+
)
32+
33+
34+
@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat))
35+
def test_fft(x):
36+
out = xp.fft.fft(x)
37+
assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
38+
ph.assert_shape("fft", out_shape=out.shape, expected=x.shape)

0 commit comments

Comments
 (0)