Skip to content

Commit 56dfe4b

Browse files
committed
get rid of direction
1 parent 8a6858e commit 56dfe4b

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,10 @@ cdef cnp.ndarray _process_arguments(
224224
object x,
225225
object n,
226226
object axis,
227-
object direction,
228227
long *axis_,
229228
long *n_,
230229
int *in_place,
231230
int *xnd,
232-
int *dir_,
233231
int realQ,
234232
):
235233
"""
@@ -239,11 +237,6 @@ cdef cnp.ndarray _process_arguments(
239237
cdef long n_max = 0
240238
cdef cnp.ndarray x_arr "xx_arrayObject"
241239

242-
if direction not in [-1, +1]:
243-
raise ValueError("Direction of FFT should +1 or -1")
244-
else:
245-
dir_[0] = -1 if direction is -1 else +1
246-
247240
# convert x to ndarray, ensure that strides are multiples of itemsize
248241
x_arr = PyArray_CheckFromAny(
249242
x, NULL, 0, 0,
@@ -382,18 +375,18 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
382375
"""
383376
cdef cnp.ndarray x_arr "x_arrayObject"
384377
cdef cnp.ndarray f_arr "f_arrayObject"
385-
cdef int xnd, n_max = 0, in_place, dir_
378+
cdef int xnd, n_max = 0, in_place
386379
cdef long n_, axis_
387380
cdef int x_type, f_type, status = 0
388381
cdef int ALL_HARMONICS = 1
389382
cdef char * c_error_msg = NULL
390383
cdef bytes py_error_msg
391384
cdef DftiCache *_cache
392385

393-
x_arr = _process_arguments(
394-
x, n, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0
395-
)
386+
if direction not in [-1, +1]:
387+
raise ValueError("Direction of FFT should +1 or -1")
396388

389+
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 0)
397390
x_type = cnp.PyArray_TYPE(x_arr)
398391

399392
if out is not None:
@@ -429,7 +422,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
429422
_cache_capsule, capsule_name
430423
)
431424
if x_type is cnp.NPY_CDOUBLE:
432-
if dir_ < 0:
425+
if direction < 0:
433426
status = cdouble_mkl_ifft1d_in(
434427
x_arr, n_, <int> axis_, fsc, _cache
435428
)
@@ -438,7 +431,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
438431
x_arr, n_, <int> axis_, fsc, _cache
439432
)
440433
elif x_type is cnp.NPY_CFLOAT:
441-
if dir_ < 0:
434+
if direction < 0:
442435
status = cfloat_mkl_ifft1d_in(
443436
x_arr, n_, <int> axis_, fsc, _cache
444437
)
@@ -487,7 +480,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
487480
)
488481
if f_type is cnp.NPY_CDOUBLE:
489482
if x_type is cnp.NPY_DOUBLE:
490-
if dir_ < 0:
483+
if direction < 0:
491484
status = double_cdouble_mkl_ifft1d_out(
492485
x_arr,
493486
n_,
@@ -508,7 +501,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
508501
_cache,
509502
)
510503
elif x_type is cnp.NPY_CDOUBLE:
511-
if dir_ < 0:
504+
if direction < 0:
512505
status = cdouble_cdouble_mkl_ifft1d_out(
513506
x_arr, n_, <int> axis_, f_arr, fsc, _cache
514507
)
@@ -518,7 +511,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
518511
)
519512
else:
520513
if x_type is cnp.NPY_FLOAT:
521-
if dir_ < 0:
514+
if direction < 0:
522515
status = float_cfloat_mkl_ifft1d_out(
523516
x_arr,
524517
n_,
@@ -539,7 +532,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
539532
_cache,
540533
)
541534
elif x_type is cnp.NPY_CFLOAT:
542-
if dir_ < 0:
535+
if direction < 0:
543536
status = cfloat_cfloat_mkl_ifft1d_out(
544537
x_arr, n_, <int> axis_, f_arr, fsc, _cache
545538
)
@@ -571,18 +564,15 @@ def _r2c_fft1d_impl(
571564
"""
572565
cdef cnp.ndarray x_arr "x_arrayObject"
573566
cdef cnp.ndarray f_arr "f_arrayObject"
574-
cdef int xnd, in_place, dir_
567+
cdef int xnd, in_place
575568
cdef long n_, axis_
576569
cdef int x_type, f_type, status, requirement
577570
cdef int HALF_HARMONICS = 0 # give only positive index harmonics
578-
cdef int direction = 1 # dummy, only used for the sake of arg-processing
579571
cdef char * c_error_msg = NULL
580572
cdef bytes py_error_msg
581573
cdef DftiCache *_cache
582574

583-
x_arr = _process_arguments(
584-
x, n, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 1
585-
)
575+
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 1)
586576

587577
x_type = cnp.PyArray_TYPE(x_arr)
588578

@@ -672,20 +662,17 @@ def _c2r_fft1d_impl(
672662
"""
673663
cdef cnp.ndarray x_arr "x_arrayObject"
674664
cdef cnp.ndarray f_arr "f_arrayObject"
675-
cdef int xnd, in_place, dir_, int_n
665+
cdef int xnd, in_place, int_n
676666
cdef long n_, axis_
677667
cdef int x_type, f_type, status
678-
cdef int direction = 1 # dummy, only used for the sake of arg-processing
679668
cdef char * c_error_msg = NULL
680669
cdef bytes py_error_msg
681670
cdef DftiCache *_cache
682671

683672
int_n = _is_integral(n)
684673
# nn gives the number elements along axis of the input that we use
685674
nn = (n // 2 + 1) if int_n and n > 0 else n
686-
x_arr = _process_arguments(
687-
x, nn, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0
688-
)
675+
x_arr = _process_arguments(x, nn, axis, &axis_, &n_, &in_place, &xnd, 0)
689676
n_ = 2*(n_ - 1)
690677
if int_n and (n % 2 == 1):
691678
n_ += 1
@@ -774,12 +761,10 @@ def _direct_fftnd(
774761
cdef int err
775762
cdef cnp.ndarray x_arr "xxnd_arrayObject"
776763
cdef cnp.ndarray f_arr "ffnd_arrayObject"
777-
cdef int dir_, in_place, x_type, f_type
764+
cdef int in_place, x_type, f_type
778765

779766
if direction not in [-1, +1]:
780767
raise ValueError("Direction of FFT should +1 or -1")
781-
else:
782-
dir_ = -1 if direction is -1 else +1
783768

784769
# convert x to ndarray, ensure that strides are multiples of itemsize
785770
x_arr = PyArray_CheckFromAny(
@@ -824,12 +809,12 @@ def _direct_fftnd(
824809

825810
if in_place:
826811
if x_type == cnp.NPY_CDOUBLE:
827-
if dir_ == 1:
812+
if direction == 1:
828813
err = cdouble_cdouble_mkl_fftnd_in(x_arr, fsc)
829814
else:
830815
err = cdouble_cdouble_mkl_ifftnd_in(x_arr, fsc)
831816
elif x_type == cnp.NPY_CFLOAT:
832-
if dir_ == 1:
817+
if direction == 1:
833818
err = cfloat_cfloat_mkl_fftnd_in(x_arr, fsc)
834819
else:
835820
err = cfloat_cfloat_mkl_ifftnd_in(x_arr, fsc)
@@ -856,22 +841,22 @@ def _direct_fftnd(
856841
f_arr = _allocate_result(x_arr, -1, 0, f_type)
857842

858843
if x_type == cnp.NPY_CDOUBLE:
859-
if dir_ == 1:
844+
if direction == 1:
860845
err = cdouble_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
861846
else:
862847
err = cdouble_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
863848
elif x_type == cnp.NPY_CFLOAT:
864-
if dir_ == 1:
849+
if direction == 1:
865850
err = cfloat_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
866851
else:
867852
err = cfloat_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)
868853
elif x_type == cnp.NPY_DOUBLE:
869-
if dir_ == 1:
854+
if direction == 1:
870855
err = double_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
871856
else:
872857
err = double_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
873858
elif x_type == cnp.NPY_FLOAT:
874-
if dir_ == 1:
859+
if direction == 1:
875860
err = float_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
876861
else:
877862
err = float_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)

0 commit comments

Comments
 (0)