Skip to content

Commit ac4e03a

Browse files
committed
Add aliases for ceil(), floor(), and trunc()
1 parent e8f7c24 commit ac4e03a

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

numpy_array_api_compat/_aliases.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def _check_device(device):
144144
if device not in ["cpu", None]:
145145
raise ValueError(f"Unsupported device {device!r}")
146146

147+
# asarray also adds the copy keyword
147148
def asarray(
148149
obj: Union[
149150
ndarray,
@@ -336,6 +337,23 @@ def prod(
336337
dtype = np.float64
337338
return np.prod(x, dtype=dtype, axis=axis, keepdims=keepdims)
338339

340+
# ceil, floor, and trunc return integers for integer inputs
341+
342+
def ceil(x: ndarray, /) -> ndarray:
343+
if np.issubdtype(x.dtype, np.integer):
344+
return x
345+
return np.ceil(x)
346+
347+
def floor(x: ndarray, /) -> ndarray:
348+
if np.issubdtype(x.dtype, np.integer):
349+
return x
350+
return np.floor(x)
351+
352+
def trunc(x: ndarray, /) -> ndarray:
353+
if np.issubdtype(x.dtype, np.integer):
354+
return x
355+
return np.trunc(x)
356+
339357
# from numpy import * doesn't overwrite these builtin names
340358
from numpy import abs, max, min, round
341359

@@ -347,4 +365,4 @@ def prod(
347365
'round', 'std', 'var', 'permute_dims', 'asarray', 'arange',
348366
'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace',
349367
'ones', 'ones_like', 'zeros', 'zeros_like', 'reshape', 'argsort',
350-
'sort', 'sum', 'prod']
368+
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc']

0 commit comments

Comments
 (0)