Skip to content

Commit 4bcf0dd

Browse files
committed
BUG: cupy: ceil/trunc/floor return integers for integer inputs
1 parent c4bab35 commit 4bcf0dd

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,25 @@ def count_nonzero(
120120
return cp.expand_dims(result, axis)
121121
return result
122122

123+
# ceil, floor, and trunc return integers for integer inputs
124+
125+
def ceil(x: Array, /) -> Array:
126+
if cp.issubdtype(x.dtype, cp.integer):
127+
return x.copy()
128+
return cp.ceil(x)
129+
130+
131+
def floor(x: Array, /) -> Array:
132+
if cp.issubdtype(x.dtype, cp.integer):
133+
return x.copy()
134+
return cp.floor(x)
135+
136+
137+
def trunc(x: Array, /) -> Array:
138+
if cp.issubdtype(x.dtype, cp.integer):
139+
return x.copy()
140+
return cp.trunc(x)
141+
123142

124143
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
125144
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
@@ -148,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
148167
'atan2', 'atanh', 'bitwise_left_shift',
149168
'bitwise_invert', 'bitwise_right_shift',
150169
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
151-
'take_along_axis']
170+
'ceil', 'floor', 'trunc', 'take_along_axis']
152171

153172
_all_ignore = ['cp', 'get_xp']

0 commit comments

Comments
 (0)