|
24 | 24 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 | 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 | 26 |
|
27 |
| -from . import _float_utils |
28 |
| -from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module |
| 27 | +import mkl_fft |
| 28 | + |
| 29 | +from ._float_utils import __upcast_float16_array |
29 | 30 |
|
30 | 31 | __all__ = ["fft", "ifft", "fftn", "ifftn", "fft2", "ifft2", "rfft", "irfft"]
|
31 | 32 |
|
32 | 33 |
|
33 | 34 | def fft(a, n=None, axis=-1, overwrite_x=False):
|
34 |
| - x = _float_utils.__upcast_float16_array(a) |
| 35 | + x = __upcast_float16_array(a) |
35 | 36 | return mkl_fft.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
|
36 | 37 |
|
37 | 38 |
|
38 | 39 | def ifft(a, n=None, axis=-1, overwrite_x=False):
|
39 |
| - x = _float_utils.__upcast_float16_array(a) |
| 40 | + x = __upcast_float16_array(a) |
40 | 41 | return mkl_fft.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
|
41 | 42 |
|
42 | 43 |
|
43 | 44 | def fftn(a, shape=None, axes=None, overwrite_x=False):
|
44 |
| - x = _float_utils.__upcast_float16_array(a) |
45 |
| - return mkl_fft.fftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x) |
| 45 | + x = __upcast_float16_array(a) |
| 46 | + return mkl_fft.fftn(x, s=shape, axes=axes, overwrite_x=overwrite_x) |
46 | 47 |
|
47 | 48 |
|
48 | 49 | def ifftn(a, shape=None, axes=None, overwrite_x=False):
|
49 |
| - x = _float_utils.__upcast_float16_array(a) |
50 |
| - return mkl_fft.ifftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x) |
| 50 | + x = __upcast_float16_array(a) |
| 51 | + return mkl_fft.ifftn(x, s=shape, axes=axes, overwrite_x=overwrite_x) |
51 | 52 |
|
52 | 53 |
|
53 | 54 | def fft2(a, shape=None, axes=(-2, -1), overwrite_x=False):
|
54 |
| - x = _float_utils.__upcast_float16_array(a) |
55 |
| - return mkl_fft.fftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x) |
| 55 | + x = __upcast_float16_array(a) |
| 56 | + return mkl_fft.fftn(x, s=shape, axes=axes, overwrite_x=overwrite_x) |
56 | 57 |
|
57 | 58 |
|
58 | 59 | def ifft2(a, shape=None, axes=(-2, -1), overwrite_x=False):
|
59 |
| - x = _float_utils.__upcast_float16_array(a) |
60 |
| - return mkl_fft.ifftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x) |
| 60 | + x = __upcast_float16_array(a) |
| 61 | + return mkl_fft.ifftn(x, s=shape, axes=axes, overwrite_x=overwrite_x) |
61 | 62 |
|
62 | 63 |
|
63 | 64 | def rfft(a, n=None, axis=-1, overwrite_x=False):
|
64 |
| - x = _float_utils.__upcast_float16_array(a) |
| 65 | + x = __upcast_float16_array(a) |
65 | 66 | return mkl_fft.rfftpack(x, n=n, axis=axis, overwrite_x=overwrite_x)
|
66 | 67 |
|
67 | 68 |
|
68 | 69 | def irfft(a, n=None, axis=-1, overwrite_x=False):
|
69 |
| - x = _float_utils.__upcast_float16_array(a) |
| 70 | + x = __upcast_float16_array(a) |
70 | 71 | return mkl_fft.irfftpack(x, n=n, axis=axis, overwrite_x=overwrite_x)
|
0 commit comments