From b003715d0157ae6a99851aa87992a0aac8d58b73 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 18 Apr 2025 15:23:30 +0100 Subject: [PATCH] ENH: CuPy creation functions to respect device= parameter --- array_api_compat/common/_aliases.py | 46 +++++++++++++-------------- array_api_compat/common/_helpers.py | 49 +++++++++++++++++++---------- array_api_compat/cupy/_aliases.py | 3 +- 3 files changed, 58 insertions(+), 40 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 27b2604b..99e937fb 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -8,7 +8,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any, NamedTuple, cast -from ._helpers import _check_device, array_namespace +from ._helpers import _device_ctx, array_namespace from ._helpers import device as _get_device from ._helpers import is_cupy_namespace from ._typing import Array, Device, DType, Namespace @@ -33,8 +33,8 @@ def arange( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( @@ -45,8 +45,8 @@ def empty( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.empty(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( @@ -58,8 +58,8 @@ def empty_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.empty_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.empty_like(x, dtype=dtype, **kwargs) def eye( @@ -73,8 +73,8 @@ def eye( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( @@ -86,8 +86,8 @@ def full( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.full(shape, fill_value, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( @@ -100,8 +100,8 @@ def full_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.full_like(x, fill_value, dtype=dtype, **kwargs) def linspace( @@ -116,8 +116,8 @@ def linspace( endpoint: bool = True, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + with _device_ctx(xp, device): + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( @@ -128,8 +128,8 @@ def ones( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.ones(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( @@ -141,8 +141,8 @@ def ones_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.ones_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( @@ -153,8 +153,8 @@ def zeros( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.zeros(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( @@ -166,8 +166,8 @@ def zeros_like( device: Device | None = None, **kwargs: object, ) -> Array: - _check_device(xp, device) - return xp.zeros_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.zeros_like(x, dtype=dtype, **kwargs) # np.unique() is split into four functions in the array API: diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 37f31ec2..22119160 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -8,12 +8,13 @@ from __future__ import annotations +import contextlib import enum import inspect import math import sys import warnings -from collections.abc import Collection, Hashable +from collections.abc import Collection, Generator, Hashable from functools import lru_cache from typing import ( TYPE_CHECKING, @@ -669,26 +670,42 @@ def your_function(x, y): get_namespace = array_namespace -def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] - """ - Validate dummy device on device-less array backends. +def _device_ctx( + bare_xp: Namespace, device: Device, like: Array | None = None +) -> Generator[None]: + """Context manager which changes the current device in CuPy. - Notes - ----- - This function is also invoked by CuPy, which does have multiple devices - if there are multiple GPUs available. - However, CuPy multi-device support is currently impossible - without using the global device or a context manager: - - https://github.com/data-apis/array-api-compat/pull/293 + Used internally by array creation functions in common._aliases. """ - if bare_xp is sys.modules.get("numpy"): - if device not in ("cpu", None): + if device is None: + if like is None: + return contextlib.nullcontext() + device = _device(like) + + if bare_xp is sys.modules.get('numpy'): + if device != "cpu": raise ValueError(f"Unsupported device for NumPy: {device!r}") + return contextlib.nullcontext() - elif bare_xp is sys.modules.get("dask.array"): - if device not in ("cpu", _DASK_DEVICE, None): + if bare_xp is sys.modules.get('dask.array'): + if device not in ("cpu", _DASK_DEVICE): raise ValueError(f"Unsupported device for Dask: {device!r}") + return contextlib.nullcontext() + + if bare_xp is sys.modules.get('cupy'): + if not isinstance(device, bare_xp.cuda.Device): + raise TypeError(f"device is not a cupy.cuda.Device: {device!r}") + return device + + # PyTorch doesn't have a "current device" context manager and you + # can't use array creation functions from common._aliases. + raise AssertionError("unreachable") # pragma: nocover + + +def _check_device(bare_xp: Namespace, device: Device) -> None: + """Validate dummy device on device-less array backends.""" + with _device_ctx(bare_xp, device): + pass # Placeholder object to represent the dask device diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 2752bd98..330f9bb9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -80,7 +80,8 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - with cp.cuda.Device(device): + like = obj if isinstance(obj, cp.ndarray) else None + with _helpers._device_ctx(cp, device, like=like): if copy is None: return cp.asarray(obj, dtype=dtype, **kwargs) else: