Skip to content

Commit 1facc35

Browse files
committed
BUG: make ceil,trunc,floor always respect view/copy semantics
Remove these functions from common/_aliases.py, add specific implementations for numpy < 2 and cupy.
1 parent c9cfc2c commit 1facc35

File tree

4 files changed

+46
-34
lines changed

4 files changed

+46
-34
lines changed

array_api_compat/common/_aliases.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -524,27 +524,6 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
524524
return xp.nonzero(x, **kwargs)
525525

526526

527-
# ceil, floor, and trunc return integers for integer inputs
528-
529-
530-
def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
531-
if xp.issubdtype(x.dtype, xp.integer):
532-
return x
533-
return xp.ceil(x, **kwargs)
534-
535-
536-
def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
537-
if xp.issubdtype(x.dtype, xp.integer):
538-
return x
539-
return xp.floor(x, **kwargs)
540-
541-
542-
def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
543-
if xp.issubdtype(x.dtype, xp.integer):
544-
return x
545-
return xp.trunc(x, **kwargs)
546-
547-
548527
# linear algebra functions
549528

550529

@@ -707,9 +686,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
707686
"argsort",
708687
"sort",
709688
"nonzero",
710-
"ceil",
711-
"floor",
712-
"trunc",
713689
"matmul",
714690
"matrix_transpose",
715691
"tensordot",

array_api_compat/cupy/_aliases.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@
5454
argsort = get_xp(cp)(_aliases.argsort)
5555
sort = get_xp(cp)(_aliases.sort)
5656
nonzero = get_xp(cp)(_aliases.nonzero)
57-
ceil = get_xp(cp)(_aliases.ceil)
58-
floor = get_xp(cp)(_aliases.floor)
59-
trunc = get_xp(cp)(_aliases.trunc)
6057
matmul = get_xp(cp)(_aliases.matmul)
6158
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6259
tensordot = get_xp(cp)(_aliases.tensordot)
@@ -123,6 +120,25 @@ def count_nonzero(
123120
return cp.expand_dims(result, axis)
124121
return result
125122

123+
# ceil, floor, and trunc return integers for integer inputs
124+
125+
def ceil(x: Array, /) -> Array:
126+
if cp.issubdtype(x.dtype, cp.integer):
127+
return x.copy()
128+
return cp.ceil(x)
129+
130+
131+
def floor(x: Array, /) -> Array:
132+
if cp.issubdtype(x.dtype, cp.integer):
133+
return x.copy()
134+
return cp.floor(x)
135+
136+
137+
def trunc(x: Array, /) -> Array:
138+
if cp.issubdtype(x.dtype, cp.integer):
139+
return x.copy()
140+
return cp.trunc(x)
141+
126142

127143
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
128144
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
@@ -151,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
151167
'atan2', 'atanh', 'bitwise_left_shift',
152168
'bitwise_invert', 'bitwise_right_shift',
153169
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
154-
'take_along_axis']
170+
'ceil', 'floor', 'trunc', 'take_along_axis']
155171

156172
_all_ignore = ['cp', 'get_xp']

array_api_compat/dask/array/_aliases.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,6 @@ def arange(
134134
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
135135
vecdot = get_xp(da)(_aliases.vecdot)
136136
nonzero = get_xp(da)(_aliases.nonzero)
137-
ceil = get_xp(np)(_aliases.ceil)
138-
floor = get_xp(np)(_aliases.floor)
139-
trunc = get_xp(np)(_aliases.trunc)
140137
matmul = get_xp(np)(_aliases.matmul)
141138
tensordot = get_xp(np)(_aliases.tensordot)
142139
sign = get_xp(np)(_aliases.sign)

array_api_compat/numpy/_aliases.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@
6363
argsort = get_xp(np)(_aliases.argsort)
6464
sort = get_xp(np)(_aliases.sort)
6565
nonzero = get_xp(np)(_aliases.nonzero)
66-
ceil = get_xp(np)(_aliases.ceil)
67-
floor = get_xp(np)(_aliases.floor)
68-
trunc = get_xp(np)(_aliases.trunc)
6966
matmul = get_xp(np)(_aliases.matmul)
7067
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
7168
tensordot = get_xp(np)(_aliases.tensordot)
@@ -145,6 +142,29 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
145142
return np.take_along_axis(x, indices, axis=axis)
146143

147144

145+
# ceil, floor, and trunc return integers for integer inputs in NumPy < 2
146+
147+
def ceil(x: Array, /) -> Array:
148+
if np.issubdtype(x.dtype, np.integer):
149+
if np.__version__ < '2':
150+
return x.copy()
151+
return np.ceil(x)
152+
153+
154+
def floor(x: Array, /) -> Array:
155+
if np.issubdtype(x.dtype, np.integer):
156+
if np.__version__ < '2':
157+
return x.copy()
158+
return np.floor(x)
159+
160+
161+
def trunc(x: Array, /) -> Array:
162+
if np.issubdtype(x.dtype, np.integer):
163+
if np.__version__ < '2':
164+
return x.copy()
165+
return np.trunc(x)
166+
167+
148168
# These functions are completely new here. If the library already has them
149169
# (i.e., numpy 2.0), use the library version instead of our wrapper.
150170
if hasattr(np, "vecdot"):
@@ -173,6 +193,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
173193
"atan",
174194
"atan2",
175195
"atanh",
196+
"ceil",
197+
"floor",
198+
"trunc",
176199
"bitwise_left_shift",
177200
"bitwise_invert",
178201
"bitwise_right_shift",

0 commit comments

Comments
 (0)