@@ -56,7 +56,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
56
56
If let to its default value None, uniform distribution is taken.
57
57
loss_fun : str
58
58
loss function used for the solver either 'square_loss' or 'kl_loss'.
59
- 'kl_loss' is not implemented yet and will raise an error.
60
59
symmetric : bool, optional
61
60
Either C1 and C2 are to be assumed symmetric or not.
62
61
If let to its default None value, a symmetry test will be conducted.
@@ -92,8 +91,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
92
91
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
93
92
International Conference on Learning Representations (ICLR), 2022.
94
93
"""
95
- if loss_fun == 'kl_loss' :
96
- raise NotImplementedError ()
97
94
arr = [C1 , C2 ]
98
95
if p is not None :
99
96
arr .append (list_to_array (p ))
@@ -139,7 +136,7 @@ def df(G):
139
136
return 0.5 * (gwggrad (constC + marginal_product_1 , hC1 , hC2 , G , nx ) + gwggrad (constCt + marginal_product_2 , hC1t , hC2t , G , nx ))
140
137
141
138
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
142
- return solve_semirelaxed_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , ones_p , M = 0. , reg = 1. , nx = nx , ** kwargs )
139
+ return solve_semirelaxed_gromov_linesearch (G , deltaG , cost_G , hC1 , hC2 , ones_p , M = 0. , reg = 1. , fC2t = fC2t , nx = nx , ** kwargs )
143
140
144
141
if log :
145
142
res , log = semirelaxed_cg (p , q , 0. , 1. , f , df , G0 , line_search , log = True , numItermax = max_iter , stopThr = tol_rel , stopThr2 = tol_abs , ** kwargs )
@@ -190,7 +187,6 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
190
187
If let to its default value None, uniform distribution is taken.
191
188
loss_fun : str
192
189
loss function used for the solver either 'square_loss' or 'kl_loss'.
193
- 'kl_loss' is not implemented yet and will raise an error.
194
190
symmetric : bool, optional
195
191
Either C1 and C2 are to be assumed symmetric or not.
196
192
If let to its default None value, a symmetry test will be conducted.
@@ -243,7 +239,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
243
239
if loss_fun == 'square_loss' :
244
240
gC1 = 2 * C1 * nx .outer (p , p ) - 2 * nx .dot (T , nx .dot (C2 , T .T ))
245
241
gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
246
- srgw = nx .set_gradients (srgw , (C1 , C2 ), (gC1 , gC2 ))
242
+
243
+ elif loss_fun == 'kl_loss' :
244
+ gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
245
+ gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
246
+
247
+ srgw = nx .set_gradients (srgw , (C1 , C2 ), (gC1 , gC2 ))
247
248
248
249
if log :
249
250
return srgw , log_srgw
@@ -291,7 +292,6 @@ def semirelaxed_fused_gromov_wasserstein(
291
292
If let to its default value None, uniform distribution is taken.
292
293
loss_fun : str
293
294
loss function used for the solver either 'square_loss' or 'kl_loss'.
294
- 'kl_loss' is not implemented yet and will raise an error.
295
295
symmetric : bool, optional
296
296
Either C1 and C2 are to be assumed symmetric or not.
297
297
If let to its default None value, a symmetry test will be conducted.
@@ -332,9 +332,6 @@ def semirelaxed_fused_gromov_wasserstein(
332
332
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
333
333
International Conference on Learning Representations (ICLR), 2022.
334
334
"""
335
- if loss_fun == 'kl_loss' :
336
- raise NotImplementedError ()
337
-
338
335
arr = [M , C1 , C2 ]
339
336
if p is not None :
340
337
arr .append (list_to_array (p ))
@@ -382,7 +379,7 @@ def df(G):
382
379
383
380
def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
384
381
return solve_semirelaxed_gromov_linesearch (
385
- G , deltaG , cost_G , C1 , C2 , ones_p , M = (1 - alpha ) * M , reg = alpha , nx = nx , ** kwargs )
382
+ G , deltaG , cost_G , hC1 , hC2 , ones_p , M = (1 - alpha ) * M , reg = alpha , fC2t = fC2t , nx = nx , ** kwargs )
386
383
387
384
if log :
388
385
res , log = semirelaxed_cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , line_search , log = True , numItermax = max_iter , stopThr = tol_rel , stopThr2 = tol_abs , ** kwargs )
@@ -434,7 +431,6 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
434
431
If let to its default value None, uniform distribution is taken.
435
432
loss_fun : str, optional
436
433
loss function used for the solver either 'square_loss' or 'kl_loss'.
437
- 'kl_loss' is not implemented yet and will raise an error.
438
434
symmetric : bool, optional
439
435
Either C1 and C2 are to be assumed symmetric or not.
440
436
If let to its default None value, a symmetry test will be conducted.
@@ -494,15 +490,20 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
494
490
if loss_fun == 'square_loss' :
495
491
gC1 = 2 * C1 * nx .outer (p , p ) - 2 * nx .dot (T , nx .dot (C2 , T .T ))
496
492
gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
497
- if isinstance (alpha , int ) or isinstance (alpha , float ):
498
- srfgw_dist = nx .set_gradients (srfgw_dist , (C1 , C2 , M ),
499
- (alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
500
- else :
501
- lin_term = nx .sum (T * M )
502
- srgw_term = (srfgw_dist - (1 - alpha ) * lin_term ) / alpha
503
- srfgw_dist = nx .set_gradients (srfgw_dist , (C1 , C2 , M , alpha ),
504
- (alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ,
505
- srgw_term - lin_term ))
493
+
494
+ elif loss_fun == 'kl_loss' :
495
+ gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
496
+ gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
497
+
498
+ if isinstance (alpha , int ) or isinstance (alpha , float ):
499
+ srfgw_dist = nx .set_gradients (srfgw_dist , (C1 , C2 , M ),
500
+ (alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
501
+ else :
502
+ lin_term = nx .sum (T * M )
503
+ srgw_term = (srfgw_dist - (1 - alpha ) * lin_term ) / alpha
504
+ srfgw_dist = nx .set_gradients (srfgw_dist , (C1 , C2 , M , alpha ),
505
+ (alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ,
506
+ srgw_term - lin_term ))
506
507
507
508
if log :
508
509
return srfgw_dist , log_fgw
@@ -511,7 +512,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
511
512
512
513
513
514
def solve_semirelaxed_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , ones_p ,
514
- M , reg , alpha_min = None , alpha_max = None , nx = None , ** kwargs ):
515
+ M , reg , fC2t = None , alpha_min = None , alpha_max = None , nx = None , ** kwargs ):
515
516
"""
516
517
Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence.
517
518
@@ -524,16 +525,22 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
524
525
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
525
526
cost_G : float
526
527
Value of the cost at `G`
527
- C1 : array-like (ns,ns)
528
- Structure matrix in the source domain.
529
- C2 : array-like (nt,nt)
530
- Structure matrix in the target domain.
528
+ C1 : array-like (ns,ns), optional
529
+ Transformed Structure matrix in the source domain.
530
+ Note that for the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix_semirelaxed
531
+ C2 : array-like (nt,nt), optional
532
+ Transformed Structure matrix in the source domain.
533
+ Note that for the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix_semirelaxed
531
534
ones_p: array-like (ns,1)
532
535
Array of ones of size ns
533
536
M : array-like (ns,nt)
534
537
Cost matrix between the features.
535
538
reg : float
536
539
Regularization parameter.
540
+ fC2t: array-like (nt,nt), optional
541
+ Transformed Structure matrix in the source domain.
542
+ Note that for the 'square_loss' and 'kl_loss', we provide fC2t from ot.gromov.init_matrix_semirelaxed.
543
+ If fC2t is not provided, it is by default fC2t corresponding to the 'square_loss'.
537
544
alpha_min : float, optional
538
545
Minimum value for alpha
539
546
alpha_max : float, optional
@@ -565,11 +572,14 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
565
572
566
573
qG , qdeltaG = nx .sum (G , 0 ), nx .sum (deltaG , 0 )
567
574
dot = nx .dot (nx .dot (C1 , deltaG ), C2 .T )
568
- C2t_square = C2 .T ** 2
569
- dot_qG = nx .dot (nx .outer (ones_p , qG ), C2t_square )
570
- dot_qdeltaG = nx .dot (nx .outer (ones_p , qdeltaG ), C2t_square )
571
- a = reg * nx .sum ((dot_qdeltaG - 2 * dot ) * deltaG )
572
- b = nx .sum (M * deltaG ) + reg * (nx .sum ((dot_qdeltaG - 2 * dot ) * G ) + nx .sum ((dot_qG - 2 * nx .dot (nx .dot (C1 , G ), C2 .T )) * deltaG ))
575
+ if fC2t is None :
576
+ fC2t = C2 .T ** 2
577
+ dot_qG = nx .dot (nx .outer (ones_p , qG ), fC2t )
578
+ dot_qdeltaG = nx .dot (nx .outer (ones_p , qdeltaG ), fC2t )
579
+
580
+ a = reg * nx .sum ((dot_qdeltaG - dot ) * deltaG )
581
+ b = nx .sum (M * deltaG ) + reg * (nx .sum ((dot_qdeltaG - dot ) * G ) + nx .sum ((dot_qG - nx .dot (nx .dot (C1 , G ), C2 .T )) * deltaG ))
582
+
573
583
alpha = solve_1d_linesearch_quad (a , b )
574
584
if alpha_min is not None or alpha_max is not None :
575
585
alpha = np .clip (alpha , alpha_min , alpha_max )
@@ -620,7 +630,6 @@ def entropic_semirelaxed_gromov_wasserstein(
620
630
If let to its default value None, uniform distribution is taken.
621
631
loss_fun : str
622
632
loss function used for the solver either 'square_loss' or 'kl_loss'.
623
- 'kl_loss' is not implemented yet and will raise an error.
624
633
epsilon : float
625
634
Regularization term >0
626
635
symmetric : bool, optional
@@ -655,8 +664,6 @@ def entropic_semirelaxed_gromov_wasserstein(
655
664
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
656
665
International Conference on Learning Representations (ICLR), 2022.
657
666
"""
658
- if loss_fun == 'kl_loss' :
659
- raise NotImplementedError ()
660
667
arr = [C1 , C2 ]
661
668
if p is not None :
662
669
arr .append (list_to_array (p ))
@@ -777,7 +784,6 @@ def entropic_semirelaxed_gromov_wasserstein2(
777
784
If let to its default value None, uniform distribution is taken.
778
785
loss_fun : str
779
786
loss function used for the solver either 'square_loss' or 'kl_loss'.
780
- 'kl_loss' is not implemented yet and will raise an error.
781
787
epsilon : float
782
788
Regularization term >0
783
789
symmetric : bool, optional
@@ -869,7 +875,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
869
875
If let to its default value None, uniform distribution is taken.
870
876
loss_fun : str
871
877
loss function used for the solver either 'square_loss' or 'kl_loss'.
872
- 'kl_loss' is not implemented yet and will raise an error.
873
878
epsilon : float
874
879
Regularization term >0
875
880
symmetric : bool, optional
@@ -907,8 +912,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
907
912
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
908
913
International Conference on Learning Representations (ICLR), 2022.
909
914
"""
910
- if loss_fun == 'kl_loss' :
911
- raise NotImplementedError ()
912
915
arr = [M , C1 , C2 ]
913
916
if p is not None :
914
917
arr .append (list_to_array (p ))
@@ -1032,7 +1035,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein2(
1032
1035
If let to its default value None, uniform distribution is taken.
1033
1036
loss_fun : str, optional
1034
1037
loss function used for the solver either 'square_loss' or 'kl_loss'.
1035
- 'kl_loss' is not implemented yet and will raise an error.
1036
1038
epsilon : float
1037
1039
Regularization term >0
1038
1040
symmetric : bool, optional
0 commit comments