From 87eec29d44a2b377d9d0318d3b8e1669b478635b Mon Sep 17 00:00:00 2001 From: Pamphile Roy Date: Wed, 26 Apr 2023 17:48:08 +0200 Subject: [PATCH 1/3] ENH: add fallback_namespace --- array_api_compat/common/_helpers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e6adc948..d05e8b1c 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -53,7 +53,7 @@ def _check_api_version(api_version): if api_version is not None and api_version != '2021.12': raise ValueError("Only the 2021.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, _use_compat=True): +def array_namespace(*xs, api_version=None, _use_compat=True, fallback_namespace=None): """ Get the array API compatible namespace for the arrays `xs`. @@ -72,7 +72,9 @@ def your_function(x, y): namespaces = set() for x in xs: if isinstance(x, (tuple, list)): - namespaces.add(array_namespace(*x, _use_compat=_use_compat)) + namespaces.add(array_namespace( + *x, _use_compat=_use_compat, fallback_namespace=fallback_namespace + )) elif hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) elif _is_numpy_array(x): @@ -99,6 +101,8 @@ def your_function(x, y): else: import torch namespaces.add(torch) + elif fallback_namespace is not None: + namespaces.add(fallback_namespace) else: # TODO: Support Python scalars? raise TypeError("The input is not a supported array type") From 3b9023841b6633522a7429c2fd49709fc4aef27d Mon Sep 17 00:00:00 2001 From: Pamphile Roy Date: Wed, 26 Apr 2023 17:48:35 +0200 Subject: [PATCH 2/3] TST: test fallback and multiple namespaces due to fallback --- tests/test_array_namespace.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 806b1192..d2c3ae05 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -26,6 +26,22 @@ def test_array_namespace_multiple(): assert array_namespace(x, x) == array_namespace((x, x)) == \ array_namespace((x, x), x) == array_api_compat.numpy +def test_fallback_namespace(): + import numpy as np + import array_api_compat.numpy + + xp = array_api_compat.numpy + xp_ = array_namespace([1, 2], fallback_namespace=xp) + assert xp_ == xp + + xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=xp) + assert xp_ == xp + + msg = 'Multiple namespaces' + with pytest.raises(TypeError, match=msg): + array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=np) + + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) From 0acc2b019e30b8d176ec698bb76e1a27de999da8 Mon Sep 17 00:00:00 2001 From: Pamphile Roy Date: Thu, 27 Apr 2023 13:28:30 +0200 Subject: [PATCH 3/3] ENH: add conversion to Array API compatible namespace --- array_api_compat/common/_helpers.py | 11 +++++++++++ tests/test_array_namespace.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index d05e8b1c..29886ca7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -69,6 +69,17 @@ def your_function(x, y): api_version should be the newest version of the spec that you need support for (currently the compat library wrapped APIs only support v2021.12). """ + # convert fallback_namespace + if fallback_namespace is not None: + try: + x_ = fallback_namespace.asarray(1) + fallback_namespace = array_namespace( + x_, _use_compat=_use_compat + ) + except AttributeError as exc: + msg = "'fallback_namespace' must be an Array API compatible namespace" + raise TypeError(msg) from exc + namespaces = set() for x in xs: if isinstance(x, (tuple, list)): diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index d2c3ae05..059d4e8e 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -28,6 +28,7 @@ def test_array_namespace_multiple(): def test_fallback_namespace(): import numpy as np + import numpy.array_api import array_api_compat.numpy xp = array_api_compat.numpy @@ -37,9 +38,18 @@ def test_fallback_namespace(): xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=xp) assert xp_ == xp + # convert to Array API compatible namespace + xp = array_api_compat.numpy + xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=np) + assert xp_ == xp + msg = 'Multiple namespaces' with pytest.raises(TypeError, match=msg): - array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=np) + array_namespace([1, 2], numpy.array_api.asarray([1, 2]), fallback_namespace=np) + + msg = "'fallback_namespace' must be an Array API" + with pytest.raises(TypeError, match=msg): + array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace="hop") def test_array_namespace_errors():