Skip to content

Commit 5ab00dd

Browse files
kachayevrflamary
andauthored
[Fix] Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations (#520)
* Decouple backend registration from backend instance setup * Make sure list of backends is properly loaded in tests * Set JAX configuration to avoid pre-allocation in tests * Remove unused import * Documentation section about backend support * Use enviornment variable for XLA * Rewrite doc on backends * Update get_backend_list to instantiate all backend objects * Add fix to the changelog * Update RELEASES.md * Update docs/source/quickstart.rst Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 4cf4492 commit 5ab00dd

File tree

4 files changed

+115
-44
lines changed

4 files changed

+115
-44
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#### Closed issues
1010
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
11+
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
1112

1213

1314
## 0.9.1

docs/source/quickstart.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,13 @@ List of compatible Backends
961961
- `Tensorflow <https://www.tensorflow.org/>`_ (all outputs differentiable w.r.t. inputs)
962962
- `Cupy <https://cupy.dev/>`_ (no differentiation, GPU only)
963963

964+
The library automatically detects which backends are available for use. A backend
965+
is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations.
966+
You can also disable the import of a specific backend library (e.g., to accelerate
967+
loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_<NAME>` with <NAME> in (TORCH,TENSORFLOW,CUPY,JAX).
968+
For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`.
969+
It's important to note that the `numpy` backend cannot be disabled.
970+
964971

965972
List of compatible modules
966973
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

ot/backend.py

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -87,43 +87,67 @@
8787
# License: MIT License
8888

8989
import numpy as np
90+
import os
9091
import scipy
9192
import scipy.linalg
92-
import scipy.special as special
9393
from scipy.sparse import issparse, coo_matrix, csr_matrix
94-
import warnings
94+
import scipy.special as special
9595
import time
96+
import warnings
97+
98+
99+
DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
100+
DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
101+
DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'
102+
DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'
103+
96104

97-
try:
98-
import torch
99-
torch_type = torch.Tensor
100-
except ImportError:
105+
if not os.environ.get(DISABLE_TORCH_KEY, False):
106+
try:
107+
import torch
108+
torch_type = torch.Tensor
109+
except ImportError:
110+
torch = False
111+
torch_type = float
112+
else:
101113
torch = False
102114
torch_type = float
103115

104-
try:
105-
import jax
106-
import jax.numpy as jnp
107-
import jax.scipy.special as jspecial
108-
from jax.lib import xla_bridge
109-
jax_type = jax.numpy.ndarray
110-
except ImportError:
116+
if not os.environ.get(DISABLE_JAX_KEY, False):
117+
try:
118+
import jax
119+
import jax.numpy as jnp
120+
import jax.scipy.special as jspecial
121+
from jax.lib import xla_bridge
122+
jax_type = jax.numpy.ndarray
123+
except ImportError:
124+
jax = False
125+
jax_type = float
126+
else:
111127
jax = False
112128
jax_type = float
113129

114-
try:
115-
import cupy as cp
116-
import cupyx
117-
cp_type = cp.ndarray
118-
except ImportError:
130+
if not os.environ.get(DISABLE_CUPY_KEY, False):
131+
try:
132+
import cupy as cp
133+
import cupyx
134+
cp_type = cp.ndarray
135+
except ImportError:
136+
cp = False
137+
cp_type = float
138+
else:
119139
cp = False
120140
cp_type = float
121141

122-
try:
123-
import tensorflow as tf
124-
import tensorflow.experimental.numpy as tnp
125-
tf_type = tf.Tensor
126-
except ImportError:
142+
if not os.environ.get(DISABLE_TF_KEY, False):
143+
try:
144+
import tensorflow as tf
145+
import tensorflow.experimental.numpy as tnp
146+
tf_type = tf.Tensor
147+
except ImportError:
148+
tf = False
149+
tf_type = float
150+
else:
127151
tf = False
128152
tf_type = float
129153

@@ -132,26 +156,51 @@
132156

133157

134158
# Mapping between argument types and the existing backend
135-
_BACKENDS = []
159+
_BACKEND_IMPLEMENTATIONS = []
160+
_BACKENDS = {}
136161

137162

138-
def register_backend(backend):
139-
_BACKENDS.append(backend)
163+
def _register_backend_implementation(backend_impl):
164+
_BACKEND_IMPLEMENTATIONS.append(backend_impl)
140165

141166

142-
def get_backend_list():
143-
"""Returns the list of available backends"""
144-
return _BACKENDS
167+
def _get_backend_instance(backend_impl):
168+
if backend_impl.__name__ not in _BACKENDS:
169+
_BACKENDS[backend_impl.__name__] = backend_impl()
170+
return _BACKENDS[backend_impl.__name__]
145171

146172

147-
def _check_args_backend(backend, args):
148-
is_instance = set(isinstance(a, backend.__type__) for a in args)
173+
def _check_args_backend(backend_impl, args):
174+
is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
149175
# check that all arguments matched or not the type
150176
if len(is_instance) == 1:
151177
return is_instance.pop()
152178

153-
# Oterwise return an error
154-
raise ValueError(str_type_error.format([type(a) for a in args]))
179+
# Otherwise return an error
180+
raise ValueError(str_type_error.format([type(arg) for arg in args]))
181+
182+
183+
def get_backend_list():
184+
"""Returns instances of all available backends.
185+
186+
Note that the function forces all detected implementations
187+
to be instantiated even if specific backend was not use before.
188+
Be careful as instantiation of the backend might lead to side effects,
189+
like GPU memory pre-allocation. See the documentation for more details.
190+
If you only need to know which implementations are available,
191+
use `:py:func:`ot.backend.get_available_backend_implementations`,
192+
which does not force instance of the backend object to be created.
193+
"""
194+
return [
195+
_get_backend_instance(backend_impl)
196+
for backend_impl
197+
in get_available_backend_implementations()
198+
]
199+
200+
201+
def get_available_backend_implementations():
202+
"""Returns the list of available backend implementations."""
203+
return _BACKEND_IMPLEMENTATIONS
155204

156205

157206
def get_backend(*args):
@@ -167,9 +216,9 @@ def get_backend(*args):
167216
if not len(args) > 0:
168217
raise ValueError(" The function takes at least one (non-None) parameter")
169218

170-
for backend in _BACKENDS:
171-
if _check_args_backend(backend, args):
172-
return backend
219+
for backend_impl in _BACKEND_IMPLEMENTATIONS:
220+
if _check_args_backend(backend_impl, args):
221+
return _get_backend_instance(backend_impl)
173222

174223
raise ValueError("Unknown type of non implemented backend.")
175224

@@ -1341,7 +1390,7 @@ def matmul(self, a, b):
13411390
return np.matmul(a, b)
13421391

13431392

1344-
register_backend(NumpyBackend())
1393+
_register_backend_implementation(NumpyBackend)
13451394

13461395

13471396
class JaxBackend(Backend):
@@ -1710,7 +1759,7 @@ def matmul(self, a, b):
17101759

17111760
if jax:
17121761
# Only register jax backend if it is installed
1713-
register_backend(JaxBackend())
1762+
_register_backend_implementation(JaxBackend)
17141763

17151764

17161765
class TorchBackend(Backend):
@@ -2193,7 +2242,7 @@ def matmul(self, a, b):
21932242

21942243
if torch:
21952244
# Only register torch backend if it is installed
2196-
register_backend(TorchBackend())
2245+
_register_backend_implementation(TorchBackend)
21972246

21982247

21992248
class CupyBackend(Backend): # pragma: no cover
@@ -2586,7 +2635,7 @@ def matmul(self, a, b):
25862635

25872636
if cp:
25882637
# Only register cp backend if it is installed
2589-
register_backend(CupyBackend())
2638+
_register_backend_implementation(CupyBackend)
25902639

25912640

25922641
class TensorflowBackend(Backend):
@@ -3006,4 +3055,4 @@ def matmul(self, a, b):
30063055

30073056
if tf:
30083057
# Only register tensorflow backend if it is installed
3009-
register_backend(TensorflowBackend())
3058+
_register_backend_implementation(TensorflowBackend)

test/conftest.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,33 @@
44

55
# License: MIT License
66

7-
import pytest
8-
from ot.backend import jax, tf
9-
from ot.backend import get_backend_list
107
import functools
8+
import os
9+
import pytest
10+
11+
from ot.backend import get_backend_list, jax, tf
12+
1113

1214
if jax:
15+
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
1316
from jax.config import config
1417
config.update("jax_enable_x64", True)
1518

1619
if tf:
20+
# make sure TF doesn't allocate entire GPU
21+
import tensorflow as tf
22+
physical_devices = tf.config.list_physical_devices('GPU')
23+
for device in physical_devices:
24+
try:
25+
tf.config.experimental.set_memory_growth(device, True)
26+
except Exception:
27+
pass
28+
29+
# allow numpy API for TF
1730
from tensorflow.python.ops.numpy_ops import np_config
1831
np_config.enable_numpy_behavior()
1932

33+
2034
backend_list = get_backend_list()
2135

2236

0 commit comments

Comments
 (0)