Skip to content

Commit 398fdf1

Browse files
committed
Add clip() wrapper for NumPy and CuPy
1 parent 090f570 commit 398fdf1

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

array_api_compat/common/_aliases.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import _check_device
15+
from ._helpers import array_namespace, _check_device
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -264,6 +264,56 @@ def var(
264264
) -> ndarray:
265265
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266266

267+
268+
# The min and max argument names in clip are different and not optional in numpy, and type
269+
# promotion behavior is different.
270+
def clip(
271+
x: ndarray,
272+
/,
273+
min: Optional[Union[int, float, ndarray]] = None,
274+
max: Optional[Union[int, float, ndarray]] = None,
275+
*,
276+
xp,
277+
# TODO: np.clip has other ufunc kwargs
278+
out: Optional[ndarray] = None,
279+
) -> ndarray:
280+
def _isscalar(a):
281+
return isinstance(a, (int, float, type(None)))
282+
min_shape = () if _isscalar(min) else min.shape
283+
max_shape = () if _isscalar(max) else max.shape
284+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285+
286+
wrapped_xp = array_namespace(x)
287+
288+
# np.clip does type promotion but the array API clip requires that the
289+
# output have the same dtype as x. We do this instead of just downcasting
290+
# the result of xp.clip() to handle some corner cases better (e.g.,
291+
# avoiding uint64 -> float64 promotion).
292+
293+
# Note: cases where min or max overflow (integer) or round (float) in the
294+
# wrong direction when downcasting to x.dtype are unspecified. This code
295+
# just does whatever NumPy does when it downcasts in the assignment, but
296+
# other behavior could be preferred, especially for integers. For example,
297+
# this code produces:
298+
299+
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
300+
# -128
301+
302+
# but an answer of 0 might be preferred. See
303+
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
304+
if out is None:
305+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
306+
if min is not None:
307+
a = xp.broadcast_to(xp.asarray(min), result_shape)
308+
ia = (out < a) | xp.isnan(a)
309+
out[ia] = a[ia]
310+
if max is not None:
311+
b = xp.broadcast_to(xp.asarray(max), result_shape)
312+
ib = (out > b) | xp.isnan(b)
313+
out[ib] = b[ib]
314+
# Return a scalar for 0-D
315+
return out[()]
316+
267317
# Unlike transpose(), the axes argument to permute_dims() is required.
268318
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269319
return xp.transpose(x, axes)
@@ -465,6 +515,6 @@ def isdtype(
465515
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
466516
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
467517
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
468-
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
518+
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
469519
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
470520
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
astype = _aliases.astype
4848
std = get_xp(cp)(_aliases.std)
4949
var = get_xp(cp)(_aliases.var)
50+
clip = get_xp(cp)(_aliases.clip)
5051
permute_dims = get_xp(cp)(_aliases.permute_dims)
5152
reshape = get_xp(cp)(_aliases.reshape)
5253
argsort = get_xp(cp)(_aliases.argsort)

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
astype = _aliases.astype
4848
std = get_xp(np)(_aliases.std)
4949
var = get_xp(np)(_aliases.var)
50+
clip = get_xp(np)(_aliases.clip)
5051
permute_dims = get_xp(np)(_aliases.permute_dims)
5152
reshape = get_xp(np)(_aliases.reshape)
5253
argsort = get_xp(np)(_aliases.argsort)

0 commit comments

Comments
 (0)