@@ -505,6 +505,17 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
505
505
raise ValueError ("nonzero() does not support zero-dimensional arrays" )
506
506
return torch .nonzero (x , as_tuple = True , ** kwargs )
507
507
508
+ # torch uses `dim` instead of `axis`
509
+ def count_nonzero (
510
+ x : array ,
511
+ / ,
512
+ * ,
513
+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
514
+ keepdims : bool = False ,
515
+ ) -> array :
516
+ return torch .count_nonzero (x , dim = axis , keepdims = keepdims )
517
+
518
+
508
519
def where (condition : array , x1 : array , x2 : array , / ) -> array :
509
520
x1 , x2 = _fix_promotion (x1 , x2 )
510
521
return torch .where (condition , x1 , x2 )
@@ -753,7 +764,8 @@ def sign(x: array, /) -> array:
753
764
__all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
754
765
'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
755
766
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
756
- 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'divide' ,
767
+ 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
768
+ 'divide' ,
757
769
'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
758
770
'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
759
771
'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
0 commit comments