Skip to content

Commit 9e74f2e

Browse files
eloitanguyrflamary
andauthored
[MRG] get_backend compatibility with None entries (#525)
* backend None compatibility with test * reference PR in RELEASES.md * Update RELEASES.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 064898d commit 9e74f2e

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

RELEASES.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Releases
22

3+
## 0.9.2dev
4+
5+
#### New features
6+
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
7+
8+
#### Closed issues
9+
10+
311
## 0.9.1
412
*August 2023*
513

ot/backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,15 @@ def _check_args_backend(backend, args):
157157
def get_backend(*args):
158158
"""Returns the proper backend for a list of input arrays
159159
160+
Accepts None entries in the arguments, and ignores them
161+
160162
Also raises TypeError if all arrays are not from the same backend
161163
"""
164+
args = [arg for arg in args if arg is not None] # exclude None entries
165+
162166
# check that some arrays given
163167
if not len(args) > 0:
164-
raise ValueError(" The function takes at least one parameter")
168+
raise ValueError(" The function takes at least one (non-None) parameter")
165169

166170
for backend in _BACKENDS:
167171
if _check_args_backend(backend, args):

test/test_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
44
# Nicolas Courty <ncourty@irisa.fr>
55
#
6+
#
67
# License: MIT License
78

89
import ot
@@ -753,3 +754,11 @@ def fun(a, b, d):
753754
[dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b])
754755
assert nx.allclose(dl_dw, w)
755756
assert nx.allclose(dl_db, b)
757+
758+
759+
def test_get_backend_none():
760+
a, b = np.zeros((2, 3)), None
761+
nx = get_backend(a, b)
762+
assert str(nx) == 'numpy'
763+
with pytest.raises(ValueError):
764+
get_backend(None, None)

0 commit comments

Comments
 (0)