Skip to content

Feature/tls dfti cache #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 154 additions & 106 deletions mkl_fft/_pydfti.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,50 @@ except ImportError:
from numpy.core._multiarray_tests import internal_overlap

from libc.string cimport memcpy
cimport cpython.pycapsule
from cpython.exc cimport (PyErr_Occurred, PyErr_Clear)
from cpython.mem cimport (PyMem_Malloc, PyMem_Free)

from threading import local as threading_local

# thread-local storage
_tls = threading_local()

cdef const char *capsule_name = "dfti_cache"

cdef void _capsule_destructor(object caps):
cdef DftiCache *_cache = NULL
cdef int status = 0
if (caps is None):
print("Nothing to destroy")
return
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
status = _free_dfti_cache(_cache)
PyMem_Free(_cache)
if (status != 0):
raise ValueError("Internal Error: Freeing DFTI Cache returned with error = {}".format(status))


def _tls_dfti_cache_capsule():
cdef DftiCache *_cache_struct

init = getattr(_tls, 'initialized', None)
if (init is None):
_cache_struct = <DftiCache *> PyMem_Malloc(sizeof(DftiCache));
# important to initialized
_cache_struct.initialized = 0
_cache_struct.hand = NULL
_tls.initialized = True
_tls.capsule = cpython.pycapsule.PyCapsule_New(<void *>_cache_struct, capsule_name, &_capsule_destructor)
capsule = getattr(_tls, 'capsule', None)
if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
raise ValueError("Internal Error: invalid capsule stored in TLS")
return capsule

from threading import Lock
_lock = Lock()

cdef extern from "Python.h":
ctypedef int size_t

void* PyMem_Malloc(size_t n)
void PyMem_Free(void* buf)

int PyErr_Occurred()
void PyErr_Clear()
long PyInt_AsLong(object ob)
int PyObject_HasAttrString(object, char*)

Expand All @@ -58,32 +90,36 @@ cdef extern from *:
object PyArray_BASE(cnp.ndarray)

cdef extern from "src/mklfft.h":
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int)
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int)
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray)
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray)

int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int)
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int)
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray)
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray)

int double_mkl_rfft_in(cnp.ndarray, int, int)
int double_mkl_irfft_in(cnp.ndarray, int, int)
int float_mkl_rfft_in(cnp.ndarray, int, int)
int float_mkl_irfft_in(cnp.ndarray, int, int)

int double_double_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray)
int double_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
int float_float_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray)
int float_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)

int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
cdef struct DftiCache:
void * hand
int initialized
int _free_dfti_cache(DftiCache *)
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)

int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarra, DftiCache*)
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)

int double_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
int double_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
int float_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
int float_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)

int double_double_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
int double_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
int float_float_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
int float_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)

int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)

int cdouble_cdouble_mkl_fftnd_in(cnp.ndarray)
int cdouble_cdouble_mkl_ifftnd_in(cnp.ndarray)
Expand Down Expand Up @@ -268,6 +304,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
cdef int ALL_HARMONICS = 1
cdef char * c_error_msg = NULL
cdef bytes py_error_msg
cdef DftiCache *_cache

x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
&axis_, &n_, &in_place, &xnd, &dir_, 0)
Expand Down Expand Up @@ -295,19 +332,20 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
in_place = 1

if in_place:
with _lock:
if x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_)
else:
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_)
else:
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
if x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
else:
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
else:
status = 1
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
else:
status = 1

if status:
c_error_msg = mkl_dfti_error(status)
Expand All @@ -327,37 +365,38 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
f_arr = __allocate_result(x_arr, n_, axis_, f_type);

# call out-of-place FFT
with _lock:
if f_type is cnp.NPY_CDOUBLE:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
else:
status = double_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
elif x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
status = cdouble_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
if x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
else:
status = float_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr)
else:
status = cfloat_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
if f_type is cnp.NPY_CDOUBLE:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
else:
status = double_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
elif x_type is cnp.NPY_CDOUBLE:
if dir_ < 0:
status = cdouble_cdouble_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, _cache)
else:
status = cdouble_cdouble_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, _cache)
else:
if x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
else:
status = float_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
elif x_type is cnp.NPY_CFLOAT:
if dir_ < 0:
status = cfloat_cfloat_mkl_ifft1d_out(
x_arr, n_, <int> axis_, f_arr, _cache)
else:
status = cfloat_cfloat_mkl_fft1d_out(
x_arr, n_, <int> axis_, f_arr, _cache)

if (status):
c_error_msg = mkl_dfti_error(status)
Expand Down Expand Up @@ -388,6 +427,7 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
cdef int x_type, status
cdef char * c_error_msg = NULL
cdef bytes py_error_msg
cdef DftiCache *_cache

x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
&axis_, &n_, &in_place, &xnd, &dir_, 1)
Expand All @@ -413,19 +453,20 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
in_place = 1

if in_place:
with _lock:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_mkl_irfft_in(x_arr, n_, <int> axis_)
else:
status = double_mkl_rfft_in(x_arr, n_, <int> axis_)
elif x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_mkl_irfft_in(x_arr, n_, <int> axis_)
else:
status = float_mkl_rfft_in(x_arr, n_, <int> axis_)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
else:
status = double_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
elif x_type is cnp.NPY_FLOAT:
if dir_ < 0:
status = float_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
else:
status = 1
status = float_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
else:
status = 1

if status:
c_error_msg = mkl_dfti_error(status)
Expand All @@ -443,17 +484,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
f_arr = __allocate_result(x_arr, n_, axis_, x_type);

# call out-of-place FFT
with _lock:
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
else:
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
if x_type is cnp.NPY_DOUBLE:
if dir_ < 0:
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
else:
if dir_ < 0:
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
else:
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
else:
if dir_ < 0:
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
else:
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)

if (status):
c_error_msg = mkl_dfti_error(status)
Expand All @@ -479,6 +521,7 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
cdef int direction = 1 # dummy, only used for the sake of arg-processing
cdef char * c_error_msg = NULL
cdef bytes py_error_msg
cdef DftiCache *_cache

x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
&axis_, &n_, &in_place, &xnd, &dir_, 1)
Expand Down Expand Up @@ -509,11 +552,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):

# call out-of-place FFT
if x_type is cnp.NPY_FLOAT:
with _lock:
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
else:
with _lock:
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)

if (status):
c_error_msg = mkl_dfti_error(status)
Expand Down Expand Up @@ -553,6 +598,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
cdef int direction = 1 # dummy, only used for the sake of arg-processing
cdef char * c_error_msg = NULL
cdef bytes py_error_msg
cdef DftiCache *_cache

int_n = _is_integral(n)
# nn gives the number elements along axis of the input that we use
Expand Down Expand Up @@ -591,11 +637,13 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):

# call out-of-place FFT
if x_type is cnp.NPY_CFLOAT:
with _lock:
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
else:
with _lock:
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)

if (status):
c_error_msg = mkl_dfti_error(status)
Expand Down
Loading