diff --git a/RELEASES.md b/RELEASES.md index 9b9d2f597..9734ab21a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,6 +11,8 @@ - Fix doc and example for lowrank sinkhorn (PR #601) - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) - Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) +- Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) + ## 0.9.2 *December 2023* diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index c37ba2bf4..0137a8ed8 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -250,7 +250,7 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) @@ -509,7 +509,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) if isinstance(alpha, int) or isinstance(alpha, float): srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),