@@ -315,7 +315,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
315
315
gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
316
316
elif loss_fun == 'kl_loss' :
317
317
gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
318
- gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
318
+ gC2 = - nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
319
319
320
320
gw = nx .set_gradients (gw , (p , q , C1 , C2 ),
321
321
(log_gw ['u' ] - nx .mean (log_gw ['u' ]),
@@ -627,7 +627,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
627
627
gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
628
628
elif loss_fun == 'kl_loss' :
629
629
gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
630
- gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
630
+ gC2 = - nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
631
631
if isinstance (alpha , int ) or isinstance (alpha , float ):
632
632
fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M ),
633
633
(log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
0 commit comments