Skip to content

ot.gromov.gromov_wasserstein2 loss does not perform backprop with torch CUDA tensor  #351

Closed
@tbng

Description

@tbng

As title. Following is a short snippet to reproduce the error.

import numpy as np
import ot
import torch
from ot.gromov import gromov_wasserstein2


def gw_pytorch_exam(C1, C2, a1, a2, device, n_iter=1000, lr=1e-2):

    C1_torch = torch.tensor(C1, device=device, requires_grad=True)
    C2_torch = torch.tensor(C2, device=device)
    a1_torch = torch.tensor(a1, device=device)    
    a2_torch = torch.tensor(a2, device=device)

    for i in range(n_iter):
        loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
        loss.backward()
        with torch.no_grad():
            grad = C1_torch.grad
            C1_torch -= grad * lr
            C1_torch.grad.zero_()
            C1_torch.data = torch.clamp(C1_torch, 0, 1)

    return C1_torch


if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")  # maybe should disable this to force GPU usage

n = 10
C1 = np.eye(n)    
C2 = np.random.randn(n, n)
a = ot.unif(n)
C1 = gw_pytorch_exam(C1, C2, a, a, device)

Running this code returns RuntimeError

     36 a = ot.unif(n)
---> 37 C1 = gw_pytorch_exam(C1, C2, a, a, device)

<ipython-input-3-afc0d26d5054> in gw_pytorch_exam(C1, C2, a1, a2, device, n_iter, lr)
     16     for i in range(n_iter):
     17         loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
---> 18         loss.backward()
     19         with torch.no_grad():
     20             grad = C1_torch.grad

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    128         retain_graph = create_graph
    129 
--> 130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
    132         allow_unreachable=True)  # allow_unreachable flag

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/function.py in apply(self, *args)
     87     def apply(self, *args):
     88         # _forward_cls is defined by derived class
---> 89         return self._forward_cls.backward(self, *args)  # type: ignore
     90 
     91 

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in backward(ctx, grad_output)
   1381             def backward(ctx, grad_output):
   1382                 # the gradients are grad
-> 1383                 return (None, None) + tuple(g * grad_output for g in ctx.grads)
   1384 
   1385         self.ValFunction = ValFunction

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in <genexpr>(.0)
   1381             def backward(ctx, grad_output):
   1382                 # the gradients are grad
-> 1383                 return (None, None) + tuple(g * grad_output for g in ctx.grads)
   1384 
   1385         self.ValFunction = ValFunction

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

PyTorch: 1.7.0
POT: 0.8.1
CUDA: 10.1 on NVIDIA Tesla P100

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions