Skip to content

Commit 2ca6d73

Browse files
committed
BUG: cupy: fix count_nonzero(... keepdims=True)
1 parent 5b8c55e commit 2ca6d73

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ def astype(
125125
return out.copy() if copy and out is x else out
126126

127127

128+
# cupy.count_nonzero does not have keepdims
129+
def count_nonzero(
130+
x: ndarray,
131+
axis=None,
132+
keepdims=False
133+
) -> ndarray:
134+
result = cp.count_nonzero(x, axis)
135+
if keepdims:
136+
if axis is None:
137+
return cp.reshape(result, [1]*x.ndim)
138+
return cp.expand_dims(result, axis)
139+
return result
140+
141+
128142
# These functions are completely new here. If the library already has them
129143
# (i.e., numpy 2.0), use the library version instead of our wrapper.
130144
if hasattr(cp, 'vecdot'):
@@ -146,6 +160,6 @@ def astype(
146160
'acos', 'acosh', 'asin', 'asinh', 'atan',
147161
'atan2', 'atanh', 'bitwise_left_shift',
148162
'bitwise_invert', 'bitwise_right_shift',
149-
'bool', 'concat', 'pow', 'sign']
163+
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
150164

151165
_all_ignore = ['cp', 'get_xp']

0 commit comments

Comments
 (0)