Skip to content

ot.solve uses GPU even though tensors are on CPU? #612

Open
@mathurinm

Description

@mathurinm

Describe the bug

Running ot.solve with tensors on the CPU allows memory on the GPU (thisis documented in get_backend_list) but also seems to use the GPU, as the Watts are increasing. See attached screencast :
Screencast from 08-03-2024 11:24:44.webm

Is it normal?

Script

import torch
import ot

n_samples = 5_000

x = torch.randn(n_samples, 2)
y = torch.randn(n_samples, 2)

a = torch.rand(n_samples)
a /= a.sum()
b = torch.rand(n_samples)
b /= b.sum()

M = ot.dist(x, y)

res = ot.solve(M, a, b, reg=0.1, reg_type="entropy")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions