@@ -310,6 +310,32 @@ def sort(
310
310
res = np .flip (res , axis = axis )
311
311
return res
312
312
313
+ # sum() and prod() should always upcast when dtype=None
314
+ def sum (
315
+ x : ndarray ,
316
+ / ,
317
+ * ,
318
+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
319
+ dtype : Optional [Dtype ] = None ,
320
+ keepdims : bool = False ,
321
+ ) -> ndarray :
322
+ # `np.sum` already upcasts integers, but not floats
323
+ if dtype is None and x .dtype == np .float32 :
324
+ dtype = np .float64
325
+ return np .sum (x , axis = axis , dtype = dtype , keepdims = keepdims )
326
+
327
+ def prod (
328
+ x : ndarray ,
329
+ / ,
330
+ * ,
331
+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
332
+ dtype : Optional [Dtype ] = None ,
333
+ keepdims : bool = False ,
334
+ ) -> ndarray :
335
+ if dtype is None and x .dtype == np .float32 :
336
+ dtype = np .float64
337
+ return np .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
338
+
313
339
# from numpy import * doesn't overwrite these builtin names
314
340
from numpy import abs , max , min , round
315
341
@@ -321,4 +347,4 @@ def sort(
321
347
'round' , 'std' , 'var' , 'permute_dims' , 'asarray' , 'arange' ,
322
348
'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' , 'linspace' ,
323
349
'ones' , 'ones_like' , 'zeros' , 'zeros_like' , 'reshape' , 'argsort' ,
324
- 'sort' ]
350
+ 'sort' , 'sum' , 'prod' ]
0 commit comments