diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 77175d0d..a152e4c0 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -8,6 +8,7 @@ from __future__ import annotations +import enum import inspect import math import sys @@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None: ) +class _ClsToXPInfo(enum.Enum): + SCALAR = 0 + MAYBE_JAX_ZERO_GRADIENT = 1 + + +@lru_cache(100) +def _cls_to_namespace( + cls: type, + api_version: str | None, + use_compat: bool | None, +) -> tuple[Namespace | None, _ClsToXPInfo | None]: + if use_compat not in (None, True, False): + raise ValueError("use_compat must be None, True, or False") + _use_compat = use_compat in (None, True) + cls_ = cast(Hashable, cls) # Make mypy happy + + if ( + _issubclass_fast(cls_, "numpy", "ndarray") + or _issubclass_fast(cls_, "numpy", "generic") + ): + if use_compat is True: + _check_api_version(api_version) + from .. import numpy as xp + elif use_compat is False: + import numpy as xp # type: ignore[no-redef] + else: + # NumPy 2.0+ have __array_namespace__; however they are not + # yet fully array API compatible. + from .. import numpy as xp # type: ignore[no-redef] + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + + # Note: this must happen _after_ the test for np.generic, + # because np.float64 and np.complex128 are subclasses of float and complex. + if issubclass(cls, int | float | complex | type(None)): + return None, _ClsToXPInfo.SCALAR + + if _issubclass_fast(cls_, "cupy", "ndarray"): + if _use_compat: + _check_api_version(api_version) + from .. import cupy as xp # type: ignore[no-redef] + else: + import cupy as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "torch", "Tensor"): + if _use_compat: + _check_api_version(api_version) + from .. import torch as xp # type: ignore[no-redef] + else: + import torch as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "dask.array", "Array"): + if _use_compat: + _check_api_version(api_version) + from ..dask import array as xp # type: ignore[no-redef] + else: + import dask.array as xp # type: ignore[no-redef] + return xp, None + + # Backwards compatibility for jax<0.4.32 + if _issubclass_fast(cls_, "jax", "Array"): + return _jax_namespace(api_version, use_compat), None + + return None, None + + +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: + if use_compat: + raise ValueError("JAX does not have an array-api-compat wrapper") + import jax.numpy as jnp + if not hasattr(jnp, "__array_namespace_info__"): + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. + # For older JAX versions, it is available via jax.experimental.array_api. + # jnp.Array objects gain the __array_namespace__ method. + import jax.experimental.array_api # noqa: F401 + # Test api_version + return jnp.empty(0).__array_namespace__(api_version=api_version) + + def array_namespace( *xs: Array | complex | None, api_version: str | None = None, @@ -553,105 +634,40 @@ def your_function(x, y): is_pydata_sparse_array """ - if use_compat not in [None, True, False]: - raise ValueError("use_compat must be None, True, or False") - - _use_compat = use_compat in [None, True] - namespaces: set[Namespace] = set() for x in xs: - if is_numpy_array(x): - import numpy as np - - from .. import numpy as numpy_namespace - - if use_compat is True: - _check_api_version(api_version) - namespaces.add(numpy_namespace) - elif use_compat is False: - namespaces.add(np) - else: - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API - # compatible. - namespaces.add(numpy_namespace) - elif is_cupy_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import cupy as cupy_namespace - - namespaces.add(cupy_namespace) - else: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - namespaces.add(cp) - elif is_torch_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import torch as torch_namespace - - namespaces.add(torch_namespace) - else: - import torch - - namespaces.add(torch) - elif is_dask_array(x): - if _use_compat: - _check_api_version(api_version) - from ..dask import array as dask_namespace - - namespaces.add(dask_namespace) - else: - import dask.array as da - - namespaces.add(da) - elif is_jax_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("JAX does not have an array-api-compat wrapper") - elif use_compat is False: - import jax.numpy as jnp - else: - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. - # For older JAX versions, it is available via jax.experimental.array_api. - import jax.numpy - - if hasattr(jax.numpy, "__array_api_version__"): - jnp = jax.numpy - else: - import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] - namespaces.add(jnp) - elif is_pydata_sparse_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("`sparse` does not have an array-api-compat wrapper") - else: - import sparse # pyright: ignore[reportMissingTypeStubs] - # `sparse` is already an array namespace. We do not have a wrapper - # submodule for it. - namespaces.add(sparse) - elif hasattr(x, "__array_namespace__"): - if use_compat is True: + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) + if info is _ClsToXPInfo.SCALAR: + continue + + if ( + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + and _is_jax_zero_gradient_array(x) + ): + xp = _jax_namespace(api_version, use_compat) + + if xp is None: + get_ns = getattr(x, "__array_namespace__", None) + if get_ns is None: + raise TypeError(f"{type(x).__name__} is not a supported array type") + if use_compat: raise ValueError( "The given array does not have an array-api-compat wrapper" ) - x = cast("SupportsArrayNamespace[Any]", x) - namespaces.add(x.__array_namespace__(api_version=api_version)) - elif isinstance(x, (bool, int, float, complex, type(None))): - continue - else: - # TODO: Support Python scalars? - raise TypeError(f"{type(x).__name__} is not a supported array type") + xp = get_ns(api_version=api_version) - if not namespaces: - raise TypeError("Unrecognized array input") + namespaces.add(xp) - if len(namespaces) != 1: + try: + (xp,) = namespaces + return xp + except ValueError: + if not namespaces: + raise TypeError( + "array_namespace requires at least one non-scalar array input" + ) raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - (xp,) = namespaces - - return xp - # backwards compatibility alias get_namespace = array_namespace diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 2fbb0339..311efc37 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -8,42 +8,41 @@ import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_, all_libraries, wrapped_libraries +from ._helpers import all_libraries, wrapped_libraries, xfail + @pytest.mark.parametrize("use_compat", [True, False, None]) -@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) +@pytest.mark.parametrize( + "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"] +) @pytest.mark.parametrize("library", all_libraries) -def test_array_namespace(library, api_version, use_compat): - xp = import_(library) +def test_array_namespace(request, library, api_version, use_compat): + xp = pytest.importorskip(library) array = xp.asarray([1.0, 2.0, 3.0]) if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return - if library == "ndonnx" and api_version in ("2021.12", "2022.12"): - pytest.skip("Unsupported API version") + if (library == "sparse" and api_version in ("2023.12", "2024.12")) or ( + library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12") + ): + xfail(request, "Unsupported API version") with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: - if library == "jax.numpy" and use_compat is None: - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - # JAX v0.4.32 or later uses jax.numpy directly - assert namespace == jax.numpy - else: - # JAX v0.4.31 or earlier uses jax.experimental.array_api - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"): + # Backwards compatibility for JAX <0.4.32 + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp + elif library == "dask.array": + assert namespace == array_api_compat.dask.array else: - if library == "dask.array": - assert namespace == array_api_compat.dask.array - else: - assert namespace == getattr(array_api_compat, library) + assert namespace == getattr(array_api_compat, library) if library == "numpy": # check that the same namespace is returned for NumPy scalars @@ -55,20 +54,20 @@ def test_array_namespace(library, api_version, use_compat): ) assert scalar_namespace == namespace - # Check that array_namespace works even if jax.experimental.array_api - # hasn't been imported yet (it monkeypatches __array_namespace__ - # onto JAX arrays, but we should support them regardless). The only way to - # do this is to use a subprocess, since we cannot un-import it and another - # test probably already imported it. - if library == "jax.numpy" and sys.version_info >= (3, 9): - code = f"""\ + +def test_jax_backwards_compat(): + """On JAX <0.4.32, test that array_namespace works even if + jax.experimental.array_api has not been imported yet. + """ + pytest.importorskip("jax") + code = """\ import sys import jax.numpy import array_api_compat -array = jax.numpy.asarray([1.0, 2.0, 3.0]) +array = jax.numpy.asarray([1.0, 2.0, 3.0]) assert 'jax.experimental.array_api' not in sys.modules -namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) +namespace = array_api_compat.array_namespace(array) if hasattr(jax.numpy, '__array_api_version__'): assert namespace == jax.numpy @@ -76,14 +75,16 @@ def test_array_namespace(library, api_version, use_compat): import jax.experimental.array_api assert namespace == jax.experimental.array_api """ - subprocess.run([sys.executable, "-c", code], check=True) + subprocess.check_call([sys.executable, "-c", code]) + def test_jax_zero_gradient(): - jax = import_("jax") + jax = pytest.importorskip("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) @@ -92,43 +93,31 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) -def test_array_namespace_errors_torch(): - torch = import_("torch") - y = torch.asarray([1, 2]) - x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace(x, y)) -def test_api_version_torch(): - torch = import_("torch") - x = torch.asarray([1, 2]) - torch_ = import_("torch", wrapper=True) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', UserWarning) - assert array_namespace(x, api_version="2023.12") == torch_ - assert array_namespace(x, api_version=None) == torch_ - assert array_namespace(x) == torch_ - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2021.12") == torch_ - assert len(w) == 1 - assert "2021.12" in str(w[0].message) - - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2022.12") == torch_ - assert len(w) == 1 - assert "2022.12" in str(w[0].message) - - pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) +@pytest.mark.parametrize("library", all_libraries) +def test_array_namespace_many_args(library): + xp = pytest.importorskip(library) + a = xp.asarray(1) + b = xp.asarray(2) + assert array_namespace(a, b) is array_namespace(a) + + +def test_array_namespace_mismatch(): + xp = pytest.importorskip("array_api_strict") + with pytest.raises(TypeError, match="Multiple namespaces"): + array_namespace(np.asarray(1), xp.asarray(1)) + def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_namespace -def test_python_scalars(): - torch = import_("torch") - a = torch.asarray([1, 2]) - xp = import_("torch", wrapper=True) + +@pytest.mark.parametrize("library", all_libraries) +def test_python_scalars(library): + xp = pytest.importorskip(library) + a = xp.asarray([1, 2]) + xp = array_namespace(a) pytest.raises(TypeError, lambda: array_namespace(1)) pytest.raises(TypeError, lambda: array_namespace(1.0)) @@ -136,8 +125,8 @@ def test_python_scalars(): pytest.raises(TypeError, lambda: array_namespace(True)) pytest.raises(TypeError, lambda: array_namespace(None)) - assert array_namespace(a, 1) == xp - assert array_namespace(a, 1.0) == xp - assert array_namespace(a, 1j) == xp - assert array_namespace(a, True) == xp - assert array_namespace(a, None) == xp + assert array_namespace(a, 1) is xp + assert array_namespace(a, 1.0) is xp + assert array_namespace(a, 1j) is xp + assert array_namespace(a, True) is xp + assert array_namespace(a, None) is xp