From 68c1f6fa958a76939f17cec3f5dbdd17fb1559f2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 2 Jun 2025 15:25:02 +0100 Subject: [PATCH 1/2] default_dtype --- docs/api-reference.md | 1 + src/array_api_extra/__init__.py | 2 ++ src/array_api_extra/_lib/_funcs.py | 46 ++++++++++++++++++++++++++---- tests/conftest.py | 38 +++++++++++++++++------- tests/test_funcs.py | 34 ++++++++++++++++++++++ 5 files changed, 106 insertions(+), 15 deletions(-) diff --git a/docs/api-reference.md b/docs/api-reference.md index ee33a819..8e9375d0 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -12,6 +12,7 @@ broadcast_shapes cov create_diagonal + default_dtype expand_dims isclose kron diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 1c55a6d1..5cfe8594 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -8,6 +8,7 @@ broadcast_shapes, cov, create_diagonal, + default_dtype, expand_dims, kron, nunique, @@ -27,6 +28,7 @@ "broadcast_shapes", "cov", "create_diagonal", + "default_dtype", "expand_dims", "isclose", "kron", diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 645f7a1b..be703fb5 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Callable, Sequence from types import ModuleType, NoneType -from typing import cast, overload +from typing import Literal, cast, overload from ._at import at from ._utils import _compat, _helpers @@ -16,7 +16,7 @@ meta_namespace, ndindex, ) -from ._utils._typing import Array +from ._utils._typing import Array, Device, DType __all__ = [ "apply_where", @@ -438,6 +438,44 @@ def create_diagonal( return xp.reshape(diag, (*batch_dims, n, n)) +def default_dtype( + xp: ModuleType, + kind: Literal[ + "real floating", "complex floating", "integral", "indexing" + ] = "real floating", + *, + device: Device | None = None, +) -> DType: + """ + Return the default dtype for the given namespace and device. + + This is a convenience shorthand for + ``xp.__array_namespace_info__().default_dtypes(device=device)[kind]``. + + Parameters + ---------- + xp : array_namespace + The standard-compatible namespace for which to get the default dtype. + kind : {'real floating', 'complex floating', 'integral', 'indexing'}, optional + The kind of dtype to return. Default is 'real floating'. + device : Device, optional + The device for which to get the default dtype. Default: current device. + + Returns + ------- + dtype + The default dtype for the given namespace, kind, and device. + """ + dtypes = xp.__array_namespace_info__().default_dtypes(device=device) + try: + return dtypes[kind] + except KeyError as e: + domain = ("real floating", "complex floating", "integral", "indexing") + assert set(dtypes) == set(domain), f"Non-compliant namespace: {dtypes}" + msg = f"Unknown kind '{kind}'. Expected one of {domain}." + raise ValueError(msg) from e + + def expand_dims( a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None ) -> Array: @@ -728,9 +766,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: x = xp.reshape(x, (-1,)) x = xp.sort(x) mask = x != xp.roll(x, -1) - default_int = xp.__array_namespace_info__().default_dtypes( - device=_compat.device(x) - )["integral"] + default_int = default_dtype(xp, "integral", device=_compat.device(x)) return xp.maximum( # Special cases: # - array is size 0 diff --git a/tests/conftest.py b/tests/conftest.py index bb72139b..29297918 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -149,16 +149,8 @@ def xp( if library.like(Backend.JAX): _setup_jax(library) - - elif library == Backend.TORCH_GPU: - import torch.cuda - - if not torch.cuda.is_available(): - pytest.skip("no CUDA device available") - xp.set_default_device("cuda") - - elif library == Backend.TORCH: # CPU - xp.set_default_device("cpu") + elif library.like(Backend.TORCH): + _setup_torch(library) yield xp @@ -179,6 +171,23 @@ def _setup_jax(library: Backend) -> None: jax.config.update("jax_default_device", device) +def _setup_torch(library: Backend) -> None: + import torch + + # This is already the default, but some tests or env variables may change it. + # TODO test both float32 and float64, like in scipy. + torch.set_default_dtype(torch.float32) + + if library == Backend.TORCH_GPU: + import torch.cuda + + if not torch.cuda.is_available(): + pytest.skip("no CUDA device available") + torch.set_default_device("cuda") + else: # TORCH + torch.set_default_device("cpu") + + @pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask` def da( request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch @@ -201,6 +210,15 @@ def jnp( return xp +@pytest.fixture(params=[Backend.TORCH, Backend.TORCH_GPU]) +def torch(request: pytest.FixtureRequest) -> ModuleType: # numpydoc ignore=PR01,RT01 + """Variant of the `xp` fixture that only yields torch.""" + xp = pytest.importorskip("torch") + xp = array_namespace(xp.empty(0)) + _setup_torch(request.param) + return xp + + @pytest.fixture def device( library: Backend, xp: ModuleType diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 7c138442..b89c7441 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -17,6 +17,7 @@ broadcast_shapes, cov, create_diagonal, + default_dtype, expand_dims, isclose, kron, @@ -517,6 +518,39 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]])) +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no __array_namespace_info__") +class TestDefaultDType: + def test_basic(self, xp: ModuleType): + assert default_dtype(xp) == xp.empty(0).dtype + + def test_kind(self, xp: ModuleType): + assert default_dtype(xp, "real floating") == xp.empty(0).dtype + assert default_dtype(xp, "complex floating") == (xp.empty(0) * 1j).dtype + assert default_dtype(xp, "integral") == xp.int64 + assert default_dtype(xp, "indexing") == xp.int64 + + with pytest.raises(ValueError, match="Unknown kind"): + _ = default_dtype(xp, "foo") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + def test_device(self, xp: ModuleType, device: Device): + # Note: at the moment there are no known namespaces with + # device-specific default dtypes. + assert default_dtype(xp, device=None) == xp.empty(0).dtype + assert default_dtype(xp, device=device) == xp.empty(0).dtype + + def test_torch(self, torch: ModuleType): + xp = torch + xp.set_default_dtype(xp.float64) + assert default_dtype(xp) == xp.float64 + assert default_dtype(xp, "real floating") == xp.float64 + assert default_dtype(xp, "complex floating") == xp.complex128 + + xp.set_default_dtype(xp.float32) + assert default_dtype(xp) == xp.float32 + assert default_dtype(xp, "real floating") == xp.float32 + assert default_dtype(xp, "complex floating") == xp.complex64 + + class TestExpandDims: def test_single_axis(self, xp: ModuleType): """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims""" From 900ab1c116563532b57eea12b26601decb08121f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 3 Jun 2025 08:20:00 +0100 Subject: [PATCH 2/2] tweak comment in conftest --- tests/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 29297918..82a0acbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -184,7 +184,8 @@ def _setup_torch(library: Backend) -> None: if not torch.cuda.is_available(): pytest.skip("no CUDA device available") torch.set_default_device("cuda") - else: # TORCH + else: + assert library == Backend.TORCH torch.set_default_device("cpu")