Skip to content

Commit 933b8ed

Browse files
authored
conditional import scipy_fft (#195)
1 parent 6524960 commit 933b8ed

File tree

9 files changed

+56
-23
lines changed

9 files changed

+56
-23
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ __pycache__/
77

88
mkl_fft/_pydfti.c
99
mkl_fft/_pydfti.cpython*.so
10+
mkl_fft/_pydfti.*-win_amd64.pyd
1011
mkl_fft/src/mklfft.c

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ To build `mkl_fft` from sources on Linux with Intel® OneMKL:
9393
- `git clone https://github.com/IntelPython/mkl_fft.git mkl_fft`
9494
- `cd mkl_fft`
9595
- `python -m pip install .`
96+
- `pip install scipy` (optional: for using `mkl_fft.interface.scipy_fft` module)
9697
- `cd ..`
9798
- `python -c "import mkl_fft"`
9899

@@ -103,5 +104,6 @@ To build `mkl_fft` from sources on Linux with conda follow these steps:
103104
- `git clone https://github.com/IntelPython/mkl_fft.git mkl_fft`
104105
- `cd mkl_fft`
105106
- `python -m pip install .`
107+
- `conda install scipy` (optional: for using `mkl_fft.interface.scipy_fft` module)
106108
- `cd ..`
107109
- `python -c "import mkl_fft"`

conda-recipe-cf/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ test:
3232
- pytest -v --pyargs mkl_fft
3333
requires:
3434
- pytest
35-
- scipy
35+
- scipy >=1.10
3636
imports:
3737
- mkl_fft
3838
- mkl_fft.interfaces

conda-recipe/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ test:
3333
- pytest -v --pyargs mkl_fft
3434
requires:
3535
- pytest
36-
- scipy
36+
- scipy >=1.10
3737
imports:
3838
- mkl_fft
3939
- mkl_fft.interfaces

mkl_fft/interfaces/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,11 @@
2323
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

26-
from . import numpy_fft, scipy_fft
26+
from . import numpy_fft
27+
28+
try:
29+
import scipy.fft
30+
except ImportError:
31+
pass
32+
else:
33+
from . import scipy_fft

mkl_fft/tests/test_interfaces.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,26 @@
2929

3030
import mkl_fft.interfaces as mfi
3131

32+
try:
33+
scipy_fft = mfi.scipy_fft
34+
except AttributeError:
35+
scipy_fft = None
36+
37+
interfaces = []
38+
ids = []
39+
if scipy_fft is not None:
40+
interfaces.append(scipy_fft)
41+
ids.append("scipy")
42+
interfaces.append(mfi.numpy_fft)
43+
ids.append("numpy")
44+
3245

3346
@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
3447
@pytest.mark.parametrize(
3548
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
3649
)
3750
def test_scipy_fft(norm, dtype):
51+
pytest.importorskip("scipy", reason="requires scipy")
3852
x = np.ones(511, dtype=dtype)
3953
w = mfi.scipy_fft.fft(x, norm=norm, workers=None, plan=None)
4054
xx = mfi.scipy_fft.ifft(w, norm=norm, workers=None, plan=None)
@@ -57,6 +71,7 @@ def test_numpy_fft(norm, dtype):
5771
@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
5872
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
5973
def test_scipy_rfft(norm, dtype):
74+
pytest.importorskip("scipy", reason="requires scipy")
6075
x = np.ones(511, dtype=dtype)
6176
w = mfi.scipy_fft.rfft(x, norm=norm, workers=None, plan=None)
6277
xx = mfi.scipy_fft.irfft(
@@ -87,6 +102,7 @@ def test_numpy_rfft(norm, dtype):
87102
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
88103
)
89104
def test_scipy_fftn(norm, dtype):
105+
pytest.importorskip("scipy", reason="requires scipy")
90106
x = np.ones((37, 83), dtype=dtype)
91107
w = mfi.scipy_fft.fftn(x, norm=norm, workers=None, plan=None)
92108
xx = mfi.scipy_fft.ifftn(w, norm=norm, workers=None, plan=None)
@@ -109,6 +125,7 @@ def test_numpy_fftn(norm, dtype):
109125
@pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"])
110126
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
111127
def test_scipy_rfftn(norm, dtype):
128+
pytest.importorskip("scipy", reason="requires scipy")
112129
x = np.ones((37, 83), dtype=dtype)
113130
w = mfi.scipy_fft.rfftn(x, norm=norm, workers=None, plan=None)
114131
xx = mfi.scipy_fft.irfftn(w, s=x.shape, norm=norm, workers=None, plan=None)
@@ -143,32 +160,30 @@ def _get_blacklisted_dtypes():
143160

144161
@pytest.mark.parametrize("dtype", _get_blacklisted_dtypes())
145162
def test_scipy_no_support_for(dtype):
163+
pytest.importorskip("scipy", reason="requires scipy")
146164
x = np.ones(16, dtype=dtype)
147165
assert_raises(NotImplementedError, mfi.scipy_fft.ifft, x)
148166

149167

150168
def test_scipy_fft_arg_validate():
169+
pytest.importorskip("scipy", reason="requires scipy")
151170
with pytest.raises(ValueError):
152171
mfi.scipy_fft.fft([1, 2, 3, 4], norm=b"invalid")
153172

154173
with pytest.raises(NotImplementedError):
155174
mfi.scipy_fft.fft([1, 2, 3, 4], plan="magic")
156175

157176

158-
@pytest.mark.parametrize(
159-
"func", [mfi.scipy_fft.rfft2, mfi.numpy_fft.rfft2], ids=["scipy", "numpy"]
160-
)
161-
def test_axes(func):
177+
@pytest.mark.parametrize("interface", interfaces, ids=ids)
178+
def test_axes(interface):
162179
x = np.arange(24.0).reshape(2, 3, 4)
163-
res = func(x, axes=(1, 2))
180+
res = interface.rfft2(x, axes=(1, 2))
164181
exp = np.fft.rfft2(x, axes=(1, 2))
165182
tol = 64 * np.finfo(np.float64).eps
166183
assert np.allclose(res, exp, atol=tol, rtol=tol)
167184

168185

169-
@pytest.mark.parametrize(
170-
"interface", [mfi.scipy_fft, mfi.numpy_fft], ids=["scipy", "numpy"]
171-
)
186+
@pytest.mark.parametrize("interface", interfaces, ids=ids)
172187
@pytest.mark.parametrize(
173188
"func", ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
174189
)

mkl_fft/tests/third_party/scipy/test_basic.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,26 @@
77

88
import numpy as np
99
import pytest
10-
import scipy
1110
from numpy.random import random
1211
from numpy.testing import assert_allclose, assert_array_almost_equal
1312
from pytest import raises as assert_raises
1413

1514
# pylint: disable=possibly-used-before-assignment
16-
if scipy.__version__ < "1.12":
17-
# scipy from Intel channel is 1.10 with python 3.9 and 3.10
18-
pytest.skip("This test file needs scipy>=1.12", allow_module_level=True)
19-
elif scipy.__version__ < "1.14":
20-
# For python-3.11 and 3.12, scipy<1.14 is installed from Intel channel
21-
# For python<=3.9, scipy<1.14 is installed from conda channel
22-
# pylint: disable=no-name-in-module
23-
from scipy._lib._array_api import size as xp_size
15+
try:
16+
import scipy
17+
except ImportError:
18+
pytest.skip("This test file needs scipy", allow_module_level=True)
2419
else:
25-
from scipy._lib._array_api import xp_size
20+
if np.lib.NumpyVersion(scipy.__version__) < "1.12.0":
21+
# scipy from Intel channel is 1.10 with python 3.9 and 3.10
22+
pytest.skip("This test file needs scipy>=1.12", allow_module_level=True)
23+
elif np.lib.NumpyVersion(scipy.__version__) < "1.14.0":
24+
# For python-3.11 and 3.12, scipy<1.14 is installed from Intel channel
25+
# For python<=3.9, scipy<1.14 is installed from conda channel
26+
# pylint: disable=no-name-in-module
27+
from scipy._lib._array_api import size as xp_size
28+
else:
29+
from scipy._lib._array_api import xp_size
2630

2731
from scipy._lib._array_api import is_numpy, xp_assert_close, xp_assert_equal
2832

mkl_fft/tests/third_party/scipy/test_multithreading.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import pytest
99
from numpy.testing import assert_allclose
1010

11-
import mkl_fft.interfaces.scipy_fft as fft
11+
try:
12+
import mkl_fft.interfaces.scipy_fft as fft
13+
except ImportError:
14+
pytest.skip("This test file needs scipy", allow_module_level=True)
1215

1316

1417
@pytest.fixture(scope="module")

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ readme = {file = "README.md", content-type = "text/markdown"}
5959
requires-python = ">=3.9,<3.13"
6060

6161
[project.optional-dependencies]
62-
test = ["pytest", "scipy"]
62+
scipy_interface = ["scipy>=1.10"]
63+
test = ["pytest", "scipy>=1.10"]
6364

6465
[project.urls]
6566
Download = "http://github.com/IntelPython/mkl_fft"

0 commit comments

Comments
 (0)