@@ -88,6 +88,44 @@ def test_argmin(x, data):
88
88
ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
89
89
90
90
91
+ @pytest .mark .min_version ("2024.12" )
92
+ @given (
93
+ x = hh .arrays (
94
+ dtype = hh .real_dtypes ,
95
+ shape = hh .shapes (min_dims = 1 , min_side = 1 ),
96
+ elements = {"allow_nan" : False },
97
+ ),
98
+ data = st .data (),
99
+ )
100
+ def test_count_nonzero (x , data ):
101
+ kw = data .draw (
102
+ hh .kwargs (
103
+ axis = st .none () | st .integers (- x .ndim , max (x .ndim - 1 , 0 )),
104
+ keepdims = st .booleans (),
105
+ ),
106
+ label = "kw" ,
107
+ )
108
+ keepdims = kw .get ("keepdims" , False )
109
+
110
+ out = xp .count_nonzero (x , ** kw )
111
+
112
+ ph .assert_default_index ("count_nonzero" , out .dtype )
113
+ axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
114
+ ph .assert_keepdimable_shape (
115
+ "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
116
+ )
117
+ scalar_type = dh .get_scalar_type (x .dtype )
118
+
119
+ for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
120
+ count = int (out [out_idx ])
121
+ elements = []
122
+ for idx in indices :
123
+ s = scalar_type (x [idx ])
124
+ elements .append (s )
125
+ expected = sum (el != 0 for el in elements )
126
+ ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
127
+
128
+
91
129
@given (hh .arrays (dtype = hh .all_dtypes , shape = ()))
92
130
def test_nonzero_zerodim_error (x ):
93
131
with pytest .raises (Exception ):
0 commit comments