Skip to content

Commit e559ac9

Browse files
committed
fix the sign of gradient for kl gromov
1 parent f1fe593 commit e559ac9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ot/gromov/_gw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
315315
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
316316
elif loss_fun == 'kl_loss':
317317
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
318-
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
318+
gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
319319

320320
gw = nx.set_gradients(gw, (p, q, C1, C2),
321321
(log_gw['u'] - nx.mean(log_gw['u']),
@@ -627,7 +627,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
627627
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
628628
elif loss_fun == 'kl_loss':
629629
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
630-
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
630+
gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
631631
if isinstance(alpha, int) or isinstance(alpha, float):
632632
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
633633
(log_fgw['u'] - nx.mean(log_fgw['u']),

0 commit comments

Comments
 (0)