diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 7dc6c5c..8cec86a 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -14,6 +14,7 @@ from ._array_object import Array from ._flags import requires_api_version from ._creation_functions import asarray +from ._data_type_functions import broadcast_to, iinfo from typing import Optional, Union @@ -325,14 +326,51 @@ def clip( if min is not None and max is not None and np.any(min > max): raise ValueError("min must be less than or equal to max") - result = np.clip(x._array, min, max) - # Note: NumPy applies type promotion, but the standard specifies the - # return dtype should be the same as x - if result.dtype != x.dtype._np_dtype: - # TODO: I'm not completely sure this always gives the correct thing - # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 - result = result.astype(x.dtype._np_dtype) - return Array._new(result, device=x.device) + # np.clip does type promotion but the array API clip requires that the + # output have the same dtype as x. We do this instead of just downcasting + # the result of xp.clip() to handle some corner cases better (e.g., + # avoiding uint64 -> float64 promotion). + + # Note: cases where min or max overflow (integer) or round (float) in the + # wrong direction when downcasting to x.dtype are unspecified. This code + # just does whatever NumPy does when it downcasts in the assignment, but + # other behavior could be preferred, especially for integers. For example, + # this code produces: + + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) + # -128 + + # but an answer of 0 might be preferred. See + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. + + # At least handle the case of Python integers correctly (see + # https://github.com/numpy/numpy/pull/26892). + if type(min) is int and min <= iinfo(x.dtype).min: + min = None + if type(max) is int and max >= iinfo(x.dtype).max: + max = None + + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape + max_shape = () if _isscalar(max) else max.shape + + result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) + + out = asarray(broadcast_to(x, result_shape), copy=True)._array + device = x.device + x = x._array + + if min is not None: + a = np.broadcast_to(np.asarray(min), result_shape) + ia = (out < a) | np.isnan(a) + + out[ia] = a[ia] + if max is not None: + b = np.broadcast_to(np.asarray(max), result_shape) + ib = (out > b) | np.isnan(b) + out[ib] = b[ib] + return Array._new(out, device=device) def conj(x: Array, /) -> Array: """