1
1
#!/usr/bin/env python
2
- # Copyright (c) 2019, Intel Corporation
2
+ # Copyright (c) 2019-2020 , Intel Corporation
3
3
#
4
4
# Redistribution and use in source and binary forms, with or without
5
5
# modification, are permitted provided that the following conditions are met:
25
25
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
from . import _pydfti
28
+ from . import _float_utils
28
29
import mkl
29
30
30
31
import scipy .fft as _fft
37
38
get_workers , set_workers
38
39
)
39
40
41
+ from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
42
+
40
43
__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
41
44
'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
42
45
'hfft' , 'ihfft' , 'hfft2' , 'ihfft2' , 'hfftn' , 'ihfftn' ,
@@ -54,6 +57,7 @@ def __ua_function__(method, args, kwargs):
54
57
return NotImplemented
55
58
return fn (* args , ** kwargs )
56
59
60
+
57
61
def _implements (scipy_func ):
58
62
"""Decorator adds function to the dictionary of implemented UA functions"""
59
63
def inner (func ):
@@ -70,25 +74,54 @@ def _unitary(norm):
70
74
return norm is not None
71
75
72
76
77
+ def _cook_nd_args (a , s = None , axes = None , invreal = 0 ):
78
+ if s is None :
79
+ shapeless = 1
80
+ if axes is None :
81
+ s = list (a .shape )
82
+ else :
83
+ s = take (a .shape , axes )
84
+ else :
85
+ shapeless = 0
86
+ s = list (s )
87
+ if axes is None :
88
+ axes = list (range (- len (s ), 0 ))
89
+ if len (s ) != len (axes ):
90
+ raise ValueError ("Shape and axes have different lengths." )
91
+ if invreal and shapeless :
92
+ s [- 1 ] = (a .shape [axes [- 1 ]] - 1 ) * 2
93
+ return s , axes
94
+
95
+
96
+ def _tot_size (x , axes ):
97
+ s = x .shape
98
+ if axes is None :
99
+ return x .size
100
+ return prod ([s [ai ] for ai in axes ])
101
+
102
+
73
103
@_implements (_fft .fft )
74
- def fft (x , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
104
+ def fft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
105
+ x = _float_utils .__upcast_float16_array (a )
75
106
output = _pydfti .fft (x , n = n , axis = axis , overwrite_x = overwrite_x )
76
107
if _unitary (norm ):
77
108
output *= 1 / sqrt (output .shape [axis ])
78
109
return output
79
110
80
111
81
112
@_implements (_fft .ifft )
82
- def ifft (x , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
113
+ def ifft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
114
+ x = _float_utils .__upcast_float16_array (a )
83
115
output = _pydfti .ifft (x , n = n , axis = axis , overwrite_x = overwrite_x )
84
116
if _unitary (norm ):
85
117
output *= sqrt (output .shape [axis ])
86
118
return output
87
119
88
120
89
121
@_implements (_fft .fft2 )
90
- def fft2 (x , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
91
- output = _pydfti .fftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
122
+ def fft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
123
+ x = _float_utils .__upcast_float16_array (a )
124
+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
92
125
if _unitary (norm ):
93
126
factor = 1
94
127
for axis in axes :
@@ -98,8 +131,9 @@ def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
98
131
99
132
100
133
@_implements (_fft .ifft2 )
101
- def ifft2 (x , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
102
- output = _pydfti .ifftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
134
+ def ifft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
135
+ x = _float_utils .__upcast_float16_array (a )
136
+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
103
137
if _unitary (norm ):
104
138
factor = 1
105
139
_axes = range (output .ndim ) if axes is None else axes
@@ -110,8 +144,9 @@ def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
110
144
111
145
112
146
@_implements (_fft .fftn )
113
- def fftn (x , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
114
- output = _pydfti .fftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
147
+ def fftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
148
+ x = _float_utils .__upcast_float16_array (a )
149
+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
115
150
if _unitary (norm ):
116
151
factor = 1
117
152
_axes = range (output .ndim ) if axes is None else axes
@@ -122,12 +157,77 @@ def fftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
122
157
123
158
124
159
@_implements (_fft .ifftn )
125
- def ifftn (x , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
126
- output = _pydfti .ifftn (x , s = s , axis = axis , overwrite_x = overwrite_x )
160
+ def ifftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
161
+ x = _float_utils .__upcast_float16_array (a )
162
+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
127
163
if _unitary (norm ):
128
164
factor = 1
129
165
_axes = range (output .ndim ) if axes is None else axes
130
166
for axis in _axes :
131
167
factor *= sqrt (output .shape [axis ])
132
168
output *= factor
133
169
return output
170
+
171
+
172
+ @_implements (_fft .rfft )
173
+ def rfft (a , n = None , axis = - 1 , norm = None ):
174
+ x = _float_utils .__upcast_float16_array (a )
175
+ unitary = _unitary (norm )
176
+ x = _float_utils .__downcast_float128_array (x )
177
+ if unitary and n is None :
178
+ x = asarray (x )
179
+ n = x .shape [axis ]
180
+ output = _pydfti .rfft_numpy (x , n = n , axis = axis )
181
+ if unitary :
182
+ output *= 1 / sqrt (n )
183
+ return output
184
+
185
+
186
+ @_implements (_fft .irfft )
187
+ def irfft (a , n = None , axis = - 1 , norm = None ):
188
+ x = _float_utils .__upcast_float16_array (a )
189
+ x = _float_utils .__downcast_float128_array (x )
190
+ output = _pydfti .irfft_numpy (x , n = n , axis = axis )
191
+ if _unitary (norm ):
192
+ output *= sqrt (output .shape [axis ])
193
+ return output
194
+
195
+
196
+ @_implements (_fft .rfft2 )
197
+ def rfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
198
+ x = _float_utils .__upcast_float16_array (a )
199
+ x = _float_utils .__downcast_float128_array (a )
200
+ return rfftn (x , s , axes , norm )
201
+
202
+
203
+ @_implements (_fft .irfft2 )
204
+ def irfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
205
+ x = _float_utils .__upcast_float16_array (a )
206
+ x = _float_utils .__downcast_float128_array (x )
207
+ return irfftn (x , s , axes , norm )
208
+
209
+
210
+ @_implements (_fft .rfftn )
211
+ def rfftn (a , s = None , axes = None , norm = None ):
212
+ unitary = _unitary (norm )
213
+ x = _float_utils .__upcast_float16_array (a )
214
+ x = _float_utils .__downcast_float128_array (x )
215
+ if unitary :
216
+ x = asarray (x )
217
+ s , axes = _cook_nd_args (x , s , axes )
218
+
219
+ output = _pydfti .rfftn_numpy (x , s , axes )
220
+ if unitary :
221
+ n_tot = prod (asarray (s , dtype = output .dtype ))
222
+ output *= 1 / sqrt (n_tot )
223
+ return output
224
+
225
+
226
+ @_implements (_fft .irfftn )
227
+ def irfftn (a , s = None , axes = None , norm = None ):
228
+ x = _float_utils .__upcast_float16_array (a )
229
+ x = _float_utils .__downcast_float128_array (x )
230
+ output = _pydfti .irfftn_numpy (x , s , axes )
231
+ if _unitary (norm ):
232
+ output *= sqrt (_tot_size (output , axes ))
233
+ return output
0 commit comments