Skip to content

Commit b316659

Browse files
initial version of fft backend for scipy 1.4
1 parent 3874305 commit b316659

File tree

2 files changed

+138
-3
lines changed

2 files changed

+138
-3
lines changed

conda-recipe/meta.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% set version = "1.0dev" %}
1+
{% set version = "1.1.0" %}
22
{% set buildnumber = 0 %}
33

44

@@ -19,6 +19,7 @@ build:
1919
- {{ SP_DIR.replace('\\', '/') if win else SP_DIR }}/mkl_fft/_pydfti.*
2020
- {{ SP_DIR.replace('\\', '/') if win else SP_DIR }}/mkl_fft/_numpy_fft.py
2121
- {{ SP_DIR.replace('\\', '/') if win else SP_DIR }}/mkl_fft/_scipy_fft.py
22+
- {{ SP_DIR.replace('\\', '/') if win else SP_DIR }}/mkl_fft/_scipy_fft_backend.py
2223
- {{ SP_DIR.replace('\\', '/') if win else SP_DIR }}/mkl_fft/setup.py
2324
- {{ SP_DIR.replace('\\', '/') if win else SP_DIR }}/mkl_fft/tests/test_fft1d.py
2425
- {{ SP_DIR.replace('\\', '/') if win else SP_DIR }}/mkl_fft/__init__.pyc [py27]
@@ -36,12 +37,13 @@ requirements:
3637
- python
3738
- setuptools
3839
- intelpython
39-
- mkl-devel [not nomkl]
40+
- mkl-devel
4041
- cython
4142
- numpy x.x
4243
run:
4344
- python
44-
- mkl [not nomkl]
45+
- mkl
46+
- mkl-service
4547
- intelpython
4648
- numpy x.x
4749

mkl_fft/_scipy_fft_backend.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2019, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from . import _pydfti
28+
import mkl
29+
30+
import scipy.fft as _fft
31+
32+
# Complete the namespace (these are not actually used in this module)
33+
from scipy.fft import (
34+
dct, idct, dst, idst, dctn, idctn, dstn, idstn,
35+
hfft2, ihfft2, hfftn, ihfftn,
36+
fftshift, ifftshift, fftfreq, rfftfreq,
37+
get_workers, set_workers
38+
)
39+
40+
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
41+
'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
42+
'hfft', 'ihfft', 'hfft2', 'ihfft2', 'hfftn', 'ihfftn',
43+
'dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', 'idstn',
44+
'fftshift', 'ifftshift', 'fftfreq', 'rfftfreq', 'get_workers',
45+
'set_workers', 'next_fast_len']
46+
47+
__ua_domain__ = 'numpy.scipy.fft'
48+
__implemented = dict()
49+
50+
def __ua_function__(method, args, kwargs):
51+
"""Fetch registered UA function."""
52+
fn = _implemented.get(method, None)
53+
if fn is None:
54+
return NotImplemented
55+
return fn(*args, **kwargs)
56+
57+
def _implements(scipy_func):
58+
"""Decorator adds function to the dictionary of implemented UA functions"""
59+
def inner(func):
60+
_implemented[scipy_func] = func
61+
return func
62+
63+
return inner
64+
65+
66+
def _unitary(norm):
67+
if norm not in (None, "ortho"):
68+
raise ValueError("Invalid norm value %s, should be None or \"ortho\"."
69+
% norm)
70+
return norm is not None
71+
72+
73+
@_implements(_fft.fft)
74+
def fft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
75+
output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
76+
if _unitary(norm):
77+
output *= 1 / sqrt(output.shape[axis])
78+
return output
79+
80+
81+
@_implements(_fft.ifft)
82+
def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
83+
output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
84+
if _unitary(norm):
85+
output *= sqrt(output.shape[axis])
86+
return output
87+
88+
89+
@_implements(_fft.fft2)
90+
def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
91+
output = _pydfti.fftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
92+
if _unitary(norm):
93+
factor = 1
94+
for axis in axes:
95+
factor *= 1 / sqrt(output.shape[axis])
96+
output *= factor
97+
return output
98+
99+
100+
@_implements(_fft.ifft2)
101+
def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
102+
output = _pydfti.ifftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
103+
if _unitary(norm):
104+
factor = 1
105+
_axes = range(output.ndim) if axes is None else axes
106+
for axis in _axes:
107+
factor *= sqrt(output.shape[axis])
108+
output *= factor
109+
return output
110+
111+
112+
@_implements(_fft.fftn)
113+
def fftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
114+
output = _pydfti.fftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
115+
if _unitary(norm):
116+
factor = 1
117+
_axes = range(output.ndim) if axes is None else axes
118+
for axis in _axes:
119+
factor *= 1 / sqrt(output.shape[axis])
120+
output *= factor
121+
return output
122+
123+
124+
@_implements(_fft.ifftn)
125+
def ifftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
126+
output = _pydfti.ifftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
127+
if _unitary(norm):
128+
factor = 1
129+
_axes = range(output.ndim) if axes is None else axes
130+
for axis in _axes:
131+
factor *= sqrt(output.shape[axis])
132+
output *= factor
133+
return output

0 commit comments

Comments
 (0)