@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
81
81
q : array-like, shape (nt,), optional
82
82
Distribution in the target space.
83
83
If let to its default value None, uniform distribution is taken.
84
- loss_fun : string, optional
84
+ loss_fun : string, optional (default='square_loss')
85
85
Loss function used for the solver either 'square_loss' or 'kl_loss'
86
86
epsilon : float, optional
87
87
Regularization term >0
@@ -92,8 +92,8 @@ def entropic_gromov_wasserstein(
92
92
G0: array-like, shape (ns,nt), optional
93
93
If None the initial transport plan of the solver is pq^T.
94
94
Otherwise G0 will be used as initial transport of the solver. G0 is not
95
- required to satisfy marginal constraints but we strongly recommand it
96
- to correcly estimate the GW distance.
95
+ required to satisfy marginal constraints but we strongly recommend it
96
+ to correctly estimate the GW distance.
97
97
max_iter : int, optional
98
98
Max number of iterations
99
99
tol : float, optional
@@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
135
135
if solver not in ['PGD' , 'PPA' ]:
136
136
raise ValueError ("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver )
137
137
138
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
139
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
140
+
138
141
C1 , C2 = list_to_array (C1 , C2 )
139
142
arr = [C1 , C2 ]
140
143
if p is not None :
@@ -280,7 +283,7 @@ def entropic_gromov_wasserstein2(
280
283
q : array-like, shape (nt,), optional
281
284
Distribution in the target space.
282
285
If let to its default value None, uniform distribution is taken.
283
- loss_fun : string, optional
286
+ loss_fun : string, optional (default='square_loss')
284
287
Loss function used for the solver either 'square_loss' or 'kl_loss'
285
288
epsilon : float, optional
286
289
Regularization term >0
@@ -373,8 +376,8 @@ def entropic_gromov_barycenters(
373
376
lambdas : list of float, optional
374
377
List of the `S` spaces' weights.
375
378
If let to its default value None, uniform weights are taken.
376
- loss_fun : callable , optional
377
- tensor-matrix multiplication function based on specific loss function
379
+ loss_fun : string , optional (default='square_loss')
380
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
378
381
epsilon : float, optional
379
382
Regularization term >0
380
383
symmetric : bool, optional.
@@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
411
414
"Gromov-Wasserstein averaging of kernel and distance matrices."
412
415
International Conference on Machine Learning (ICML). 2016.
413
416
"""
417
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
418
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
419
+
414
420
Cs = list_to_array (* Cs )
415
421
arr = [* Cs ]
416
422
if ps is not None :
@@ -459,7 +465,6 @@ def entropic_gromov_barycenters(
459
465
460
466
if loss_fun == 'square_loss' :
461
467
C = update_square_loss (p , lambdas , T , Cs )
462
-
463
468
elif loss_fun == 'kl_loss' :
464
469
C = update_kl_loss (p , lambdas , T , Cs )
465
470
@@ -550,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
550
555
q : array-like, shape (nt,), optional
551
556
Distribution in the target space.
552
557
If let to its default value None, uniform distribution is taken.
553
- loss_fun : string, optional
558
+ loss_fun : string, optional (default='square_loss')
554
559
Loss function used for the solver either 'square_loss' or 'kl_loss'
555
560
epsilon : float, optional
556
561
Regularization term >0
557
562
symmetric : bool, optional
558
563
Either C1 and C2 are to be assumed symmetric or not.
559
564
If let to its default None value, a symmetry test will be conducted.
560
- Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric ).
565
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric ).
561
566
alpha : float, optional
562
567
Trade-off parameter (0 < alpha < 1)
563
568
G0: array-like, shape (ns,nt), optional
564
569
If None the initial transport plan of the solver is pq^T.
565
570
Otherwise G0 will be used as initial transport of the solver. G0 is not
566
- required to satisfy marginal constraints but we strongly recommand it
567
- to correcly estimate the GW distance.
571
+ required to satisfy marginal constraints but we strongly recommend it
572
+ to correctly estimate the GW distance.
568
573
max_iter : int, optional
569
574
Max number of iterations
570
575
tol : float, optional
@@ -611,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
611
616
if solver not in ['PGD' , 'PPA' ]:
612
617
raise ValueError ("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver )
613
618
619
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
620
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
621
+
614
622
M , C1 , C2 = list_to_array (M , C1 , C2 )
615
623
arr = [M , C1 , C2 ]
616
624
if p is not None :
@@ -762,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
762
770
q : array-like, shape (nt,), optional
763
771
Distribution in the target space.
764
772
If let to its default value None, uniform distribution is taken.
765
- loss_fun : string, optional
773
+ loss_fun : string, optional (default='square_loss')
766
774
Loss function used for the solver either 'square_loss' or 'kl_loss'
767
775
epsilon : float, optional
768
776
Regularization term >0
@@ -775,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
775
783
G0: array-like, shape (ns,nt), optional
776
784
If None the initial transport plan of the solver is pq^T.
777
785
Otherwise G0 will be used as initial transport of the solver. G0 is not
778
- required to satisfy marginal constraints but we strongly recommand it
779
- to correcly estimate the GW distance.
786
+ required to satisfy marginal constraints but we strongly recommend it
787
+ to correctly estimate the GW distance.
780
788
max_iter : int, optional
781
789
Max number of iterations
782
790
tol : float, optional
@@ -857,8 +865,8 @@ def entropic_fused_gromov_barycenters(
857
865
lambdas : list of float, optional
858
866
List of the `S` spaces' weights.
859
867
If let to its default value None, uniform weights are taken.
860
- loss_fun : callable , optional
861
- tensor-matrix multiplication function based on specific loss function
868
+ loss_fun : string , optional (default='square_loss')
869
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
862
870
epsilon : float, optional
863
871
Regularization term >0
864
872
symmetric : bool, optional.
@@ -907,6 +915,9 @@ def entropic_fused_gromov_barycenters(
907
915
"Optimal Transport for structured data with application on graphs"
908
916
International Conference on Machine Learning (ICML). 2019.
909
917
"""
918
+ if loss_fun not in ('square_loss' , 'kl_loss' ):
919
+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
920
+
910
921
Cs = list_to_array (* Cs )
911
922
Ys = list_to_array (* Ys )
912
923
arr = [* Cs , * Ys ]
@@ -977,7 +988,6 @@ def entropic_fused_gromov_barycenters(
977
988
978
989
if loss_fun == 'square_loss' :
979
990
C = update_square_loss (p , lambdas , T , Cs )
980
-
981
991
elif loss_fun == 'kl_loss' :
982
992
C = update_kl_loss (p , lambdas , T , Cs )
983
993
@@ -1004,7 +1014,6 @@ def entropic_fused_gromov_barycenters(
1004
1014
print ('{:5d}|{:8e}|' .format (cpt , err_feature ))
1005
1015
1006
1016
cpt += 1
1007
- print ('Y type:' , type (Y ))
1008
1017
if log :
1009
1018
log_ ['T' ] = T # from target to Ys
1010
1019
log_ ['p' ] = p
0 commit comments