Skip to content

Use a more robust implementation of clip #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._array_object import Array
from ._flags import requires_api_version
from ._creation_functions import asarray
from ._data_type_functions import broadcast_to, iinfo

from typing import Optional, Union

Expand Down Expand Up @@ -325,14 +326,51 @@ def clip(
if min is not None and max is not None and np.any(min > max):
raise ValueError("min must be less than or equal to max")

result = np.clip(x._array, min, max)
# Note: NumPy applies type promotion, but the standard specifies the
# return dtype should be the same as x
if result.dtype != x.dtype._np_dtype:
# TODO: I'm not completely sure this always gives the correct thing
# for integer dtypes. See https://github.com/numpy/numpy/issues/24976
result = result.astype(x.dtype._np_dtype)
return Array._new(result, device=x.device)
# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
# avoiding uint64 -> float64 promotion).

# Note: cases where min or max overflow (integer) or round (float) in the
# wrong direction when downcasting to x.dtype are unspecified. This code
# just does whatever NumPy does when it downcasts in the assignment, but
# other behavior could be preferred, especially for integers. For example,
# this code produces:

# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
# -128

# but an answer of 0 might be preferred. See
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.

# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
if type(min) is int and min <= iinfo(x.dtype).min:
min = None
if type(max) is int and max >= iinfo(x.dtype).max:
max = None

def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape

result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)

out = asarray(broadcast_to(x, result_shape), copy=True)._array
device = x.device
x = x._array

if min is not None:
a = np.broadcast_to(np.asarray(min), result_shape)
ia = (out < a) | np.isnan(a)

out[ia] = a[ia]
if max is not None:
b = np.broadcast_to(np.asarray(max), result_shape)
ib = (out > b) | np.isnan(b)
out[ib] = b[ib]
return Array._new(out, device=device)

def conj(x: Array, /) -> Array:
"""
Expand Down
Loading