From 83b658131d2cce5d63328c69bb0fa88b3eb43a3a Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 25 Apr 2025 12:50:43 -0700 Subject: [PATCH 1/3] update handling shape and axes of scipy interface --- .github/workflows/conda-package.yml | 8 +- CHANGELOG.md | 3 +- mkl_fft/_pydfti.pyx | 8 +- mkl_fft/interfaces/_scipy_fft.py | 156 +++++++++++++--------------- 4 files changed, 83 insertions(+), 92 deletions(-) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 86afbe76..f3bc08d2 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -131,7 +131,8 @@ jobs: - name: Install mkl_fft run: | CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}" - conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} $PACKAGE_NAME pytest scipy $CHANNELS + conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} "scipy>=1.10" $CHANNELS + conda install -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME pytest $CHANNELS # Test installed packages conda list -n ${{ env.TEST_ENV_NAME }} @@ -295,8 +296,8 @@ jobs: FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO ( SET PACKAGE_VERSION=%%F ) - SET "TEST_DEPENDENCIES=pytest pytest-cov" - conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }} + SET "TEST_DEPENDENCIES=pytest scipy" + conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }} - name: Report content of test environment shell: cmd /C CALL {0} @@ -304,6 +305,7 @@ jobs: echo "Value of CONDA environment variable was: " %CONDA% echo "Value of CONDA_PREFIX environment variable was: " %CONDA_PREFIX% conda info && conda list -n ${{ env.TEST_ENV_NAME }} + - name: Run tests shell: cmd /C CALL {0} run: >- diff --git a/CHANGELOG.md b/CHANGELOG.md index ebb5d30e..2d5cc3ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added support for `out` kwarg to all FFT functions in `mkl_fft` and `mkl_fft.interfaces.numpy_fft` [gh-157](https://github.com/IntelPython/mkl_fft/pull/157) ### Changed -* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.* [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157) +* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157) +* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs[gh-181](https://github.com/IntelPython/mkl_fft/pull/181) ## [1.3.14] (04/10/2025) diff --git a/mkl_fft/_pydfti.pyx b/mkl_fft/_pydfti.pyx index 74f1d69c..e6d1d221 100644 --- a/mkl_fft/_pydfti.pyx +++ b/mkl_fft/_pydfti.pyx @@ -89,7 +89,7 @@ def _tls_dfti_cache_capsule(): cdef extern from "Python.h": ctypedef int size_t - long PyInt_AsLong(object ob) + long PyLong_AsLong(object ob) int PyObject_HasAttrString(object, char*) @@ -262,7 +262,7 @@ cdef cnp.ndarray _process_arguments( xnd[0] = cnp.PyArray_NDIM(x_arr) # tensor-rank of the array err = 0 - axis_[0] = PyInt_AsLong(axis) + axis_[0] = PyLong_AsLong(axis) if (axis_[0] == -1 and PyErr_Occurred()): PyErr_Clear() err = 1 @@ -278,7 +278,7 @@ cdef cnp.ndarray _process_arguments( n_[0] = x_arr.shape[axis_[0]] else: try: - n_[0] = PyInt_AsLong(n) + n_[0] = PyLong_AsLong(n) except: err = 1 @@ -334,7 +334,7 @@ cdef int _is_integral(object num): if num is None: return 0 try: - n = PyInt_AsLong(num) + n = PyLong_AsLong(num) _integral = 1 if n > 0 else 0 except: _integral = 0 diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index e310da1d..8443eaca 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -33,6 +33,7 @@ import contextvars import operator import os +from numbers import Number import mkl import numpy as np @@ -194,30 +195,65 @@ def _check_plan(plan): ) -def _check_overwrite_x(overwrite_x): - if overwrite_x: - raise NotImplementedError( - "Overwriting the content of `x` is currently not supported" - ) +# copied from scipy.fft._pocketfft.helper +# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py +def _iterable_of_int(x, name=None): + if isinstance(x, Number): + x = (x,) + try: + x = [operator.index(a) for a in x] + except TypeError as e: + name = name or "value" + raise ValueError( + f"{name} must be a scalar or iterable of integers" + ) from e -def _cook_nd_args(x, s=None, axes=None, invreal=False): - if s is None: - shapeless = True - if axes is None: - s = list(x.shape) - else: - s = np.take(x.shape, axes) + return x + + +# copied and modified from scipy.fft._pocketfft.helper +# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py +def _init_nd_shape_and_axes(x, shape, axes, invreal=False): + noshape = shape is None + noaxes = axes is None + + if not noaxes: + axes = _iterable_of_int(axes, "axes") + axes = [a + x.ndim if a < 0 else a for a in axes] + + if any(a >= x.ndim or a < 0 for a in axes): + raise ValueError("axes exceeds dimensionality of input") + if len(set(axes)) != len(axes): + raise ValueError("all axes must be unique") + + if not noshape: + shape = _iterable_of_int(shape, "shape") + + if axes and len(axes) != len(shape): + raise ValueError( + "when given, axes and shape arguments" + " have to be of the same length" + ) + if noaxes: + if len(shape) > x.ndim: + raise ValueError("shape requires more axes than are present") + axes = range(x.ndim - len(shape), x.ndim) + + shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)] + elif noaxes: + shape = list(x.shape) + axes = range(x.ndim) else: - shapeless = False - s = list(s) - if axes is None: - axes = list(range(-len(s), 0)) - if len(s) != len(axes): - raise ValueError("Shape and axes have different lengths.") - if invreal and shapeless: - s[-1] = (x.shape[axes[-1]] - 1) * 2 - return s, axes + shape = [x.shape[a] for a in axes] + + if noshape and invreal: + shape[-1] = (x.shape[axes[-1]] - 1) * 2 + + if any(s < 1 for s in shape): + raise ValueError(f"invalid number of data points ({shape}) specified") + + return tuple(shape), list(axes) def _validate_input(x): @@ -339,7 +375,7 @@ def fftn( """ _check_plan(plan) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): @@ -366,7 +402,7 @@ def ifftn( """ _check_plan(plan) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): @@ -383,17 +419,13 @@ def rfft( For full documentation refer to `scipy.fft.rfft`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) fsc = _compute_fwd_scale(norm, n, x.shape[axis]) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.rfft(x, n=n, axis=axis, fwd_scale=fsc) @@ -405,17 +437,13 @@ def irfft( For full documentation refer to `scipy.fft.irfft`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfft(x, n=n, axis=axis, fwd_scale=fsc) @@ -434,10 +462,6 @@ def rfft2( For full documentation refer to `scipy.fft.rfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return rfftn( x, @@ -465,10 +489,6 @@ def irfft2( For full documentation refer to `scipy.fft.irfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return irfftn( x, @@ -496,18 +516,14 @@ def rfftn( For full documentation refer to `scipy.fft.rfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.rfftn(x, s, axes, fwd_scale=fsc) @@ -526,18 +542,14 @@ def irfftn( For full documentation refer to `scipy.fft.irfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes, invreal=True) + s, axes = _init_nd_shape_and_axes(x, s, axes, invreal=True) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfftn(x, s, axes, fwd_scale=fsc) @@ -548,15 +560,10 @@ def hfft( Compute the FFT of a signal that has Hermitian symmetry, i.e., a real spectrum. - For full documentation refer to `scipy.fft.hfft`. - - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. + For full documentation refer to `scipy.fft.hfft`.. """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) x = np.array(x, copy=True) @@ -564,6 +571,7 @@ def hfft( fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfft(x, n=n, axis=axis, fwd_scale=fsc) @@ -575,18 +583,14 @@ def ihfft( For full documentation refer to `scipy.fft.ihfft`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) fsc = _compute_fwd_scale(norm, n, x.shape[axis]) with _Workers(workers): + # Note: overwrite_x is not utilized result = mkl_fft.rfft(x, n=n, axis=axis, fwd_scale=fsc) np.conjugate(result, out=result) @@ -608,10 +612,6 @@ def hfft2( For full documentation refer to `scipy.fft.hfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return hfftn( x, @@ -639,10 +639,6 @@ def ihfft2( For full documentation refer to `scipy.fft.ihfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return ihfftn( x, @@ -671,21 +667,17 @@ def hfftn( For full documentation refer to `scipy.fft.hfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) x = np.array(x, copy=True) np.conjugate(x, out=x) - s, axes = _cook_nd_args(x, s, axes, invreal=True) + s, axes = _init_nd_shape_and_axes(x, s, axes, invreal=True) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfftn(x, s, axes, fwd_scale=fsc) @@ -704,19 +696,15 @@ def ihfftn( For full documentation refer to `scipy.fft.ihfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized result = mkl_fft.rfftn(x, s, axes, fwd_scale=fsc) np.conjugate(result, out=result) From a337e491264cc34b8b21ddb7290e4726db7fef31 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad <120411540+vtavana@users.noreply.github.com> Date: Fri, 9 May 2025 09:28:01 -0500 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- CHANGELOG.md | 2 +- mkl_fft/interfaces/_scipy_fft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d5cc3ec..cea0cb1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed * NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157) -* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs[gh-181](https://github.com/IntelPython/mkl_fft/pull/181) +* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs [gh-181](https://github.com/IntelPython/mkl_fft/pull/181) ## [1.3.14] (04/10/2025) diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index 8443eaca..331dae0a 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -560,7 +560,7 @@ def hfft( Compute the FFT of a signal that has Hermitian symmetry, i.e., a real spectrum. - For full documentation refer to `scipy.fft.hfft`.. + For full documentation refer to `scipy.fft.hfft`. """ _check_plan(plan) From 02761881a3daea082e04cc135846d3c35bc3f835 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 13 May 2025 10:17:46 -0700 Subject: [PATCH 3/3] include a few skipped tests --- mkl_fft/tests/third_party/scipy/test_basic.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mkl_fft/tests/third_party/scipy/test_basic.py b/mkl_fft/tests/third_party/scipy/test_basic.py index 470bbfd4..19b00fa2 100644 --- a/mkl_fft/tests/third_party/scipy/test_basic.py +++ b/mkl_fft/tests/third_party/scipy/test_basic.py @@ -230,7 +230,6 @@ def test_irfftn(self, xp): for norm in ["backward", "ortho", "forward"]: xp_assert_close(fft.irfftn(fft.rfftn(x, norm=norm), norm=norm), x) - @pytest.mark.skip("hfft is not supported") def test_hfft(self, xp): x = random(14) + 1j * random(14) x_herm = np.concatenate((random(1), x, random(1))) @@ -246,7 +245,6 @@ def test_hfft(self, xp): ) xp_assert_close(fft.hfft(x_herm, norm="forward"), expect / 30) - @pytest.mark.skip("ihfft is not supported") def test_ihfft(self, xp): x = random(14) + 1j * random(14) x_herm = np.concatenate((random(1), x, random(1))) @@ -259,14 +257,12 @@ def test_ihfft(self, xp): fft.ihfft(fft.hfft(x_herm, norm=norm), norm=norm), x_herm ) - @pytest.mark.skip("hfft2 is not supported") def test_hfft2(self, xp): x = xp.asarray(random((30, 20))) xp_assert_close(fft.hfft2(fft.ihfft2(x)), x) for norm in ["backward", "ortho", "forward"]: xp_assert_close(fft.hfft2(fft.ihfft2(x, norm=norm), norm=norm), x) - @pytest.mark.skip("ihfft2 is not supported") def test_ihfft2(self, xp): x = xp.asarray(random((30, 20)), dtype=xp.float64) expect = fft.ifft2(xp.asarray(x, dtype=xp.complex128))[:, :11] @@ -278,14 +274,12 @@ def test_ihfft2(self, xp): ) xp_assert_close(fft.ihfft2(x, norm="forward"), expect * (30 * 20)) - @pytest.mark.skip("hfftn is not supported") def test_hfftn(self, xp): x = xp.asarray(random((30, 20, 10))) xp_assert_close(fft.hfftn(fft.ihfftn(x)), x) for norm in ["backward", "ortho", "forward"]: xp_assert_close(fft.hfftn(fft.ihfftn(x, norm=norm), norm=norm), x) - @pytest.mark.skip("ihfftn is not supported") def test_ihfftn(self, xp): x = xp.asarray(random((30, 20, 10)), dtype=xp.float64) expect = fft.ifftn(xp.asarray(x, dtype=xp.complex128))[:, :, :6]