Skip to content

Commit 4d1829f

Browse files
committed
Add sum() and prod() wrappers to handle float promotion
1 parent c28b5f1 commit 4d1829f

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

numpy_array_api_compat/_aliases.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,32 @@ def sort(
310310
res = np.flip(res, axis=axis)
311311
return res
312312

313+
# sum() and prod() should always upcast when dtype=None
314+
def sum(
315+
x: ndarray,
316+
/,
317+
*,
318+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
319+
dtype: Optional[Dtype] = None,
320+
keepdims: bool = False,
321+
) -> ndarray:
322+
# `np.sum` already upcasts integers, but not floats
323+
if dtype is None and x.dtype == np.float32:
324+
dtype = np.float64
325+
return np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
326+
327+
def prod(
328+
x: ndarray,
329+
/,
330+
*,
331+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
332+
dtype: Optional[Dtype] = None,
333+
keepdims: bool = False,
334+
) -> ndarray:
335+
if dtype is None and x.dtype == np.float32:
336+
dtype = np.float64
337+
return np.prod(x, dtype=dtype, axis=axis, keepdims=keepdims)
338+
313339
# from numpy import * doesn't overwrite these builtin names
314340
from numpy import abs, max, min, round
315341

@@ -321,4 +347,4 @@ def sort(
321347
'round', 'std', 'var', 'permute_dims', 'asarray', 'arange',
322348
'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace',
323349
'ones', 'ones_like', 'zeros', 'zeros_like', 'reshape', 'argsort',
324-
'sort']
350+
'sort', 'sum', 'prod']

0 commit comments

Comments
 (0)