@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
81
81
q : array-like, shape (nt,), optional
82
82
Distribution in the target space.
83
83
If let to its default value None, uniform distribution is taken.
84
- loss_fun : string, optional (default='square_loss')
84
+ loss_fun : string, optional (default='square_loss')
85
85
Loss function used for the solver either 'square_loss' or 'kl_loss'
86
86
epsilon : float, optional
87
87
Regularization term >0
@@ -92,7 +92,7 @@ def entropic_gromov_wasserstein(
92
92
G0: array-like, shape (ns,nt), optional
93
93
If None the initial transport plan of the solver is pq^T.
94
94
Otherwise G0 will be used as initial transport of the solver. G0 is not
95
- required to satisfy marginal constraints but we strongly recommand it
95
+ required to satisfy marginal constraints but we strongly recommend it
96
96
to correctly estimate the GW distance.
97
97
max_iter : int, optional
98
98
Max number of iterations
@@ -283,7 +283,7 @@ def entropic_gromov_wasserstein2(
283
283
q : array-like, shape (nt,), optional
284
284
Distribution in the target space.
285
285
If let to its default value None, uniform distribution is taken.
286
- loss_fun : string, optional
286
+ loss_fun : string, optional (default='square_loss')
287
287
Loss function used for the solver either 'square_loss' or 'kl_loss'
288
288
epsilon : float, optional
289
289
Regularization term >0
@@ -376,7 +376,7 @@ def entropic_gromov_barycenters(
376
376
lambdas : list of float, optional
377
377
List of the `S` spaces' weights.
378
378
If let to its default value None, uniform weights are taken.
379
- loss_fun : string, optional (default='square_loss')
379
+ loss_fun : string, optional (default='square_loss')
380
380
Loss function used for the solver either 'square_loss' or 'kl_loss'
381
381
epsilon : float, optional
382
382
Regularization term >0
@@ -555,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
555
555
q : array-like, shape (nt,), optional
556
556
Distribution in the target space.
557
557
If let to its default value None, uniform distribution is taken.
558
- loss_fun : string, optional
558
+ loss_fun : string, optional (default='square_loss')
559
559
Loss function used for the solver either 'square_loss' or 'kl_loss'
560
560
epsilon : float, optional
561
561
Regularization term >0
562
562
symmetric : bool, optional
563
563
Either C1 and C2 are to be assumed symmetric or not.
564
564
If let to its default None value, a symmetry test will be conducted.
565
- Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric ).
565
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric ).
566
566
alpha : float, optional
567
567
Trade-off parameter (0 < alpha < 1)
568
568
G0: array-like, shape (ns,nt), optional
569
569
If None the initial transport plan of the solver is pq^T.
570
570
Otherwise G0 will be used as initial transport of the solver. G0 is not
571
- required to satisfy marginal constraints but we strongly recommand it
572
- to correcly estimate the GW distance.
571
+ required to satisfy marginal constraints but we strongly recommend it
572
+ to correctly estimate the GW distance.
573
573
max_iter : int, optional
574
574
Max number of iterations
575
575
tol : float, optional
@@ -616,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
616
616
if solver not in ['PGD' , 'PPA' ]:
617
617
raise ValueError ("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver )
618
618
619
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
620
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
621
+
619
622
M , C1 , C2 = list_to_array (M , C1 , C2 )
620
623
arr = [M , C1 , C2 ]
621
624
if p is not None :
@@ -767,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
767
770
q : array-like, shape (nt,), optional
768
771
Distribution in the target space.
769
772
If let to its default value None, uniform distribution is taken.
770
- loss_fun : string, optional
773
+ loss_fun : string, optional (default='square_loss')
771
774
Loss function used for the solver either 'square_loss' or 'kl_loss'
772
775
epsilon : float, optional
773
776
Regularization term >0
@@ -780,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
780
783
G0: array-like, shape (ns,nt), optional
781
784
If None the initial transport plan of the solver is pq^T.
782
785
Otherwise G0 will be used as initial transport of the solver. G0 is not
783
- required to satisfy marginal constraints but we strongly recommand it
784
- to correcly estimate the GW distance.
786
+ required to satisfy marginal constraints but we strongly recommend it
787
+ to correctly estimate the GW distance.
785
788
max_iter : int, optional
786
789
Max number of iterations
787
790
tol : float, optional
@@ -862,8 +865,8 @@ def entropic_fused_gromov_barycenters(
862
865
lambdas : list of float, optional
863
866
List of the `S` spaces' weights.
864
867
If let to its default value None, uniform weights are taken.
865
- loss_fun : callable , optional
866
- tensor-matrix multiplication function based on specific loss function
868
+ loss_fun : string , optional (default='square_loss')
869
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
867
870
epsilon : float, optional
868
871
Regularization term >0
869
872
symmetric : bool, optional.
@@ -912,6 +915,9 @@ def entropic_fused_gromov_barycenters(
912
915
"Optimal Transport for structured data with application on graphs"
913
916
International Conference on Machine Learning (ICML). 2019.
914
917
"""
918
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
919
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
920
+
915
921
Cs = list_to_array (* Cs )
916
922
Ys = list_to_array (* Ys )
917
923
arr = [* Cs , * Ys ]
@@ -982,7 +988,6 @@ def entropic_fused_gromov_barycenters(
982
988
983
989
if loss_fun == 'square_loss' :
984
990
C = update_square_loss (p , lambdas , T , Cs )
985
-
986
991
elif loss_fun == 'kl_loss' :
987
992
C = update_kl_loss (p , lambdas , T , Cs )
988
993
0 commit comments