@@ -160,15 +160,17 @@ def df(G):
160
160
161
161
def df (G ):
162
162
return 0.5 * (gwggrad (constC , hC1 , hC2 , G , np_ ) + gwggrad (constCt , hC1t , hC2t , G , np_ ))
163
- if loss_fun == 'kl_loss' :
164
- armijo = True # there is no closed form line-search with KL
163
+
164
+ # removed since 0.9.2
165
+ #if loss_fun == 'kl_loss':
166
+ # armijo = True # there is no closed form line-search with KL
165
167
166
168
if armijo :
167
169
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
168
170
return line_search_armijo (cost , G , deltaG , Mi , cost_G , nx = np_ , ** kwargs )
169
171
else :
170
172
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
171
- return solve_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , M = 0. , reg = 1. , nx = np_ , ** kwargs )
173
+ return solve_gromov_linesearch (G , deltaG , cost_G , hC1 , hC2 , M = 0. , reg = 1. , nx = np_ , ** kwargs )
172
174
if log :
173
175
res , log = cg (p , q , 0. , 1. , f , df , G0 , line_search , log = True , numItermax = max_iter , stopThr = tol_rel , stopThr2 = tol_abs , ** kwargs )
174
176
log ['gw_dist' ] = nx .from_numpy (log ['loss' ][- 1 ], type_as = C10 )
@@ -296,9 +298,13 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
296
298
if loss_fun == 'square_loss' :
297
299
gC1 = 2 * C1 * nx .outer (p , p ) - 2 * nx .dot (T , nx .dot (C2 , T .T ))
298
300
gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
299
- gw = nx .set_gradients (gw , (p , q , C1 , C2 ),
300
- (log_gw ['u' ] - nx .mean (log_gw ['u' ]),
301
- log_gw ['v' ] - nx .mean (log_gw ['v' ]), gC1 , gC2 ))
301
+ elif loss_fun == 'kl_loss' :
302
+ gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
303
+ gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
304
+
305
+ gw = nx .set_gradients (gw , (p , q , C1 , C2 ),
306
+ (log_gw ['u' ] - nx .mean (log_gw ['u' ]),
307
+ log_gw ['v' ] - nx .mean (log_gw ['v' ]), gC1 , gC2 ))
302
308
303
309
if log :
304
310
return gw , log_gw
@@ -449,15 +455,16 @@ def df(G):
449
455
def df (G ):
450
456
return 0.5 * (gwggrad (constC , hC1 , hC2 , G , np_ ) + gwggrad (constCt , hC1t , hC2t , G , np_ ))
451
457
452
- if loss_fun == 'kl_loss' :
453
- armijo = True # there is no closed form line-search with KL
458
+ # removed since 0.9.2
459
+ #if loss_fun == 'kl_loss':
460
+ # armijo = True # there is no closed form line-search with KL
454
461
455
462
if armijo :
456
463
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
457
464
return line_search_armijo (cost , G , deltaG , Mi , cost_G , nx = np_ , ** kwargs )
458
465
else :
459
466
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
460
- return solve_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , M = (1 - alpha ) * M , reg = alpha , nx = np_ , ** kwargs )
467
+ return solve_gromov_linesearch (G , deltaG , cost_G , hC1 , hC2 , M = (1 - alpha ) * M , reg = alpha , nx = np_ , ** kwargs )
461
468
if log :
462
469
res , log = cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , line_search , log = True , numItermax = max_iter , stopThr = tol_rel , stopThr2 = tol_abs , ** kwargs )
463
470
log ['fgw_dist' ] = nx .from_numpy (log ['loss' ][- 1 ], type_as = C10 )
@@ -591,18 +598,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
591
598
if loss_fun == 'square_loss' :
592
599
gC1 = 2 * C1 * nx .outer (p , p ) - 2 * nx .dot (T , nx .dot (C2 , T .T ))
593
600
gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
594
- if isinstance (alpha , int ) or isinstance (alpha , float ):
595
- fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M ),
596
- (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
597
- log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
598
- alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
599
- else :
600
-
601
- fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M , alpha ),
602
- (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
603
- log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
604
- alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ,
605
- gw_term - lin_term ))
601
+ elif loss_fun == 'kl_loss' :
602
+ gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
603
+ gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
604
+ if isinstance (alpha , int ) or isinstance (alpha , float ):
605
+ fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M ),
606
+ (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
607
+ log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
608
+ alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
609
+ else :
610
+ fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M , alpha ),
611
+ (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
612
+ log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
613
+ alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ,
614
+ gw_term - lin_term ))
606
615
607
616
if log :
608
617
return fgw_dist , log_fgw
@@ -613,7 +622,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
613
622
def solve_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , M , reg ,
614
623
alpha_min = None , alpha_max = None , nx = None , ** kwargs ):
615
624
"""
616
- Solve the linesearch in the FW iterations
625
+ Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] <references-solve-linesearch>`.
617
626
618
627
Parameters
619
628
----------
@@ -625,9 +634,11 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
625
634
cost_G : float
626
635
Value of the cost at `G`
627
636
C1 : array-like (ns,ns), optional
628
- Structure matrix in the source domain.
637
+ Transformed Structure matrix in the source domain.
638
+ For the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix
629
639
C2 : array-like (nt,nt), optional
630
- Structure matrix in the target domain.
640
+ Transformed Structure matrix in the source domain.
641
+ For the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix
631
642
M : array-like (ns,nt)
632
643
Cost matrix between the features.
633
644
reg : float
@@ -649,11 +660,16 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
649
660
650
661
651
662
.. _references-solve-linesearch:
663
+
652
664
References
653
665
----------
654
666
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
655
667
"Optimal Transport for structured data with application on graphs"
656
668
International Conference on Machine Learning (ICML). 2019.
669
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
670
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
671
+ International Conference on Machine Learning (ICML). 2016.
672
+
657
673
"""
658
674
if nx is None :
659
675
G , deltaG , C1 , C2 , M = list_to_array (G , deltaG , C1 , C2 , M )
@@ -664,8 +680,8 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
664
680
nx = get_backend (G , deltaG , C1 , C2 , M )
665
681
666
682
dot = nx .dot (nx .dot (C1 , deltaG ), C2 .T )
667
- a = - 2 * reg * nx .sum (dot * deltaG )
668
- b = nx .sum (M * deltaG ) - 2 * reg * (nx .sum (dot * G ) + nx .sum (nx .dot (nx .dot (C1 , G ), C2 .T ) * deltaG ))
683
+ a = - reg * nx .sum (dot * deltaG )
684
+ b = nx .sum (M * deltaG ) - reg * (nx .sum (dot * G ) + nx .sum (nx .dot (nx .dot (C1 , G ), C2 .T ) * deltaG ))
669
685
670
686
alpha = solve_1d_linesearch_quad (a , b )
671
687
if alpha_min is not None or alpha_max is not None :
@@ -776,8 +792,9 @@ def gromov_barycenters(
776
792
else :
777
793
C = init_C
778
794
779
- if loss_fun == 'kl_loss' :
780
- armijo = True
795
+ # removed since 0.9.2
796
+ #if loss_fun == 'kl_loss':
797
+ # armijo = True
781
798
782
799
cpt = 0
783
800
err = 1
@@ -960,8 +977,9 @@ def fgw_barycenters(
960
977
961
978
Ms = [dist (X , Ys [s ]) for s in range (len (Ys ))]
962
979
963
- if loss_fun == 'kl_loss' :
964
- armijo = True
980
+ # removed since 0.9.2
981
+ #if loss_fun == 'kl_loss':
982
+ # armijo = True
965
983
966
984
cpt = 0
967
985
err_feature = 1
0 commit comments