From 28bff59da38aa2a910273b13d2ecd91fd6d49d56 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 19:36:59 +0000 Subject: [PATCH 01/10] ENH: add `sinc` --- docs/api-reference.md | 1 + src/array_api_extra/__init__.py | 4 +- src/array_api_extra/_funcs.py | 79 ++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/docs/api-reference.md b/docs/api-reference.md index a459743d..1dc3dfc9 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -10,4 +10,5 @@ cov expand_dims kron + sinc ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 7a46760b..30492748 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -from ._funcs import atleast_nd, cov, expand_dims, kron +from ._funcs import atleast_nd, cov, expand_dims, kron, sinc __version__ = "0.1.2.dev0" -__all__ = ["__version__", "atleast_nd", "cov", "expand_dims", "kron"] +__all__ = ["__version__", "atleast_nd", "cov", "expand_dims", "kron", "sinc"] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index a371768b..05a8d17e 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from ._typing import Array, ModuleType -__all__ = ["atleast_nd", "cov", "expand_dims", "kron"] +__all__ = ["atleast_nd", "cov", "expand_dims", "kron", "sinc"] def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: @@ -348,3 +348,80 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: a_shape = xp.asarray(a_shape) b_shape = xp.asarray(b_shape) return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape))) + + +def sinc(x: Array, /, *, xp: ModuleType) -> Array: + r""" + Return the normalized sinc function. + + The sinc function is equal to :math:`\sin(\pi x)/(\pi x)` for any argument + :math:`x\ne 0`. ``sinc(0)`` takes the limit value 1, making ``sinc`` not + only everywhere continuous but also infinitely differentiable. + + .. note:: + + Note the normalization factor of ``pi`` used in the definition. + This is the most commonly used definition in signal processing. + Use ``sinc(x / np.pi)`` to obtain the unnormalized sinc function + :math:`\sin(x)/x` that is more common in mathematics. + + Parameters + ---------- + x : array + Array (possibly multi-dimensional) of values for which to calculate + ``sinc(x)``. + + Returns + ------- + out : ndarray + ``sinc(x)`` calculated elementwise, which has the same shape as the input. + + Notes + ----- + The name sinc is short for "sine cardinal" or "sinus cardinalis". + + The sinc function is used in various signal processing applications, + including in anti-aliasing, in the construction of a Lanczos resampling + filter, and in interpolation. + + For bandlimited interpolation of discrete-time signals, the ideal + interpolation kernel is proportional to the sinc function. + + References + ---------- + .. [1] Weisstein, Eric W. "Sinc Function." From MathWorld--A Wolfram Web + Resource. https://mathworld.wolfram.com/SincFunction.html + .. [2] Wikipedia, "Sinc function", + https://en.wikipedia.org/wiki/Sinc_function + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.linspace(-4, 4, 41) + >>> xpx.sinc(x, xp=xp) + Array([-3.89817183e-17, -4.92362781e-02, + -8.40918587e-02, -8.90384387e-02, + -5.84680802e-02, 3.89817183e-17, + 6.68206631e-02, 1.16434881e-01, + 1.26137788e-01, 8.50444803e-02, + -3.89817183e-17, -1.03943254e-01, + -1.89206682e-01, -2.16236208e-01, + -1.55914881e-01, 3.89817183e-17, + 2.33872321e-01, 5.04551152e-01, + 7.56826729e-01, 9.35489284e-01, + 1.00000000e+00, 9.35489284e-01, + 7.56826729e-01, 5.04551152e-01, + 2.33872321e-01, 3.89817183e-17, + -1.55914881e-01, -2.16236208e-01, + -1.89206682e-01, -1.03943254e-01, + -3.89817183e-17, 8.50444803e-02, + 1.26137788e-01, 1.16434881e-01, + 6.68206631e-02, 3.89817183e-17, + -5.84680802e-02, -8.90384387e-02, + -8.40918587e-02, -4.92362781e-02, + -3.89817183e-17], dtype=array_api_strict.float64) + + """ + y = xp.pi * xp.where(x == 0, xp.asarray(1.0e-20), x) + return xp.sin(y) / y From 7baaa4af0b151f617ca08d3bf3b8e70fae8f4e6c Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 19:52:07 +0000 Subject: [PATCH 02/10] TST: sinc: add tests --- src/array_api_extra/_funcs.py | 7 +++++-- tests/test_funcs.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 05a8d17e..5e72d3af 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -367,9 +367,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: Parameters ---------- - x : array + x : array of floats Array (possibly multi-dimensional) of values for which to calculate - ``sinc(x)``. + ``sinc(x)``. Should have a floating point dtype. Returns ------- @@ -423,5 +423,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: -3.89817183e-17], dtype=array_api_strict.float64) """ + if not xp.isdtype(x.dtype, "real floating"): + err_msg = "`x` must have a real floating data type." + raise ValueError(err_msg) y = xp.pi * xp.where(x == 0, xp.asarray(1.0e-20), x) return xp.sin(y) / y diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 556add12..b37d32bc 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -9,7 +9,7 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal, assert_equal -from array_api_extra import atleast_nd, cov, expand_dims, kron +from array_api_extra import atleast_nd, cov, expand_dims, kron, sinc if TYPE_CHECKING: Array = Any # To be changed to a Protocol later (see array-api#589) @@ -224,3 +224,16 @@ def test_positive_negative_repeated(self): a = xp.empty((2, 3, 4, 5)) with pytest.raises(ValueError, match="Duplicate dimensions"): expand_dims(a, axis=(3, -3), xp=xp) + + +class TestSinc: + def test_simple(self): + assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + w = sinc(xp.linspace(-1, 1, 100), xp=xp) + # check symmetry + assert_allclose(w, xp.flip(w, axis=0)) + + @pytest.mark.parametrize("x", [0, 1 + 3j]) + def test_dtype(self, x): + with pytest.raises(ValueError, match="real floating data type"): + sinc(xp.asarray(x), xp=xp) From 7eed00d123a1a98d8fd030f8fad431e1b1537c14 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 19:53:12 +0000 Subject: [PATCH 03/10] DOC: sinc: tweak --- src/array_api_extra/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 5e72d3af..a13aebb4 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -373,7 +373,7 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: Returns ------- - out : ndarray + res : array ``sinc(x)`` calculated elementwise, which has the same shape as the input. Notes From 68e0a6f2f2f27ca48857b685d68978e52a7ea01e Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 19:53:54 +0000 Subject: [PATCH 04/10] DOC: sinc: more tweaks --- src/array_api_extra/_funcs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index a13aebb4..84db814f 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -367,9 +367,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: Parameters ---------- - x : array of floats + x : array Array (possibly multi-dimensional) of values for which to calculate - ``sinc(x)``. Should have a floating point dtype. + ``sinc(x)``. Must have a real floating point dtype. Returns ------- From 1fe453fb67ebfb2e6753e97a6bea7ea62e447b22 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 20:01:38 +0000 Subject: [PATCH 05/10] TST: sinc: add 3D test --- tests/test_funcs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index b37d32bc..30e71eba 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -237,3 +237,9 @@ def test_simple(self): def test_dtype(self, x): with pytest.raises(ValueError, match="real floating data type"): sinc(xp.asarray(x), xp=xp) + + def test_3d(self): + x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2)) + expected = xp.zeros((3, 3, 2)) + expected[0, 0, 0] = 1.0 + assert_allclose(sinc(x, xp=xp), expected, atol=1e-15) From bb17b5efbff57e146f905669544e795b80d23564 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 21:06:27 +0000 Subject: [PATCH 06/10] improve `where` call Co-authored-by: jakirkham --- src/array_api_extra/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 84db814f..7734f25b 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -426,5 +426,5 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: if not xp.isdtype(x.dtype, "real floating"): err_msg = "`x` must have a real floating data type." raise ValueError(err_msg) - y = xp.pi * xp.where(x == 0, xp.asarray(1.0e-20), x) + y = xp.pi * xp.where(x, x, xp.finfo(x.dtype).smallest_normal) return xp.sin(y) / y From 9ce01a95bba13aebda667a8795709c979aad000a Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 21:11:19 +0000 Subject: [PATCH 07/10] Apply suggestions from code review --- src/array_api_extra/_funcs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 7734f25b..f80f75e4 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -362,7 +362,7 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: Note the normalization factor of ``pi`` used in the definition. This is the most commonly used definition in signal processing. - Use ``sinc(x / np.pi)`` to obtain the unnormalized sinc function + Use ``sinc(x / xp.pi)`` to obtain the unnormalized sinc function :math:`\sin(x)/x` that is more common in mathematics. Parameters @@ -426,5 +426,6 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: if not xp.isdtype(x.dtype, "real floating"): err_msg = "`x` must have a real floating data type." raise ValueError(err_msg) - y = xp.pi * xp.where(x, x, xp.finfo(x.dtype).smallest_normal) + # no scalars in `where` - array-api#807 + y = xp.pi * xp.where(x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal)) return xp.sin(y) / y From b1d5edc6e99f91176744bcf97d1c178fe28acf5b Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 21:35:42 +0000 Subject: [PATCH 08/10] DOC: sinc: add `xp` param --- src/array_api_extra/_funcs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index f80f75e4..92b3b3eb 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -370,6 +370,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: x : array Array (possibly multi-dimensional) of values for which to calculate ``sinc(x)``. Must have a real floating point dtype. + xp : array_namespace + The standard-compatible namespace for `x`. Returns ------- From 59f14283b7819c58671ce24e12fac08b54bd0560 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 17 Nov 2024 19:35:48 +0000 Subject: [PATCH 09/10] Update _funcs.py Co-authored-by: jakirkham --- src/array_api_extra/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 92b3b3eb..6b465138 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -429,5 +429,5 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: err_msg = "`x` must have a real floating data type." raise ValueError(err_msg) # no scalars in `where` - array-api#807 - y = xp.pi * xp.where(x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal)) + y = xp.pi * xp.where(x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype)) return xp.sin(y) / y From 4e5c0d819cafbdfff34707e19b931d7dd4160224 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 17 Nov 2024 19:42:10 +0000 Subject: [PATCH 10/10] appease formatter --- src/array_api_extra/_funcs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 6b465138..459db639 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -429,5 +429,7 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: err_msg = "`x` must have a real floating data type." raise ValueError(err_msg) # no scalars in `where` - array-api#807 - y = xp.pi * xp.where(x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype)) + y = xp.pi * xp.where( + x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype) + ) return xp.sin(y) / y