Skip to content

Commit f00a882

Browse files
committed
ENH: add count_nonzero
1 parent cf3f717 commit f00a882

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,9 @@
293293

294294
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]
295295

296-
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
296+
from ._searching_functions import argmax, argmin, nonzero, count_nonzero, searchsorted, where
297297

298-
__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]
298+
__all__ += ["argmax", "argmin", "nonzero", "count_nonzero", "searchsorted", "where"]
299299

300300
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
301301

array_api_strict/_searching_functions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
from typing import Literal, Optional, Tuple
9+
from typing import Literal, Optional, Tuple, Union
1010

1111
import numpy as np
1212

@@ -45,6 +45,24 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]:
4545
raise ValueError("nonzero is not allowed on 0-dimensional arrays")
4646
return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array))
4747

48+
49+
@requires_api_version('2024.12')
50+
def count_nonzero(
51+
x: Array,
52+
/,
53+
*,
54+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
55+
keepdims: bool = False,
56+
) -> Array:
57+
"""
58+
Array API compatible wrapper for :py:func:`np.count_nonzero <numpy.count_nonzero>`
59+
60+
See its docstring for more information.
61+
"""
62+
arr = np.count_nonzero(x._array, axis=axis, keepdims=keepdims)
63+
return Array._new(np.asarray(arr), device=x.device)
64+
65+
4866
@requires_api_version('2023.12')
4967
def searchsorted(
5068
x1: Array,

array_api_strict/tests/test_flags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def test_api_version_2023_12(func_name):
307307
'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])),
308308
'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)),
309309
xp.zeros((1, 4), dtype=xp.int64)),
310+
'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)),
310311
}
311312

312313
@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys())

0 commit comments

Comments
 (0)