Skip to content

Commit a3fd883

Browse files
committed
Explicitly check for correct loss_fun value
1 parent 7ee082d commit a3fd883

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

ot/gromov/_bregman.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
8181
q : array-like, shape (nt,), optional
8282
Distribution in the target space.
8383
If let to its default value None, uniform distribution is taken.
84-
loss_fun : string, optional
84+
loss_fun : string, optional (default='square_loss')
8585
Loss function used for the solver either 'square_loss' or 'kl_loss'
8686
epsilon : float, optional
8787
Regularization term >0
@@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
135135
if solver not in ['PGD', 'PPA']:
136136
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
137137

138+
if loss_fun not in ('square_loss', 'kl_loss'):
139+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
140+
138141
C1, C2 = list_to_array(C1, C2)
139142
arr = [C1, C2]
140143
if p is not None:
@@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
411414
"Gromov-Wasserstein averaging of kernel and distance matrices."
412415
International Conference on Machine Learning (ICML). 2016.
413416
"""
417+
if loss_fun not in ('square_loss', 'kl_loss'):
418+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
419+
414420
Cs = list_to_array(*Cs)
415421
arr = [*Cs]
416422
if ps is not None:
@@ -459,7 +465,6 @@ def entropic_gromov_barycenters(
459465

460466
if loss_fun == 'square_loss':
461467
C = update_square_loss(p, lambdas, T, Cs)
462-
463468
elif loss_fun == 'kl_loss':
464469
C = update_kl_loss(p, lambdas, T, Cs)
465470

0 commit comments

Comments
 (0)