Closed
Description
As title. Following is a short snippet to reproduce the error.
import numpy as np
import ot
import torch
from ot.gromov import gromov_wasserstein2
def gw_pytorch_exam(C1, C2, a1, a2, device, n_iter=1000, lr=1e-2):
C1_torch = torch.tensor(C1, device=device, requires_grad=True)
C2_torch = torch.tensor(C2, device=device)
a1_torch = torch.tensor(a1, device=device)
a2_torch = torch.tensor(a2, device=device)
for i in range(n_iter):
loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
loss.backward()
with torch.no_grad():
grad = C1_torch.grad
C1_torch -= grad * lr
C1_torch.grad.zero_()
C1_torch.data = torch.clamp(C1_torch, 0, 1)
return C1_torch
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu") # maybe should disable this to force GPU usage
n = 10
C1 = np.eye(n)
C2 = np.random.randn(n, n)
a = ot.unif(n)
C1 = gw_pytorch_exam(C1, C2, a, a, device)
Running this code returns RuntimeError
36 a = ot.unif(n)
---> 37 C1 = gw_pytorch_exam(C1, C2, a, a, device)
<ipython-input-3-afc0d26d5054> in gw_pytorch_exam(C1, C2, a1, a2, device, n_iter, lr)
16 for i in range(n_iter):
17 loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
---> 18 loss.backward()
19 with torch.no_grad():
20 grad = C1_torch.grad
~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
219 retain_graph=retain_graph,
220 create_graph=create_graph)
--> 221 torch.autograd.backward(self, gradient, retain_graph, create_graph)
222
223 def register_hook(self, hook):
~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
128 retain_graph = create_graph
129
--> 130 Variable._execution_engine.run_backward(
131 tensors, grad_tensors_, retain_graph, create_graph,
132 allow_unreachable=True) # allow_unreachable flag
~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/function.py in apply(self, *args)
87 def apply(self, *args):
88 # _forward_cls is defined by derived class
---> 89 return self._forward_cls.backward(self, *args) # type: ignore
90
91
~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in backward(ctx, grad_output)
1381 def backward(ctx, grad_output):
1382 # the gradients are grad
-> 1383 return (None, None) + tuple(g * grad_output for g in ctx.grads)
1384
1385 self.ValFunction = ValFunction
~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in <genexpr>(.0)
1381 def backward(ctx, grad_output):
1382 # the gradients are grad
-> 1383 return (None, None) + tuple(g * grad_output for g in ctx.grads)
1384
1385 self.ValFunction = ValFunction
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
PyTorch: 1.7.0
POT: 0.8.1
CUDA: 10.1 on NVIDIA Tesla P100