Skip to content

Commit b003715

Browse files
committed
ENH: CuPy creation functions to respect device= parameter
1 parent cddc9ef commit b003715

File tree

3 files changed

+58
-40
lines changed

3 files changed

+58
-40
lines changed

array_api_compat/common/_aliases.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Sequence
99
from typing import TYPE_CHECKING, Any, NamedTuple, cast
1010

11-
from ._helpers import _check_device, array_namespace
11+
from ._helpers import _device_ctx, array_namespace
1212
from ._helpers import device as _get_device
1313
from ._helpers import is_cupy_namespace
1414
from ._typing import Array, Device, DType, Namespace
@@ -33,8 +33,8 @@ def arange(
3333
device: Device | None = None,
3434
**kwargs: object,
3535
) -> Array:
36-
_check_device(xp, device)
37-
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
36+
with _device_ctx(xp, device):
37+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3838

3939

4040
def empty(
@@ -45,8 +45,8 @@ def empty(
4545
device: Device | None = None,
4646
**kwargs: object,
4747
) -> Array:
48-
_check_device(xp, device)
49-
return xp.empty(shape, dtype=dtype, **kwargs)
48+
with _device_ctx(xp, device):
49+
return xp.empty(shape, dtype=dtype, **kwargs)
5050

5151

5252
def empty_like(
@@ -58,8 +58,8 @@ def empty_like(
5858
device: Device | None = None,
5959
**kwargs: object,
6060
) -> Array:
61-
_check_device(xp, device)
62-
return xp.empty_like(x, dtype=dtype, **kwargs)
61+
with _device_ctx(xp, device, like=x):
62+
return xp.empty_like(x, dtype=dtype, **kwargs)
6363

6464

6565
def eye(
@@ -73,8 +73,8 @@ def eye(
7373
device: Device | None = None,
7474
**kwargs: object,
7575
) -> Array:
76-
_check_device(xp, device)
77-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
76+
with _device_ctx(xp, device):
77+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
7878

7979

8080
def full(
@@ -86,8 +86,8 @@ def full(
8686
device: Device | None = None,
8787
**kwargs: object,
8888
) -> Array:
89-
_check_device(xp, device)
90-
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
89+
with _device_ctx(xp, device):
90+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
9191

9292

9393
def full_like(
@@ -100,8 +100,8 @@ def full_like(
100100
device: Device | None = None,
101101
**kwargs: object,
102102
) -> Array:
103-
_check_device(xp, device)
104-
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
103+
with _device_ctx(xp, device, like=x):
104+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
105105

106106

107107
def linspace(
@@ -116,8 +116,8 @@ def linspace(
116116
endpoint: bool = True,
117117
**kwargs: object,
118118
) -> Array:
119-
_check_device(xp, device)
120-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
119+
with _device_ctx(xp, device):
120+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
121121

122122

123123
def ones(
@@ -128,8 +128,8 @@ def ones(
128128
device: Device | None = None,
129129
**kwargs: object,
130130
) -> Array:
131-
_check_device(xp, device)
132-
return xp.ones(shape, dtype=dtype, **kwargs)
131+
with _device_ctx(xp, device):
132+
return xp.ones(shape, dtype=dtype, **kwargs)
133133

134134

135135
def ones_like(
@@ -141,8 +141,8 @@ def ones_like(
141141
device: Device | None = None,
142142
**kwargs: object,
143143
) -> Array:
144-
_check_device(xp, device)
145-
return xp.ones_like(x, dtype=dtype, **kwargs)
144+
with _device_ctx(xp, device, like=x):
145+
return xp.ones_like(x, dtype=dtype, **kwargs)
146146

147147

148148
def zeros(
@@ -153,8 +153,8 @@ def zeros(
153153
device: Device | None = None,
154154
**kwargs: object,
155155
) -> Array:
156-
_check_device(xp, device)
157-
return xp.zeros(shape, dtype=dtype, **kwargs)
156+
with _device_ctx(xp, device):
157+
return xp.zeros(shape, dtype=dtype, **kwargs)
158158

159159

160160
def zeros_like(
@@ -166,8 +166,8 @@ def zeros_like(
166166
device: Device | None = None,
167167
**kwargs: object,
168168
) -> Array:
169-
_check_device(xp, device)
170-
return xp.zeros_like(x, dtype=dtype, **kwargs)
169+
with _device_ctx(xp, device, like=x):
170+
return xp.zeros_like(x, dtype=dtype, **kwargs)
171171

172172

173173
# np.unique() is split into four functions in the array API:

array_api_compat/common/_helpers.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
from __future__ import annotations
1010

11+
import contextlib
1112
import enum
1213
import inspect
1314
import math
1415
import sys
1516
import warnings
16-
from collections.abc import Collection, Hashable
17+
from collections.abc import Collection, Generator, Hashable
1718
from functools import lru_cache
1819
from typing import (
1920
TYPE_CHECKING,
@@ -669,26 +670,42 @@ def your_function(x, y):
669670
get_namespace = array_namespace
670671

671672

672-
def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
673-
"""
674-
Validate dummy device on device-less array backends.
673+
def _device_ctx(
674+
bare_xp: Namespace, device: Device, like: Array | None = None
675+
) -> Generator[None]:
676+
"""Context manager which changes the current device in CuPy.
675677
676-
Notes
677-
-----
678-
This function is also invoked by CuPy, which does have multiple devices
679-
if there are multiple GPUs available.
680-
However, CuPy multi-device support is currently impossible
681-
without using the global device or a context manager:
682-
683-
https://github.com/data-apis/array-api-compat/pull/293
678+
Used internally by array creation functions in common._aliases.
684679
"""
685-
if bare_xp is sys.modules.get("numpy"):
686-
if device not in ("cpu", None):
680+
if device is None:
681+
if like is None:
682+
return contextlib.nullcontext()
683+
device = _device(like)
684+
685+
if bare_xp is sys.modules.get('numpy'):
686+
if device != "cpu":
687687
raise ValueError(f"Unsupported device for NumPy: {device!r}")
688+
return contextlib.nullcontext()
688689

689-
elif bare_xp is sys.modules.get("dask.array"):
690-
if device not in ("cpu", _DASK_DEVICE, None):
690+
if bare_xp is sys.modules.get('dask.array'):
691+
if device not in ("cpu", _DASK_DEVICE):
691692
raise ValueError(f"Unsupported device for Dask: {device!r}")
693+
return contextlib.nullcontext()
694+
695+
if bare_xp is sys.modules.get('cupy'):
696+
if not isinstance(device, bare_xp.cuda.Device):
697+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
698+
return device
699+
700+
# PyTorch doesn't have a "current device" context manager and you
701+
# can't use array creation functions from common._aliases.
702+
raise AssertionError("unreachable") # pragma: nocover
703+
704+
705+
def _check_device(bare_xp: Namespace, device: Device) -> None:
706+
"""Validate dummy device on device-less array backends."""
707+
with _device_ctx(bare_xp, device):
708+
pass
692709

693710

694711
# Placeholder object to represent the dask device

array_api_compat/cupy/_aliases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def asarray(
8080
See the corresponding documentation in the array library and/or the array API
8181
specification for more details.
8282
"""
83-
with cp.cuda.Device(device):
83+
like = obj if isinstance(obj, cp.ndarray) else None
84+
with _helpers._device_ctx(cp, device, like=like):
8485
if copy is None:
8586
return cp.asarray(obj, dtype=dtype, **kwargs)
8687
else:

0 commit comments

Comments
 (0)