From 37ce6702ff38cac24ae80bbb390551f30ccf28ac Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 14 Jan 2025 16:57:22 +0000 Subject: [PATCH 1/4] nunique --- docs/api-reference.md | 1 + src/array_api_extra/__init__.py | 2 ++ src/array_api_extra/_funcs.py | 40 +++++++++++++++++++++++++++++++++ tests/test_funcs.py | 19 ++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/docs/api-reference.md b/docs/api-reference.md index b43c960f..279c84c4 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -12,6 +12,7 @@ create_diagonal expand_dims kron + nunique setdiff1d sinc ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index a4f6815f..3f973307 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -7,6 +7,7 @@ create_diagonal, expand_dims, kron, + nunique, pad, setdiff1d, sinc, @@ -23,6 +24,7 @@ "create_diagonal", "expand_dims", "kron", + "nunique", "pad", "setdiff1d", "sinc", diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 7502561a..8c8ebc3d 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -3,6 +3,7 @@ # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 from __future__ import annotations +import math import operator import warnings from collections.abc import Callable @@ -13,8 +14,10 @@ from ._lib import _compat, _utils from ._lib._compat import ( array_namespace, + device, is_jax_array, is_writeable_array, + size, ) from ._lib._typing import Array, Index @@ -25,6 +28,7 @@ "create_diagonal", "expand_dims", "kron", + "nunique", "pad", "setdiff1d", "sinc", @@ -638,6 +642,42 @@ def pad( return at(padded, tuple(slices)).set(x) +def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: + """ + Count the number of unique elements in an array. + + Compatible with JAX and Dask, whose laziness would be otherwise + problematic. + + Parameters + ---------- + x : Array + Input array. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array: Scalar integer array + The number of unique elements in `x`. It can be lazy. + """ + if xp is None: + xp = array_namespace(x) + + if is_jax_array(x): + # size= is JAX-specific + # https://github.com/data-apis/array-api/issues/883 + _, counts = xp.unique_counts(x, size=size(x)) + return xp.astype(counts, xp.bool).sum() + + _, counts = xp.unique_counts(x) + n = size(counts) + # FIXME https://github.com/data-apis/array-api-compat/pull/231 + if n is None or math.isnan(n): # e.g. Dask, ndonnx + return xp.astype(counts, xp.bool).sum() + return xp.asarray(n, device=device(x)) + + class _AtOp(Enum): """Operations for use in `xpx.at`.""" diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 5f18ef61..201295da 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -11,6 +11,7 @@ create_diagonal, expand_dims, kron, + nunique, pad, setdiff1d, sinc, @@ -448,3 +449,21 @@ def test_list_of_tuples_width(self, xp: ModuleType): padded = pad(a, [(1, 0), (0, 0)]) assert padded.shape == (4, 4) + + +class TestNUnique: + def test_simple(self, xp: ModuleType): + a = xp.asarray([[1, 1], [0, 2], [2, 2]]) + xp_assert_equal(nunique(a), xp.asarray(3)) + + def test_empty(self, xp: ModuleType): + a = xp.asarray([]) + xp_assert_equal(nunique(a), xp.asarray(0)) + + def test_device(self, xp: ModuleType, device: Device): + a = xp.asarray(0.0, device=device) + assert get_device(nunique(a)) == device + + def test_xp(self, xp: ModuleType): + a = xp.asarray([[1, 1], [0, 2], [2, 2]]) + xp_assert_equal(nunique(a, xp=xp), xp.asarray(3)) From 46f8bb65393a7ceb4bde28e3b288a6d48f807598 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 14 Jan 2025 17:04:31 +0000 Subject: [PATCH 2/4] Use _compat ns --- src/array_api_extra/_funcs.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 8c8ebc3d..4c2cec4e 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -14,10 +14,8 @@ from ._lib import _compat, _utils from ._lib._compat import ( array_namespace, - device, is_jax_array, is_writeable_array, - size, ) from ._lib._typing import Array, Index @@ -667,15 +665,15 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: if is_jax_array(x): # size= is JAX-specific # https://github.com/data-apis/array-api/issues/883 - _, counts = xp.unique_counts(x, size=size(x)) + _, counts = xp.unique_counts(x, size=_compat.size(x)) return xp.astype(counts, xp.bool).sum() _, counts = xp.unique_counts(x) - n = size(counts) + n = _compat.size(counts) # FIXME https://github.com/data-apis/array-api-compat/pull/231 if n is None or math.isnan(n): # e.g. Dask, ndonnx return xp.astype(counts, xp.bool).sum() - return xp.asarray(n, device=device(x)) + return xp.asarray(n, device=_compat.device(x)) class _AtOp(Enum): From 33947e96a72d8fd69f065e9043efd9cf7be4dd45 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 14 Jan 2025 17:09:05 +0000 Subject: [PATCH 3/4] Disable too-many-lines --- pixi.lock | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pixi.lock b/pixi.lock index 2790b207..ed262f2f 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2469,7 +2469,7 @@ packages: - pypi: . name: array-api-extra version: 0.5.1.dev0 - sha256: 8b4533cc75534abb69425a1e5c9f6a4ab96949562d2e90d41ea0e22187a02c1b + sha256: 09d6a4b1405fd64596379826065a09bc3787a4fc4e1535dc369f74a3b96f86e3 requires_dist: - array-api-compat>=1.10.0,<2 - furo>=2023.8.17 ; extra == 'docs' diff --git a/pyproject.toml b/pyproject.toml index 4f5ddac0..a5594541 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -299,6 +299,7 @@ messages_control.disable = [ "line-too-long", "missing-module-docstring", "missing-function-docstring", + "too-many-lines", "wrong-import-position", ] From 07b278256317234925da10237f79de761b2b9d48 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 14 Jan 2025 17:18:21 +0000 Subject: [PATCH 4/4] Update src/array_api_extra/_funcs.py --- src/array_api_extra/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 4c2cec4e..017c7297 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -656,7 +656,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: Returns ------- - array: Scalar integer array + array: 0-dimensional integer array The number of unique elements in `x`. It can be lazy. """ if xp is None: