Skip to content

Automatically use the correct device in xp.clip with passed Python number literal as bounds #177

Closed
@ogrisel

Description

@ogrisel

I would like the following not to fail with PyTorch:

>>> import array_api_compat.torch  as xp
>>> data = xp.linspace(0, 1, num=5, device="mps")
>>> xp.clip(data, 0.1, 0.9)
Traceback (most recent call last):
  Cell In[4], line 1
    xp.clip(data, 0.1, 0.9)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/_internal.py:28 in wrapped_f
    return f(*args, xp=xp, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/common/_aliases.py:317 in clip
    ia = (out < a) | xp.isnan(a)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

At the moment, we need to be overly verbose to use xp.clip with pytorch on non-cpu tensors:

>>> from array_api_compat import device
>>> device_ = device(data)
>>> xp.clip(data, xp.asarray(0.1, device=device_), xp.asarray(0.9, device=device_))
tensor([0.1000, 0.2500, 0.5000, 0.7500, 0.9000], device='mps:0')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions