diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ad2a04a..d3d2368b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed * NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157) * To set `mkl_fft` as the backend for SciPy is only possible through `mkl_fft.interfaces.scipy_fft` [gh-179](https://github.com/IntelPython/mkl_fft/pull/179) +* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs [gh-181](https://github.com/IntelPython/mkl_fft/pull/181) ## [1.3.14] (04/10/2025) diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index d1f8facb..fa8482e3 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -33,6 +33,7 @@ import contextvars import operator import os +from numbers import Number import mkl import numpy as np @@ -156,30 +157,65 @@ def _check_plan(plan): ) -def _check_overwrite_x(overwrite_x): - if overwrite_x: - raise NotImplementedError( - "Overwriting the content of `x` is currently not supported" - ) +# copied from scipy.fft._pocketfft.helper +# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py +def _iterable_of_int(x, name=None): + if isinstance(x, Number): + x = (x,) + try: + x = [operator.index(a) for a in x] + except TypeError as e: + name = name or "value" + raise ValueError( + f"{name} must be a scalar or iterable of integers" + ) from e -def _cook_nd_args(x, s=None, axes=None, invreal=False): - if s is None: - shapeless = True - if axes is None: - s = list(x.shape) - else: - s = np.take(x.shape, axes) + return x + + +# copied and modified from scipy.fft._pocketfft.helper +# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py +def _init_nd_shape_and_axes(x, shape, axes, invreal=False): + noshape = shape is None + noaxes = axes is None + + if not noaxes: + axes = _iterable_of_int(axes, "axes") + axes = [a + x.ndim if a < 0 else a for a in axes] + + if any(a >= x.ndim or a < 0 for a in axes): + raise ValueError("axes exceeds dimensionality of input") + if len(set(axes)) != len(axes): + raise ValueError("all axes must be unique") + + if not noshape: + shape = _iterable_of_int(shape, "shape") + + if axes and len(axes) != len(shape): + raise ValueError( + "when given, axes and shape arguments" + " have to be of the same length" + ) + if noaxes: + if len(shape) > x.ndim: + raise ValueError("shape requires more axes than are present") + axes = range(x.ndim - len(shape), x.ndim) + + shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)] + elif noaxes: + shape = list(x.shape) + axes = range(x.ndim) else: - shapeless = False - s = list(s) - if axes is None: - axes = list(range(-len(s), 0)) - if len(s) != len(axes): - raise ValueError("Shape and axes have different lengths.") - if invreal and shapeless: - s[-1] = (x.shape[axes[-1]] - 1) * 2 - return s, axes + shape = [x.shape[a] for a in axes] + + if noshape and invreal: + shape[-1] = (x.shape[axes[-1]] - 1) * 2 + + if any(s < 1 for s in shape): + raise ValueError(f"invalid number of data points ({shape}) specified") + + return tuple(shape), list(axes) def _validate_input(x): @@ -301,7 +337,7 @@ def fftn( """ _check_plan(plan) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): @@ -328,7 +364,7 @@ def ifftn( """ _check_plan(plan) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): @@ -345,17 +381,13 @@ def rfft( For full documentation refer to `scipy.fft.rfft`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) fsc = _compute_fwd_scale(norm, n, x.shape[axis]) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.rfft(x, n=n, axis=axis, fwd_scale=fsc) @@ -367,17 +399,13 @@ def irfft( For full documentation refer to `scipy.fft.irfft`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfft(x, n=n, axis=axis, fwd_scale=fsc) @@ -396,10 +424,6 @@ def rfft2( For full documentation refer to `scipy.fft.rfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return rfftn( x, @@ -427,10 +451,6 @@ def irfft2( For full documentation refer to `scipy.fft.irfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return irfftn( x, @@ -458,18 +478,14 @@ def rfftn( For full documentation refer to `scipy.fft.rfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.rfftn(x, s, axes, fwd_scale=fsc) @@ -488,18 +504,14 @@ def irfftn( For full documentation refer to `scipy.fft.irfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) - s, axes = _cook_nd_args(x, s, axes, invreal=True) + s, axes = _init_nd_shape_and_axes(x, s, axes, invreal=True) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfftn(x, s, axes, fwd_scale=fsc) @@ -512,13 +524,8 @@ def hfft( For full documentation refer to `scipy.fft.hfft`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) x = np.array(x, copy=True) @@ -526,6 +533,7 @@ def hfft( fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfft(x, n=n, axis=axis, fwd_scale=fsc) @@ -537,18 +545,14 @@ def ihfft( For full documentation refer to `scipy.fft.ihfft`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) fsc = _compute_fwd_scale(norm, n, x.shape[axis]) with _Workers(workers): + # Note: overwrite_x is not utilized result = mkl_fft.rfft(x, n=n, axis=axis, fwd_scale=fsc) np.conjugate(result, out=result) @@ -570,10 +574,6 @@ def hfft2( For full documentation refer to `scipy.fft.hfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return hfftn( x, @@ -601,10 +601,6 @@ def ihfft2( For full documentation refer to `scipy.fft.ihfft2`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ return ihfftn( x, @@ -633,21 +629,17 @@ def hfftn( For full documentation refer to `scipy.fft.hfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) x = np.array(x, copy=True) np.conjugate(x, out=x) - s, axes = _cook_nd_args(x, s, axes, invreal=True) + s, axes = _init_nd_shape_and_axes(x, s, axes, invreal=True) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized return mkl_fft.irfftn(x, s, axes, fwd_scale=fsc) @@ -666,19 +658,15 @@ def ihfftn( For full documentation refer to `scipy.fft.ihfftn`. - Limitation - ----------- - The kwarg `overwrite_x` is only supported with its default value. - """ _check_plan(plan) - _check_overwrite_x(overwrite_x) x = _validate_input(x) norm = _swap_direction(norm) - s, axes = _cook_nd_args(x, s, axes) + s, axes = _init_nd_shape_and_axes(x, s, axes) fsc = _compute_fwd_scale(norm, s, x.shape) with _Workers(workers): + # Note: overwrite_x is not utilized result = mkl_fft.rfftn(x, s, axes, fwd_scale=fsc) np.conjugate(result, out=result) diff --git a/mkl_fft/tests/third_party/scipy/test_basic.py b/mkl_fft/tests/third_party/scipy/test_basic.py index 470bbfd4..19b00fa2 100644 --- a/mkl_fft/tests/third_party/scipy/test_basic.py +++ b/mkl_fft/tests/third_party/scipy/test_basic.py @@ -230,7 +230,6 @@ def test_irfftn(self, xp): for norm in ["backward", "ortho", "forward"]: xp_assert_close(fft.irfftn(fft.rfftn(x, norm=norm), norm=norm), x) - @pytest.mark.skip("hfft is not supported") def test_hfft(self, xp): x = random(14) + 1j * random(14) x_herm = np.concatenate((random(1), x, random(1))) @@ -246,7 +245,6 @@ def test_hfft(self, xp): ) xp_assert_close(fft.hfft(x_herm, norm="forward"), expect / 30) - @pytest.mark.skip("ihfft is not supported") def test_ihfft(self, xp): x = random(14) + 1j * random(14) x_herm = np.concatenate((random(1), x, random(1))) @@ -259,14 +257,12 @@ def test_ihfft(self, xp): fft.ihfft(fft.hfft(x_herm, norm=norm), norm=norm), x_herm ) - @pytest.mark.skip("hfft2 is not supported") def test_hfft2(self, xp): x = xp.asarray(random((30, 20))) xp_assert_close(fft.hfft2(fft.ihfft2(x)), x) for norm in ["backward", "ortho", "forward"]: xp_assert_close(fft.hfft2(fft.ihfft2(x, norm=norm), norm=norm), x) - @pytest.mark.skip("ihfft2 is not supported") def test_ihfft2(self, xp): x = xp.asarray(random((30, 20)), dtype=xp.float64) expect = fft.ifft2(xp.asarray(x, dtype=xp.complex128))[:, :11] @@ -278,14 +274,12 @@ def test_ihfft2(self, xp): ) xp_assert_close(fft.ihfft2(x, norm="forward"), expect * (30 * 20)) - @pytest.mark.skip("hfftn is not supported") def test_hfftn(self, xp): x = xp.asarray(random((30, 20, 10))) xp_assert_close(fft.hfftn(fft.ihfftn(x)), x) for norm in ["backward", "ortho", "forward"]: xp_assert_close(fft.hfftn(fft.ihfftn(x, norm=norm), norm=norm), x) - @pytest.mark.skip("ihfftn is not supported") def test_ihfftn(self, xp): x = xp.asarray(random((30, 20, 10)), dtype=xp.float64) expect = fft.ifftn(xp.asarray(x, dtype=xp.complex128))[:, :, :6]