Skip to content

Commit 8e67b40

Browse files
fix grad sign
1 parent 3e05385 commit 8e67b40

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ot/gromov/_semirelaxed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
250250

251251
elif loss_fun == 'kl_loss':
252252
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
253-
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
253+
gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
254254

255255
srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
256256

@@ -509,7 +509,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
509509

510510
elif loss_fun == 'kl_loss':
511511
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
512-
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
512+
gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
513513

514514
if isinstance(alpha, int) or isinstance(alpha, float):
515515
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),

0 commit comments

Comments
 (0)