diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e6adc948..29886ca7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -53,7 +53,7 @@ def _check_api_version(api_version): if api_version is not None and api_version != '2021.12': raise ValueError("Only the 2021.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, _use_compat=True): +def array_namespace(*xs, api_version=None, _use_compat=True, fallback_namespace=None): """ Get the array API compatible namespace for the arrays `xs`. @@ -69,10 +69,23 @@ def your_function(x, y): api_version should be the newest version of the spec that you need support for (currently the compat library wrapped APIs only support v2021.12). """ + # convert fallback_namespace + if fallback_namespace is not None: + try: + x_ = fallback_namespace.asarray(1) + fallback_namespace = array_namespace( + x_, _use_compat=_use_compat + ) + except AttributeError as exc: + msg = "'fallback_namespace' must be an Array API compatible namespace" + raise TypeError(msg) from exc + namespaces = set() for x in xs: if isinstance(x, (tuple, list)): - namespaces.add(array_namespace(*x, _use_compat=_use_compat)) + namespaces.add(array_namespace( + *x, _use_compat=_use_compat, fallback_namespace=fallback_namespace + )) elif hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) elif _is_numpy_array(x): @@ -99,6 +112,8 @@ def your_function(x, y): else: import torch namespaces.add(torch) + elif fallback_namespace is not None: + namespaces.add(fallback_namespace) else: # TODO: Support Python scalars? raise TypeError("The input is not a supported array type") diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 806b1192..059d4e8e 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -26,6 +26,32 @@ def test_array_namespace_multiple(): assert array_namespace(x, x) == array_namespace((x, x)) == \ array_namespace((x, x), x) == array_api_compat.numpy +def test_fallback_namespace(): + import numpy as np + import numpy.array_api + import array_api_compat.numpy + + xp = array_api_compat.numpy + xp_ = array_namespace([1, 2], fallback_namespace=xp) + assert xp_ == xp + + xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=xp) + assert xp_ == xp + + # convert to Array API compatible namespace + xp = array_api_compat.numpy + xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=np) + assert xp_ == xp + + msg = 'Multiple namespaces' + with pytest.raises(TypeError, match=msg): + array_namespace([1, 2], numpy.array_api.asarray([1, 2]), fallback_namespace=np) + + msg = "'fallback_namespace' must be an Array API" + with pytest.raises(TypeError, match=msg): + array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace="hop") + + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace())