Open
Description
It seems that a DistributedDataParallel (DDP) pytorch setup is not supported in OT - specifically on emd2 computation.
Any workarounds ideas for making this working?
or any example for multi-gpu setups for OT?
ideally, I would like to make OT working with this torch setup
https://github.com/pytorch/examples/blob/main/distributed/ddp/main.py
Many thanks
example of failed DDP
ot.emd2(a, b, dist)
File "/python3.8/site-packages/ot/lp/__init__.py", line 468, in emd2
nx = get_backend(M0, a0, b0)
File "/python3.8/site-packages/ot/backend.py", line 168, in get_backend
return TorchBackend()
File "/python3.8/site-packages/ot/backend.py", line 1517, in __init__
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable
my current workaround is:
changing
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
to
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device=device_id))
passing device id from backend, recompiling this OT from source.