diff --git a/RELEASES.md b/RELEASES.md index 3b8513dbd..9b9d2f597 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,7 @@ - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) +- Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) ## 0.9.2 *December 2023* diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 281ed5f0b..46e1ddfe8 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -315,7 +315,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) gw = nx.set_gradients(gw, (p, q, C1, C2), (log_gw['u'] - nx.mean(log_gw['u']), @@ -627,7 +627,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) if isinstance(alpha, int) or isinstance(alpha, float): fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), (log_fgw['u'] - nx.mean(log_fgw['u']),