-
Notifications
You must be signed in to change notification settings - Fork 11
Add "multi device" support #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
600df5e
Add "multi device" support
betatim 325b9d0
Loop device through elementwise functions
betatim 6cc7bac
Define __hash__
betatim bca670a
More device pass through
betatim 426609f
Fix meshgrid
betatim 727072f
Add testing and small typo fixes
betatim 23f390e
Add a comment about atanh special casing
betatim 405b7e7
Add conversion to NumPy test
betatim 03e1ae7
Add multi-device support to sorting functions
betatim 3bc8199
More multi-device support
betatim 032f3bb
Formatting
betatim 724e071
Add multi-device test for take
betatim e0b2a64
Multi-device support in linear algebra functions
betatim 9323324
Multi-device support for array manipulation
betatim ff37de7
Add multi-device support for searching
betatim bae7482
Add multi-device support to stats and sets
betatim cca1785
Add multi-device support for utils
betatim a96c497
More FFT multi-device
betatim 58334e5
Fix weird ruff error
betatim 1c77ba0
Merge branch 'main' into multiple-devices
betatim 9c5436c
New default version
betatim 0dbabcc
Fix result device
betatim 8e6365b
Make device= a required argument to create an Array
betatim 635e14d
Merge branch 'main' into betatim-multiple-devices
asmeurer 78def19
Add device check to repeat()
asmeurer 33450f3
Use ValueError for different device errors
asmeurer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,13 +43,25 @@ | |
|
||
import numpy as np | ||
|
||
# Placeholder object to represent the "cpu" device (the only device NumPy | ||
# supports). | ||
class _cpu_device: | ||
class Device: | ||
def __init__(self, device="CPU_DEVICE"): | ||
if device not in ("CPU_DEVICE", "device1", "device2"): | ||
raise ValueError(f"The device '{device}' is not a valid choice.") | ||
self._device = device | ||
|
||
def __repr__(self): | ||
return "CPU_DEVICE" | ||
return f"array_api_strict.Device('{self._device}')" | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, Device): | ||
return False | ||
return self._device == other._device | ||
|
||
def __hash__(self): | ||
return hash(("Device", self._device)) | ||
|
||
|
||
CPU_DEVICE = _cpu_device() | ||
CPU_DEVICE = Device() | ||
|
||
_default = object() | ||
|
||
|
@@ -73,7 +85,7 @@ class Array: | |
# Use a custom constructor instead of __init__, as manually initializing | ||
# this class is not supported API. | ||
@classmethod | ||
def _new(cls, x, /): | ||
def _new(cls, x, /, device=None): | ||
""" | ||
This is a private method for initializing the array API Array | ||
object. | ||
|
@@ -95,6 +107,9 @@ def _new(cls, x, /): | |
) | ||
obj._array = x | ||
obj._dtype = _dtype | ||
if device is None: | ||
device = CPU_DEVICE | ||
obj._device = device | ||
return obj | ||
|
||
# Prevent Array() from working | ||
|
@@ -116,7 +131,11 @@ def __repr__(self: Array, /) -> str: | |
""" | ||
Performs the operation __repr__. | ||
""" | ||
suffix = f", dtype={self.dtype})" | ||
suffix = f", dtype={self.dtype}" | ||
if self.device != CPU_DEVICE: | ||
suffix += f", device={self.device})" | ||
else: | ||
suffix += ")" | ||
if 0 in self.shape: | ||
prefix = "empty(" | ||
mid = str(self.shape) | ||
|
@@ -134,6 +153,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None | |
will be present in other implementations. | ||
|
||
""" | ||
if self._device != CPU_DEVICE: | ||
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") | ||
# copy keyword is new in 2.0.0; for older versions don't use it | ||
# retry without that keyword. | ||
if np.__version__[0] < '2': | ||
|
@@ -193,6 +214,15 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor | |
|
||
return other | ||
|
||
def _check_device(self, other): | ||
"""Check that other is on a device compatible with the current array""" | ||
if isinstance(other, (int, complex, float, bool)): | ||
return other | ||
elif isinstance(other, Array): | ||
if self.device != other.device: | ||
raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") | ||
return other | ||
betatim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Helper function to match the type promotion rules in the spec | ||
def _promote_scalar(self, scalar): | ||
""" | ||
|
@@ -468,23 +498,25 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array: | |
""" | ||
Performs the operation __add__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "numeric", "__add__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__add__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: | ||
""" | ||
Performs the operation __and__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__and__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __array_namespace__( | ||
self: Array, /, *, api_version: Optional[str] = None | ||
|
@@ -568,14 +600,15 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: | |
""" | ||
Performs the operation __eq__. | ||
""" | ||
other = self._check_device(other) | ||
# Even though "all" dtypes are allowed, we still require them to be | ||
# promotable with each other. | ||
other = self._check_allowed_dtypes(other, "all", "__eq__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__eq__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __float__(self: Array, /) -> float: | ||
""" | ||
|
@@ -593,23 +626,25 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: | |
""" | ||
Performs the operation __floordiv__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__floordiv__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: | ||
""" | ||
Performs the operation __ge__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "real numeric", "__ge__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__ge__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __getitem__( | ||
self: Array, | ||
|
@@ -625,19 +660,21 @@ def __getitem__( | |
""" | ||
Performs the operation __getitem__. | ||
""" | ||
# XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is underspecified. I would need to test what PyTorch and others do, but I would suspect that an implicit cross-device array |
||
# Note: Only indices required by the spec are allowed. See the | ||
# docstring of _validate_index | ||
self._validate_index(key) | ||
if isinstance(key, Array): | ||
# Indexing self._array with array_api_strict arrays can be erroneous | ||
key = key._array | ||
res = self._array.__getitem__(key) | ||
return self._new(res) | ||
return self._new(res, device=self.device) | ||
|
||
def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: | ||
""" | ||
Performs the operation __gt__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "real numeric", "__gt__") | ||
if other is NotImplemented: | ||
return other | ||
|
@@ -671,7 +708,7 @@ def __invert__(self: Array, /) -> Array: | |
if self.dtype not in _integer_or_boolean_dtypes: | ||
raise TypeError("Only integer or boolean dtypes are allowed in __invert__") | ||
res = self._array.__invert__() | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __iter__(self: Array, /): | ||
""" | ||
|
@@ -686,85 +723,92 @@ def __iter__(self: Array, /): | |
# define __iter__, but it doesn't disallow it. The default Python | ||
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is | ||
# implemented, which implies iteration on 1-D arrays. | ||
return (Array._new(i) for i in self._array) | ||
return (Array._new(i, device=self.device) for i in self._array) | ||
|
||
def __le__(self: Array, other: Union[int, float, Array], /) -> Array: | ||
""" | ||
Performs the operation __le__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "real numeric", "__le__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__le__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __lshift__(self: Array, other: Union[int, Array], /) -> Array: | ||
""" | ||
Performs the operation __lshift__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "integer", "__lshift__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__lshift__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: | ||
""" | ||
Performs the operation __lt__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "real numeric", "__lt__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__lt__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __matmul__(self: Array, other: Array, /) -> Array: | ||
""" | ||
Performs the operation __matmul__. | ||
""" | ||
other = self._check_device(other) | ||
# matmul is not defined for scalars, but without this, we may get | ||
# the wrong error message from asarray. | ||
other = self._check_allowed_dtypes(other, "numeric", "__matmul__") | ||
if other is NotImplemented: | ||
return other | ||
res = self._array.__matmul__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: | ||
""" | ||
Performs the operation __mod__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "real numeric", "__mod__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__mod__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: | ||
""" | ||
Performs the operation __mul__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "numeric", "__mul__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__mul__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: | ||
""" | ||
Performs the operation __ne__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "all", "__ne__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__ne__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __neg__(self: Array, /) -> Array: | ||
""" | ||
|
@@ -773,18 +817,19 @@ def __neg__(self: Array, /) -> Array: | |
if self.dtype not in _numeric_dtypes: | ||
raise TypeError("Only numeric dtypes are allowed in __neg__") | ||
res = self._array.__neg__() | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: | ||
""" | ||
Performs the operation __or__. | ||
""" | ||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") | ||
if other is NotImplemented: | ||
return other | ||
self, other = self._normalize_two_args(self, other) | ||
res = self._array.__or__(other._array) | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __pos__(self: Array, /) -> Array: | ||
""" | ||
|
@@ -793,14 +838,15 @@ def __pos__(self: Array, /) -> Array: | |
if self.dtype not in _numeric_dtypes: | ||
raise TypeError("Only numeric dtypes are allowed in __pos__") | ||
res = self._array.__pos__() | ||
return self.__class__._new(res) | ||
return self.__class__._new(res, device=self.device) | ||
|
||
def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: | ||
""" | ||
Performs the operation __pow__. | ||
""" | ||
from ._elementwise_functions import pow | ||
|
||
other = self._check_device(other) | ||
other = self._check_allowed_dtypes(other, "numeric", "__pow__") | ||
if other is NotImplemented: | ||
return other | ||
|
@@ -1154,8 +1200,11 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: | |
def to_device(self: Array, device: Device, /, stream: None = None) -> Array: | ||
if stream is not None: | ||
raise ValueError("The stream argument to to_device() is not supported") | ||
if device == CPU_DEVICE: | ||
if device == self._device: | ||
return self | ||
elif isinstance(device, Device): | ||
arr = np.asarray(self._array, copy=True) | ||
return self.__class__._new(arr, device=device) | ||
raise ValueError(f"Unsupported device {device!r}") | ||
|
||
@property | ||
|
@@ -1169,7 +1218,7 @@ def dtype(self) -> Dtype: | |
|
||
@property | ||
def device(self) -> Device: | ||
return CPU_DEVICE | ||
return self._device | ||
|
||
# Note: mT is new in array API spec (see matrix_transpose) | ||
@property | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not add this to the
__init__.py
since it isn't part of the array API. If it is necessary to have some public APIs to create device objects we should make APIs that are more obviously array-api-strict specific (similar to the flags APIs).