@@ -166,10 +166,6 @@ def df(G):
166
166
def df (G ):
167
167
return 0.5 * (gwggrad (constC , hC1 , hC2 , G , np_ ) + gwggrad (constCt , hC1t , hC2t , G , np_ ))
168
168
169
- # removed since 0.9.2
170
- #if loss_fun == 'kl_loss':
171
- # armijo = True # there is no closed form line-search with KL
172
-
173
169
if armijo :
174
170
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
175
171
return line_search_armijo (cost , G , deltaG , Mi , cost_G , nx = np_ , ** kwargs )
@@ -478,10 +474,6 @@ def df(G):
478
474
def df (G ):
479
475
return 0.5 * (gwggrad (constC , hC1 , hC2 , G , np_ ) + gwggrad (constCt , hC1t , hC2t , G , np_ ))
480
476
481
- # removed since 0.9.2
482
- #if loss_fun == 'kl_loss':
483
- # armijo = True # there is no closed form line-search with KL
484
-
485
477
if armijo :
486
478
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
487
479
return line_search_armijo (cost , G , deltaG , Mi , cost_G , nx = np_ , ** kwargs )
@@ -827,10 +819,6 @@ def gromov_barycenters(
827
819
else :
828
820
C = init_C
829
821
830
- # removed since 0.9.2
831
- #if loss_fun == 'kl_loss':
832
- # armijo = True
833
-
834
822
cpt = 0
835
823
err = 1
836
824
@@ -1005,16 +993,14 @@ def fgw_barycenters(
1005
993
else :
1006
994
if init_X is None :
1007
995
X = nx .zeros ((N , d ), type_as = ps [0 ])
996
+
1008
997
else :
1009
998
X = init_X
1010
999
1011
- T = [nx .outer (p , q ) for q in ps ]
1012
-
1013
1000
Ms = [dist (X , Ys [s ]) for s in range (len (Ys ))]
1014
1001
1015
- # removed since 0.9.2
1016
- #if loss_fun == 'kl_loss':
1017
- # armijo = True
1002
+ if warmstartT :
1003
+ T = [nx .outer (p , q ) for q in ps ]
1018
1004
1019
1005
cpt = 0
1020
1006
err_feature = 1
@@ -1030,11 +1016,19 @@ def fgw_barycenters(
1030
1016
Cprev = C
1031
1017
Xprev = X
1032
1018
1019
+ if warmstartT :
1020
+ T = [fused_gromov_wasserstein (
1021
+ Ms [s ], C , Cs [s ], p , ps [s ], loss_fun = loss_fun , alpha = alpha , armijo = armijo , symmetric = symmetric ,
1022
+ G0 = T [s ], max_iter = max_iter , tol_rel = 1e-5 , tol_abs = 0. , verbose = verbose , ** kwargs ) for s in range (S )]
1023
+ else :
1024
+ T = [fused_gromov_wasserstein (
1025
+ Ms [s ], C , Cs [s ], p , ps [s ], loss_fun = loss_fun , alpha = alpha , armijo = armijo , symmetric = symmetric ,
1026
+ G0 = None , max_iter = max_iter , tol_rel = 1e-5 , tol_abs = 0. , verbose = verbose , ** kwargs ) for s in range (S )]
1027
+ # T is N,ns
1033
1028
if not fixed_features :
1034
1029
Ys_temp = [y .T for y in Ys ]
1035
1030
X = update_feature_matrix (lambdas , Ys_temp , T , p ).T
1036
-
1037
- Ms = [dist (X , Ys [s ]) for s in range (len (Ys ))]
1031
+ Ms = [dist (X , Ys [s ]) for s in range (len (Ys ))]
1038
1032
1039
1033
if not fixed_structure :
1040
1034
T_temp = [t .T for t in T ]
@@ -1044,15 +1038,6 @@ def fgw_barycenters(
1044
1038
elif loss_fun == 'kl_loss' :
1045
1039
C = update_kl_loss (p , lambdas , T_temp , Cs )
1046
1040
1047
- if warmstartT :
1048
- T = [fused_gromov_wasserstein (
1049
- Ms [s ], C , Cs [s ], p , ps [s ], loss_fun = loss_fun , alpha = alpha , armijo = armijo , symmetric = symmetric ,
1050
- G0 = T [s ], max_iter = max_iter , tol_rel = 1e-5 , tol_abs = 0. , verbose = verbose , ** kwargs ) for s in range (S )]
1051
- else :
1052
- T = [fused_gromov_wasserstein (
1053
- Ms [s ], C , Cs [s ], p , ps [s ], loss_fun = loss_fun , alpha = alpha , armijo = armijo , symmetric = symmetric ,
1054
- G0 = None , max_iter = max_iter , tol_rel = 1e-5 , tol_abs = 0. , verbose = verbose , ** kwargs ) for s in range (S )]
1055
- # T is N,ns
1056
1041
err_feature = nx .norm (X - nx .reshape (Xprev , (N , d )))
1057
1042
err_structure = nx .norm (C - Cprev )
1058
1043
if log :
0 commit comments