@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
81
81
q : array-like, shape (nt,), optional
82
82
Distribution in the target space.
83
83
If let to its default value None, uniform distribution is taken.
84
- loss_fun : string, optional
84
+ loss_fun : string, optional (default='square_loss')
85
85
Loss function used for the solver either 'square_loss' or 'kl_loss'
86
86
epsilon : float, optional
87
87
Regularization term >0
@@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
135
135
if solver not in ['PGD' , 'PPA' ]:
136
136
raise ValueError ("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver )
137
137
138
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
139
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
140
+
138
141
C1 , C2 = list_to_array (C1 , C2 )
139
142
arr = [C1 , C2 ]
140
143
if p is not None :
@@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
411
414
"Gromov-Wasserstein averaging of kernel and distance matrices."
412
415
International Conference on Machine Learning (ICML). 2016.
413
416
"""
417
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
418
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
419
+
414
420
Cs = list_to_array (* Cs )
415
421
arr = [* Cs ]
416
422
if ps is not None :
@@ -459,7 +465,6 @@ def entropic_gromov_barycenters(
459
465
460
466
if loss_fun == 'square_loss' :
461
467
C = update_square_loss (p , lambdas , T , Cs )
462
-
463
468
elif loss_fun == 'kl_loss' :
464
469
C = update_kl_loss (p , lambdas , T , Cs )
465
470
0 commit comments