Skip to content

Add "multi device" support #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
600df5e
Add "multi device" support
betatim Sep 4, 2024
325b9d0
Loop device through elementwise functions
betatim Sep 27, 2024
6cc7bac
Define __hash__
betatim Sep 27, 2024
bca670a
More device pass through
betatim Sep 27, 2024
426609f
Fix meshgrid
betatim Sep 27, 2024
727072f
Add testing and small typo fixes
betatim Oct 2, 2024
23f390e
Add a comment about atanh special casing
betatim Oct 3, 2024
405b7e7
Add conversion to NumPy test
betatim Oct 3, 2024
03e1ae7
Add multi-device support to sorting functions
betatim Oct 3, 2024
3bc8199
More multi-device support
betatim Oct 3, 2024
032f3bb
Formatting
betatim Oct 3, 2024
724e071
Add multi-device test for take
betatim Oct 3, 2024
e0b2a64
Multi-device support in linear algebra functions
betatim Oct 3, 2024
9323324
Multi-device support for array manipulation
betatim Oct 3, 2024
ff37de7
Add multi-device support for searching
betatim Oct 3, 2024
bae7482
Add multi-device support to stats and sets
betatim Oct 3, 2024
cca1785
Add multi-device support for utils
betatim Oct 3, 2024
a96c497
More FFT multi-device
betatim Oct 7, 2024
58334e5
Fix weird ruff error
betatim Oct 7, 2024
1c77ba0
Merge branch 'main' into multiple-devices
betatim Oct 7, 2024
9c5436c
New default version
betatim Oct 7, 2024
0dbabcc
Fix result device
betatim Oct 7, 2024
8e6365b
Make device= a required argument to create an Array
betatim Oct 16, 2024
635e14d
Merge branch 'main' into betatim-multiple-devices
asmeurer Oct 16, 2024
78def19
Add device check to repeat()
asmeurer Oct 16, 2024
33450f3
Use ValueError for different device errors
asmeurer Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@

__all__ += ["all", "any"]

from ._array_object import Device
__all__ += ["Device"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not add this to the __init__.py since it isn't part of the array API. If it is necessary to have some public APIs to create device objects we should make APIs that are more obviously array-api-strict specific (similar to the flags APIs).


# Helper functions that are not part of the standard

from ._flags import (
Expand Down
103 changes: 76 additions & 27 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,25 @@

import numpy as np

# Placeholder object to represent the "cpu" device (the only device NumPy
# supports).
class _cpu_device:
class Device:
def __init__(self, device="CPU_DEVICE"):
if device not in ("CPU_DEVICE", "device1", "device2"):
raise ValueError(f"The device '{device}' is not a valid choice.")
self._device = device

def __repr__(self):
return "CPU_DEVICE"
return f"array_api_strict.Device('{self._device}')"

def __eq__(self, other):
if not isinstance(other, Device):
return False
return self._device == other._device

def __hash__(self):
return hash(("Device", self._device))


CPU_DEVICE = _cpu_device()
CPU_DEVICE = Device()

_default = object()

Expand All @@ -73,7 +85,7 @@ class Array:
# Use a custom constructor instead of __init__, as manually initializing
# this class is not supported API.
@classmethod
def _new(cls, x, /):
def _new(cls, x, /, device=None):
"""
This is a private method for initializing the array API Array
object.
Expand All @@ -95,6 +107,9 @@ def _new(cls, x, /):
)
obj._array = x
obj._dtype = _dtype
if device is None:
device = CPU_DEVICE
obj._device = device
return obj

# Prevent Array() from working
Expand All @@ -116,7 +131,11 @@ def __repr__(self: Array, /) -> str:
"""
Performs the operation __repr__.
"""
suffix = f", dtype={self.dtype})"
suffix = f", dtype={self.dtype}"
if self.device != CPU_DEVICE:
suffix += f", device={self.device})"
else:
suffix += ")"
if 0 in self.shape:
prefix = "empty("
mid = str(self.shape)
Expand All @@ -134,6 +153,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
will be present in other implementations.

"""
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
# copy keyword is new in 2.0.0; for older versions don't use it
# retry without that keyword.
if np.__version__[0] < '2':
Expand Down Expand Up @@ -193,6 +214,15 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor

return other

def _check_device(self, other):
"""Check that other is on a device compatible with the current array"""
if isinstance(other, (int, complex, float, bool)):
return other
elif isinstance(other, Array):
if self.device != other.device:
raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
return other

# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar):
"""
Expand Down Expand Up @@ -468,23 +498,25 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __add__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__add__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__add__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __and__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__and__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __array_namespace__(
self: Array, /, *, api_version: Optional[str] = None
Expand Down Expand Up @@ -568,14 +600,15 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
Performs the operation __eq__.
"""
other = self._check_device(other)
# Even though "all" dtypes are allowed, we still require them to be
# promotable with each other.
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__eq__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __float__(self: Array, /) -> float:
"""
Expand All @@ -593,23 +626,25 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __floordiv__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__floordiv__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ge__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ge__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __getitem__(
self: Array,
Expand All @@ -625,19 +660,21 @@ def __getitem__(
"""
Performs the operation __getitem__.
"""
# XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is underspecified. I would need to test what PyTorch and others do, but I would suspect that an implicit cross-device array key is not something that's intended to be supported, since that still would require an implicit device transfer.

# Note: Only indices required by the spec are allowed. See the
# docstring of _validate_index
self._validate_index(key)
if isinstance(key, Array):
# Indexing self._array with array_api_strict arrays can be erroneous
key = key._array
res = self._array.__getitem__(key)
return self._new(res)
return self._new(res, device=self.device)

def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __gt__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -671,7 +708,7 @@ def __invert__(self: Array, /) -> Array:
if self.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
res = self._array.__invert__()
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __iter__(self: Array, /):
"""
Expand All @@ -686,85 +723,92 @@ def __iter__(self: Array, /):
# define __iter__, but it doesn't disallow it. The default Python
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is
# implemented, which implies iteration on 1-D arrays.
return (Array._new(i) for i in self._array)
return (Array._new(i, device=self.device) for i in self._array)

def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __le__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__le__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lshift__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __lt__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lt__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __matmul__(self: Array, other: Array, /) -> Array:
"""
Performs the operation __matmul__.
"""
other = self._check_device(other)
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
if other is NotImplemented:
return other
res = self._array.__matmul__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __mod__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mod__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __mul__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __mul__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mul__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
Performs the operation __ne__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ne__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __neg__(self: Array, /) -> Array:
"""
Expand All @@ -773,18 +817,19 @@ def __neg__(self: Array, /) -> Array:
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __neg__")
res = self._array.__neg__()
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __or__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __or__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__or__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __pos__(self: Array, /) -> Array:
"""
Expand All @@ -793,14 +838,15 @@ def __pos__(self: Array, /) -> Array:
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __pos__")
res = self._array.__pos__()
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __pow__.
"""
from ._elementwise_functions import pow

other = self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -1154,8 +1200,11 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == CPU_DEVICE:
if device == self._device:
return self
elif isinstance(device, Device):
arr = np.asarray(self._array, copy=True)
return self.__class__._new(arr, device=device)
raise ValueError(f"Unsupported device {device!r}")

@property
Expand All @@ -1169,7 +1218,7 @@ def dtype(self) -> Dtype:

@property
def device(self) -> Device:
return CPU_DEVICE
return self._device

# Note: mT is new in array API spec (see matrix_transpose)
@property
Expand Down
Loading
Loading