From e559ac967068f11767f4127f75aa96f118029b66 Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:31:24 +0100 Subject: [PATCH 1/3] fix the sign of gradient for kl gromov --- ot/gromov/_gw.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 3d7a47480..6bc613cb0 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -315,7 +315,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) gw = nx.set_gradients(gw, (p, q, C1, C2), (log_gw['u'] - nx.mean(log_gw['u']), @@ -627,7 +627,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) if isinstance(alpha, int) or isinstance(alpha, float): fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), (log_fgw['u'] - nx.mean(log_fgw['u']), From b0d62a64053066e666acf2206297301c94a1dec2 Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Mon, 4 Mar 2024 17:54:54 +0100 Subject: [PATCH 2/3] releases updated --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index a695d6a70..3d8a48d38 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -9,6 +9,7 @@ - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) +- Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss ## 0.9.2 *December 2023* From 7cecd40f100cb4bbdcf30124160e15862253360f Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:12:01 +0100 Subject: [PATCH 3/3] add PR ref --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 342380d75..9b9d2f597 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,7 +10,7 @@ - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) -- Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss +- Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) ## 0.9.2 *December 2023*