Skip to content

ENH: New function default_dtype #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
broadcast_shapes
cov
create_diagonal
default_dtype
expand_dims
isclose
kron
Expand Down
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
broadcast_shapes,
cov,
create_diagonal,
default_dtype,
expand_dims,
kron,
nunique,
Expand All @@ -27,6 +28,7 @@
"broadcast_shapes",
"cov",
"create_diagonal",
"default_dtype",
"expand_dims",
"isclose",
"kron",
Expand Down
46 changes: 41 additions & 5 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from collections.abc import Callable, Sequence
from types import ModuleType, NoneType
from typing import cast, overload
from typing import Literal, cast, overload

from ._at import at
from ._utils import _compat, _helpers
Expand All @@ -16,7 +16,7 @@
meta_namespace,
ndindex,
)
from ._utils._typing import Array
from ._utils._typing import Array, Device, DType

__all__ = [
"apply_where",
Expand Down Expand Up @@ -438,6 +438,44 @@ def create_diagonal(
return xp.reshape(diag, (*batch_dims, n, n))


def default_dtype(
xp: ModuleType,
kind: Literal[
"real floating", "complex floating", "integral", "indexing"
] = "real floating",
*,
device: Device | None = None,
) -> DType:
"""
Return the default dtype for the given namespace and device.

This is a convenience shorthand for
``xp.__array_namespace_info__().default_dtypes(device=device)[kind]``.

Parameters
----------
xp : array_namespace
The standard-compatible namespace for which to get the default dtype.
kind : {'real floating', 'complex floating', 'integral', 'indexing'}, optional
The kind of dtype to return. Default is 'real floating'.
device : Device, optional
The device for which to get the default dtype. Default: current device.

Returns
-------
dtype
The default dtype for the given namespace, kind, and device.
"""
dtypes = xp.__array_namespace_info__().default_dtypes(device=device)
try:
return dtypes[kind]
except KeyError as e:
domain = ("real floating", "complex floating", "integral", "indexing")
assert set(dtypes) == set(domain), f"Non-compliant namespace: {dtypes}"
msg = f"Unknown kind '{kind}'. Expected one of {domain}."
raise ValueError(msg) from e


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
Expand Down Expand Up @@ -728,9 +766,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
x = xp.reshape(x, (-1,))
x = xp.sort(x)
mask = x != xp.roll(x, -1)
default_int = xp.__array_namespace_info__().default_dtypes(
device=_compat.device(x)
)["integral"]
default_int = default_dtype(xp, "integral", device=_compat.device(x))
return xp.maximum(
# Special cases:
# - array is size 0
Expand Down
39 changes: 29 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,8 @@ def xp(

if library.like(Backend.JAX):
_setup_jax(library)

elif library == Backend.TORCH_GPU:
import torch.cuda

if not torch.cuda.is_available():
pytest.skip("no CUDA device available")
xp.set_default_device("cuda")

elif library == Backend.TORCH: # CPU
xp.set_default_device("cpu")
elif library.like(Backend.TORCH):
_setup_torch(library)

yield xp

Expand All @@ -179,6 +171,24 @@ def _setup_jax(library: Backend) -> None:
jax.config.update("jax_default_device", device)


def _setup_torch(library: Backend) -> None:
import torch

# This is already the default, but some tests or env variables may change it.
# TODO test both float32 and float64, like in scipy.
torch.set_default_dtype(torch.float32)

if library == Backend.TORCH_GPU:
import torch.cuda

if not torch.cuda.is_available():
pytest.skip("no CUDA device available")
torch.set_default_device("cuda")
else:
assert library == Backend.TORCH
torch.set_default_device("cpu")


@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
def da(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
Expand All @@ -201,6 +211,15 @@ def jnp(
return xp


@pytest.fixture(params=[Backend.TORCH, Backend.TORCH_GPU])
def torch(request: pytest.FixtureRequest) -> ModuleType: # numpydoc ignore=PR01,RT01
"""Variant of the `xp` fixture that only yields torch."""
xp = pytest.importorskip("torch")
xp = array_namespace(xp.empty(0))
_setup_torch(request.param)
return xp


@pytest.fixture
def device(
library: Backend, xp: ModuleType
Expand Down
34 changes: 34 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
broadcast_shapes,
cov,
create_diagonal,
default_dtype,
expand_dims,
isclose,
kron,
Expand Down Expand Up @@ -517,6 +518,39 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no __array_namespace_info__")
class TestDefaultDType:
def test_basic(self, xp: ModuleType):
assert default_dtype(xp) == xp.empty(0).dtype

def test_kind(self, xp: ModuleType):
assert default_dtype(xp, "real floating") == xp.empty(0).dtype
assert default_dtype(xp, "complex floating") == (xp.empty(0) * 1j).dtype
assert default_dtype(xp, "integral") == xp.int64
assert default_dtype(xp, "indexing") == xp.int64

with pytest.raises(ValueError, match="Unknown kind"):
_ = default_dtype(xp, "foo") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

def test_device(self, xp: ModuleType, device: Device):
# Note: at the moment there are no known namespaces with
# device-specific default dtypes.
assert default_dtype(xp, device=None) == xp.empty(0).dtype
assert default_dtype(xp, device=device) == xp.empty(0).dtype

def test_torch(self, torch: ModuleType):
xp = torch
xp.set_default_dtype(xp.float64)
assert default_dtype(xp) == xp.float64
assert default_dtype(xp, "real floating") == xp.float64
assert default_dtype(xp, "complex floating") == xp.complex128

xp.set_default_dtype(xp.float32)
assert default_dtype(xp) == xp.float32
assert default_dtype(xp, "real floating") == xp.float32
assert default_dtype(xp, "complex floating") == xp.complex64


class TestExpandDims:
def test_single_axis(self, xp: ModuleType):
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
Expand Down