We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3e05385 commit 8e67b40Copy full SHA for 8e67b40
ot/gromov/_semirelaxed.py
@@ -250,7 +250,7 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
250
251
elif loss_fun == 'kl_loss':
252
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
253
- gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
+ gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
254
255
srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
256
@@ -509,7 +509,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
509
510
511
512
513
514
if isinstance(alpha, int) or isinstance(alpha, float):
515
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
0 commit comments