From cf128458418433de71839efdbc564dd01904e14b Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 14:17:32 +0200 Subject: [PATCH 1/3] backend None compatibility with test --- ot/backend.py | 6 +++++- test/test_backend.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 7b2fe875f..a80c5ae73 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -157,11 +157,15 @@ def _check_args_backend(backend, args): def get_backend(*args): """Returns the proper backend for a list of input arrays + Accepts None entries in the arguments, and ignores them + Also raises TypeError if all arrays are not from the same backend """ + args = [arg for arg in args if arg is not None] # exclude None entries + # check that some arrays given if not len(args) > 0: - raise ValueError(" The function takes at least one parameter") + raise ValueError(" The function takes at least one (non-None) parameter") for backend in _BACKENDS: if _check_args_backend(backend, args): diff --git a/test/test_backend.py b/test/test_backend.py index f0571471c..bfca6139d 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -3,6 +3,7 @@ # Author: Remi Flamary # Nicolas Courty # +# # License: MIT License import ot @@ -753,3 +754,11 @@ def fun(a, b, d): [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b]) assert nx.allclose(dl_dw, w) assert nx.allclose(dl_db, b) + + +def test_get_backend_none(): + a, b = np.zeros((2, 3)), None + nx = get_backend(a, b) + assert str(nx) == 'numpy' + with pytest.raises(ValueError): + get_backend(None, None) From 1a0d3a2e5e4a9983d8e862750c4bc71ba6c77969 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 14:55:03 +0200 Subject: [PATCH 2/3] reference PR in RELEASES.md --- RELEASES.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index d0209e233..9b94ecaf3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,9 @@ # Releases +## 0.9.2dev + +Tweaked `get_backend` to ignore `None` inputs (PR # 525) + ## 0.9.1 *August 2023* From 10f69edcc57f017ddbc9df5bc5fa33569203b1d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 20 Sep 2023 15:10:27 +0200 Subject: [PATCH 3/3] Update RELEASES.md --- RELEASES.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 9b94ecaf3..1b98d51bf 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,7 +2,11 @@ ## 0.9.2dev -Tweaked `get_backend` to ignore `None` inputs (PR # 525) +#### New features ++ Tweaked `get_backend` to ignore `None` inputs (PR # 525) + +#### Closed issues + ## 0.9.1 *August 2023*