Skip to content

Commit 9412f0a

Browse files
authored
[MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352)
* Resolves gromov wasserstein backward bug * release file updated
1 parent 1781472 commit 9412f0a

File tree

3 files changed

+46
-29
lines changed

3 files changed

+46
-29
lines changed

RELEASES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
1919
PR #338)
2020
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
21+
- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA
22+
tensors (Issue #351, PR #352)
23+
2124

2225
## 0.8.1.0
2326
*December 2021*

ot/gromov.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,10 @@ def df(G):
546546
gw = log_gw['gw_dist']
547547

548548
if loss_fun == 'square_loss':
549-
gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
550-
gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
549+
gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
550+
gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
551+
gC1 = nx.from_numpy(gC1, type_as=C10)
552+
gC2 = nx.from_numpy(gC2, type_as=C10)
551553
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
552554
(log_gw['u'], log_gw['v'], gC1, gC2))
553555

@@ -786,8 +788,10 @@ def df(G):
786788
log_fgw['T'] = T0
787789

788790
if loss_fun == 'square_loss':
789-
gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
790-
gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
791+
gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
792+
gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
793+
gC1 = nx.from_numpy(gC1, type_as=C10)
794+
gC2 = nx.from_numpy(gC2, type_as=C10)
791795
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
792796
(log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
793797

test/test_gromov.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -181,19 +181,24 @@ def test_gromov2_gradients():
181181

182182
if torch:
183183

184-
p1 = torch.tensor(p, requires_grad=True)
185-
q1 = torch.tensor(q, requires_grad=True)
186-
C11 = torch.tensor(C1, requires_grad=True)
187-
C12 = torch.tensor(C2, requires_grad=True)
184+
devices = [torch.device("cpu")]
185+
if torch.cuda.is_available():
186+
devices.append(torch.device("cuda"))
187+
for device in devices:
188+
p1 = torch.tensor(p, requires_grad=True, device=device)
189+
q1 = torch.tensor(q, requires_grad=True, device=device)
190+
C11 = torch.tensor(C1, requires_grad=True, device=device)
191+
C12 = torch.tensor(C2, requires_grad=True, device=device)
188192

189-
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
193+
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
190194

191-
val.backward()
195+
val.backward()
192196

193-
assert q1.shape == q1.grad.shape
194-
assert p1.shape == p1.grad.shape
195-
assert C11.shape == C11.grad.shape
196-
assert C12.shape == C12.grad.shape
197+
assert val.device == p1.device
198+
assert q1.shape == q1.grad.shape
199+
assert p1.shape == p1.grad.shape
200+
assert C11.shape == C11.grad.shape
201+
assert C12.shape == C12.grad.shape
197202

198203

199204
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@@ -636,21 +641,26 @@ def test_fgw2_gradients():
636641

637642
if torch:
638643

639-
p1 = torch.tensor(p, requires_grad=True)
640-
q1 = torch.tensor(q, requires_grad=True)
641-
C11 = torch.tensor(C1, requires_grad=True)
642-
C12 = torch.tensor(C2, requires_grad=True)
643-
M1 = torch.tensor(M, requires_grad=True)
644-
645-
val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
646-
647-
val.backward()
648-
649-
assert q1.shape == q1.grad.shape
650-
assert p1.shape == p1.grad.shape
651-
assert C11.shape == C11.grad.shape
652-
assert C12.shape == C12.grad.shape
653-
assert M1.shape == M1.grad.shape
644+
devices = [torch.device("cpu")]
645+
if torch.cuda.is_available():
646+
devices.append(torch.device("cuda"))
647+
for device in devices:
648+
p1 = torch.tensor(p, requires_grad=True, device=device)
649+
q1 = torch.tensor(q, requires_grad=True, device=device)
650+
C11 = torch.tensor(C1, requires_grad=True, device=device)
651+
C12 = torch.tensor(C2, requires_grad=True, device=device)
652+
M1 = torch.tensor(M, requires_grad=True, device=device)
653+
654+
val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
655+
656+
val.backward()
657+
658+
assert val.device == p1.device
659+
assert q1.shape == q1.grad.shape
660+
assert p1.shape == p1.grad.shape
661+
assert C11.shape == C11.grad.shape
662+
assert C12.shape == C12.grad.shape
663+
assert M1.shape == M1.grad.shape
654664

655665

656666
def test_fgw_barycenter(nx):

0 commit comments

Comments
 (0)