@@ -144,6 +144,7 @@ def _check_device(device):
144
144
if device not in ["cpu" , None ]:
145
145
raise ValueError (f"Unsupported device { device !r} " )
146
146
147
+ # asarray also adds the copy keyword
147
148
def asarray (
148
149
obj : Union [
149
150
ndarray ,
@@ -336,6 +337,23 @@ def prod(
336
337
dtype = np .float64
337
338
return np .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
338
339
340
+ # ceil, floor, and trunc return integers for integer inputs
341
+
342
+ def ceil (x : ndarray , / ) -> ndarray :
343
+ if np .issubdtype (x .dtype , np .integer ):
344
+ return x
345
+ return np .ceil (x )
346
+
347
+ def floor (x : ndarray , / ) -> ndarray :
348
+ if np .issubdtype (x .dtype , np .integer ):
349
+ return x
350
+ return np .floor (x )
351
+
352
+ def trunc (x : ndarray , / ) -> ndarray :
353
+ if np .issubdtype (x .dtype , np .integer ):
354
+ return x
355
+ return np .trunc (x )
356
+
339
357
# from numpy import * doesn't overwrite these builtin names
340
358
from numpy import abs , max , min , round
341
359
@@ -347,4 +365,4 @@ def prod(
347
365
'round' , 'std' , 'var' , 'permute_dims' , 'asarray' , 'arange' ,
348
366
'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' , 'linspace' ,
349
367
'ones' , 'ones_like' , 'zeros' , 'zeros_like' , 'reshape' , 'argsort' ,
350
- 'sort' , 'sum' , 'prod' ]
368
+ 'sort' , 'sum' , 'prod' , 'ceil' , 'floor' , 'trunc' ]
0 commit comments