40
40
41
41
from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
42
42
43
+ _max_threads_count = mkl .get_max_threads ()
44
+
45
+
43
46
__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
44
47
'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
45
48
'hfft' , 'ihfft' , 'hfft2' , 'ihfft2' , 'hfftn' , 'ihfftn' ,
@@ -101,9 +104,20 @@ def _tot_size(x, axes):
101
104
102
105
103
106
def _workers_to_num_threads (w ):
107
+ """Handle conversion of workers to a positive number of threads in the
108
+ same way as scipy.fft.helpers._workers.
109
+ """
104
110
if w is None :
105
- return mkl .domain_get_max_threads (domain = 'fft' )
106
- return int (w )
111
+ return get_workers ()
112
+ _w = int (w )
113
+ if (_w == 0 ):
114
+ raise ValueError ("Number of workers must be nonzero" )
115
+ if (_w < 0 ):
116
+ _w += _max_threads_count + 1
117
+ if _w <= 0 :
118
+ raise ValueError ("workers value out of range; got {}, must not be"
119
+ " less than {}" .format (w , - _max_threads_count ))
120
+ return _w
107
121
108
122
109
123
class Workers :
@@ -119,8 +133,7 @@ def __enter__(self):
119
133
120
134
def __exit__ (self , * args ):
121
135
# restore default
122
- max_num_threads = mkl .domain_get_max_threads (domain = 'fft' )
123
- mkl .domain_set_num_threads (max_num_threads , domain = 'fft' )
136
+ mkl .domain_set_num_threads (_max_threads_count , domain = 'fft' )
124
137
125
138
126
139
@_implements (_fft .fft )
0 commit comments