@@ -120,6 +120,25 @@ def count_nonzero(
120
120
return cp .expand_dims (result , axis )
121
121
return result
122
122
123
+ # ceil, floor, and trunc return integers for integer inputs
124
+
125
+ def ceil (x : Array , / ) -> Array :
126
+ if cp .issubdtype (x .dtype , cp .integer ):
127
+ return x .copy ()
128
+ return cp .ceil (x )
129
+
130
+
131
+ def floor (x : Array , / ) -> Array :
132
+ if cp .issubdtype (x .dtype , cp .integer ):
133
+ return x .copy ()
134
+ return cp .floor (x )
135
+
136
+
137
+ def trunc (x : Array , / ) -> Array :
138
+ if cp .issubdtype (x .dtype , cp .integer ):
139
+ return x .copy ()
140
+ return cp .trunc (x )
141
+
123
142
124
143
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
125
144
def take_along_axis (x : Array , indices : Array , / , * , axis : int = - 1 ):
@@ -148,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
148
167
'atan2' , 'atanh' , 'bitwise_left_shift' ,
149
168
'bitwise_invert' , 'bitwise_right_shift' ,
150
169
'bool' , 'concat' , 'count_nonzero' , 'pow' , 'sign' ,
151
- 'take_along_axis' ]
170
+ 'ceil' , 'floor' , 'trunc' , ' take_along_axis' ]
152
171
153
172
_all_ignore = ['cp' , 'get_xp' ]
0 commit comments