Skip to content

Commit 66ff0dd

Browse files
added rfft* to scipy.fft backend code
1 parent a44387c commit 66ff0dd

File tree

1 file changed

+111
-11
lines changed

1 file changed

+111
-11
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 111 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2019, Intel Corporation
2+
# Copyright (c) 2019-2020, Intel Corporation
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions are met:
@@ -25,6 +25,7 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
from . import _pydfti
28+
from . import _float_utils
2829
import mkl
2930

3031
import scipy.fft as _fft
@@ -37,6 +38,8 @@
3738
get_workers, set_workers
3839
)
3940

41+
from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod)
42+
4043
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
4144
'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
4245
'hfft', 'ihfft', 'hfft2', 'ihfft2', 'hfftn', 'ihfftn',
@@ -54,6 +57,7 @@ def __ua_function__(method, args, kwargs):
5457
return NotImplemented
5558
return fn(*args, **kwargs)
5659

60+
5761
def _implements(scipy_func):
5862
"""Decorator adds function to the dictionary of implemented UA functions"""
5963
def inner(func):
@@ -70,25 +74,54 @@ def _unitary(norm):
7074
return norm is not None
7175

7276

77+
def _cook_nd_args(a, s=None, axes=None, invreal=0):
78+
if s is None:
79+
shapeless = 1
80+
if axes is None:
81+
s = list(a.shape)
82+
else:
83+
s = take(a.shape, axes)
84+
else:
85+
shapeless = 0
86+
s = list(s)
87+
if axes is None:
88+
axes = list(range(-len(s), 0))
89+
if len(s) != len(axes):
90+
raise ValueError("Shape and axes have different lengths.")
91+
if invreal and shapeless:
92+
s[-1] = (a.shape[axes[-1]] - 1) * 2
93+
return s, axes
94+
95+
96+
def _tot_size(x, axes):
97+
s = x.shape
98+
if axes is None:
99+
return x.size
100+
return prod([s[ai] for ai in axes])
101+
102+
73103
@_implements(_fft.fft)
74-
def fft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
104+
def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
105+
x = _float_utils.__upcast_float16_array(a)
75106
output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
76107
if _unitary(norm):
77108
output *= 1 / sqrt(output.shape[axis])
78109
return output
79110

80111

81112
@_implements(_fft.ifft)
82-
def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
113+
def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
114+
x = _float_utils.__upcast_float16_array(a)
83115
output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
84116
if _unitary(norm):
85117
output *= sqrt(output.shape[axis])
86118
return output
87119

88120

89121
@_implements(_fft.fft2)
90-
def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
91-
output = _pydfti.fftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
122+
def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
123+
x = _float_utils.__upcast_float16_array(a)
124+
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
92125
if _unitary(norm):
93126
factor = 1
94127
for axis in axes:
@@ -98,8 +131,9 @@ def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
98131

99132

100133
@_implements(_fft.ifft2)
101-
def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
102-
output = _pydfti.ifftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
134+
def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
135+
x = _float_utils.__upcast_float16_array(a)
136+
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
103137
if _unitary(norm):
104138
factor = 1
105139
_axes = range(output.ndim) if axes is None else axes
@@ -110,8 +144,9 @@ def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
110144

111145

112146
@_implements(_fft.fftn)
113-
def fftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
114-
output = _pydfti.fftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
147+
def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
148+
x = _float_utils.__upcast_float16_array(a)
149+
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
115150
if _unitary(norm):
116151
factor = 1
117152
_axes = range(output.ndim) if axes is None else axes
@@ -122,12 +157,77 @@ def fftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
122157

123158

124159
@_implements(_fft.ifftn)
125-
def ifftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
126-
output = _pydfti.ifftn(x, s=s, axis=axis, overwrite_x=overwrite_x)
160+
def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
161+
x = _float_utils.__upcast_float16_array(a)
162+
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
127163
if _unitary(norm):
128164
factor = 1
129165
_axes = range(output.ndim) if axes is None else axes
130166
for axis in _axes:
131167
factor *= sqrt(output.shape[axis])
132168
output *= factor
133169
return output
170+
171+
172+
@_implements(_fft.rfft)
173+
def rfft(a, n=None, axis=-1, norm=None):
174+
x = _float_utils.__upcast_float16_array(a)
175+
unitary = _unitary(norm)
176+
x = _float_utils.__downcast_float128_array(x)
177+
if unitary and n is None:
178+
x = asarray(x)
179+
n = x.shape[axis]
180+
output = _pydfti.rfft_numpy(x, n=n, axis=axis)
181+
if unitary:
182+
output *= 1 / sqrt(n)
183+
return output
184+
185+
186+
@_implements(_fft.irfft)
187+
def irfft(a, n=None, axis=-1, norm=None):
188+
x = _float_utils.__upcast_float16_array(a)
189+
x = _float_utils.__downcast_float128_array(x)
190+
output = _pydfti.irfft_numpy(x, n=n, axis=axis)
191+
if _unitary(norm):
192+
output *= sqrt(output.shape[axis])
193+
return output
194+
195+
196+
@_implements(_fft.rfft2)
197+
def rfft2(a, s=None, axes=(-2, -1), norm=None):
198+
x = _float_utils.__upcast_float16_array(a)
199+
x = _float_utils.__downcast_float128_array(a)
200+
return rfftn(x, s, axes, norm)
201+
202+
203+
@_implements(_fft.irfft2)
204+
def irfft2(a, s=None, axes=(-2, -1), norm=None):
205+
x = _float_utils.__upcast_float16_array(a)
206+
x = _float_utils.__downcast_float128_array(x)
207+
return irfftn(x, s, axes, norm)
208+
209+
210+
@_implements(_fft.rfftn)
211+
def rfftn(a, s=None, axes=None, norm=None):
212+
unitary = _unitary(norm)
213+
x = _float_utils.__upcast_float16_array(a)
214+
x = _float_utils.__downcast_float128_array(x)
215+
if unitary:
216+
x = asarray(x)
217+
s, axes = _cook_nd_args(x, s, axes)
218+
219+
output = _pydfti.rfftn_numpy(x, s, axes)
220+
if unitary:
221+
n_tot = prod(asarray(s, dtype=output.dtype))
222+
output *= 1 / sqrt(n_tot)
223+
return output
224+
225+
226+
@_implements(_fft.irfftn)
227+
def irfftn(a, s=None, axes=None, norm=None):
228+
x = _float_utils.__upcast_float16_array(a)
229+
x = _float_utils.__downcast_float128_array(x)
230+
output = _pydfti.irfftn_numpy(x, s, axes)
231+
if _unitary(norm):
232+
output *= sqrt(_tot_size(output, axes))
233+
return output

0 commit comments

Comments
 (0)