Open
Description
Describe the bug
Running ot.solve
with tensors on the CPU allows memory on the GPU (thisis documented in get_backend_list
) but also seems to use the GPU, as the Watts are increasing. See attached screencast :
Screencast from 08-03-2024 11:24:44.webm
Is it normal?
Script
import torch
import ot
n_samples = 5_000
x = torch.randn(n_samples, 2)
y = torch.randn(n_samples, 2)
a = torch.rand(n_samples)
a /= a.sum()
b = torch.rand(n_samples)
b /= b.sum()
M = ot.dist(x, y)
res = ot.solve(M, a, b, reg=0.1, reg_type="entropy")