Skip to content

Commit 0588591

Browse files
committed
torch: add count_nonzero
1 parent d1f0a1a commit 0588591

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,17 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
505505
raise ValueError("nonzero() does not support zero-dimensional arrays")
506506
return torch.nonzero(x, as_tuple=True, **kwargs)
507507

508+
# torch uses `dim` instead of `axis`
509+
def count_nonzero(
510+
x: array,
511+
/,
512+
*,
513+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
514+
keepdims: bool = False,
515+
) -> array:
516+
return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
517+
518+
508519
def where(condition: array, x1: array, x2: array, /) -> array:
509520
x1, x2 = _fix_promotion(x1, x2)
510521
return torch.where(condition, x1, x2)
@@ -753,7 +764,8 @@ def sign(x: array, /) -> array:
753764
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
754765
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
755766
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
756-
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide',
767+
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
768+
'divide',
757769
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
758770
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
759771
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',

0 commit comments

Comments
 (0)