We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5b8c55e commit 2ca6d73Copy full SHA for 2ca6d73
array_api_compat/cupy/_aliases.py
@@ -125,6 +125,20 @@ def astype(
125
return out.copy() if copy and out is x else out
126
127
128
+# cupy.count_nonzero does not have keepdims
129
+def count_nonzero(
130
+ x: ndarray,
131
+ axis=None,
132
+ keepdims=False
133
+) -> ndarray:
134
+ result = cp.count_nonzero(x, axis)
135
+ if keepdims:
136
+ if axis is None:
137
+ return cp.reshape(result, [1]*x.ndim)
138
+ return cp.expand_dims(result, axis)
139
+ return result
140
+
141
142
# These functions are completely new here. If the library already has them
143
# (i.e., numpy 2.0), use the library version instead of our wrapper.
144
if hasattr(cp, 'vecdot'):
@@ -146,6 +160,6 @@ def astype(
146
160
'acos', 'acosh', 'asin', 'asinh', 'atan',
147
161
'atan2', 'atanh', 'bitwise_left_shift',
148
162
'bitwise_invert', 'bitwise_right_shift',
149
- 'bool', 'concat', 'pow', 'sign']
163
+ 'bool', 'concat', 'count_nonzero', 'pow', 'sign']
150
164
151
165
_all_ignore = ['cp', 'get_xp']
0 commit comments