From c0b2eda0ccdd6da7e7f5296c9abf4e0e9720ec0e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 2 Mar 2022 10:20:23 +0100 Subject: [PATCH 1/2] Resolves gromov wasserstein backward bug --- ot/gromov.py | 12 ++++++--- test/test_gromov.py | 60 ++++++++++++++++++++++++++------------------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index f5a1f9137..c5a82d11a 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -546,8 +546,10 @@ def df(G): gw = log_gw['gw_dist'] if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) gw = nx.set_gradients(gw, (p0, q0, C10, C20), (log_gw['u'], log_gw['v'], gC1, gC2)) @@ -786,8 +788,10 @@ def df(G): log_fgw['T'] = T0 if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 329f99c38..0dcf2daaf 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,19 +181,24 @@ def test_gromov2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov_wasserstein2(C11, C12, p1, q1) + val = ot.gromov_wasserstein2(C11, C12, p1, q1) - val.backward() + val.backward() - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -636,21 +641,26 @@ def test_fgw2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) - M1 = torch.tensor(M, requires_grad=True) - - val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) - - val.backward() - - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape def test_fgw_barycenter(nx): From 2d045de505a7bb467022683b8de91d7b0982409b Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 2 Mar 2022 11:19:34 +0100 Subject: [PATCH 2/2] release file updated --- RELEASES.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index c1068f3e3..18562e7b7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,9 @@ - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) +- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA + tensors (Issue #351, PR #352) + ## 0.8.1.0 *December 2021*