Skip to content

Commit cf12845

Browse files
committed
backend None compatibility with test
1 parent 064898d commit cf12845

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

ot/backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,15 @@ def _check_args_backend(backend, args):
157157
def get_backend(*args):
158158
"""Returns the proper backend for a list of input arrays
159159
160+
Accepts None entries in the arguments, and ignores them
161+
160162
Also raises TypeError if all arrays are not from the same backend
161163
"""
164+
args = [arg for arg in args if arg is not None] # exclude None entries
165+
162166
# check that some arrays given
163167
if not len(args) > 0:
164-
raise ValueError(" The function takes at least one parameter")
168+
raise ValueError(" The function takes at least one (non-None) parameter")
165169

166170
for backend in _BACKENDS:
167171
if _check_args_backend(backend, args):

test/test_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
44
# Nicolas Courty <ncourty@irisa.fr>
55
#
6+
#
67
# License: MIT License
78

89
import ot
@@ -753,3 +754,11 @@ def fun(a, b, d):
753754
[dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b])
754755
assert nx.allclose(dl_dw, w)
755756
assert nx.allclose(dl_db, b)
757+
758+
759+
def test_get_backend_none():
760+
a, b = np.zeros((2, 3)), None
761+
nx = get_backend(a, b)
762+
assert str(nx) == 'numpy'
763+
with pytest.raises(ValueError):
764+
get_backend(None, None)

0 commit comments

Comments
 (0)