From e43735a939f6957576d9d5b51256e25290cd26df Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 5 Sep 2023 20:13:32 +0200 Subject: [PATCH 01/11] Decouple backend registration from backend instance setup --- ot/backend.py | 111 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 40 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 7b2fe875f..af516e288 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -87,43 +87,67 @@ # License: MIT License import numpy as np +import os import scipy import scipy.linalg -import scipy.special as special from scipy.sparse import issparse, coo_matrix, csr_matrix -import warnings +import scipy.special as special import time +import warnings + -try: - import torch - torch_type = torch.Tensor -except ImportError: +DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH' +DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX' +DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY' +DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW' + + +if not os.environ.get(DISABLE_TORCH_KEY, False): + try: + import torch + torch_type = torch.Tensor + except ImportError: + torch = False + torch_type = float +else: torch = False torch_type = float -try: - import jax - import jax.numpy as jnp - import jax.scipy.special as jspecial - from jax.lib import xla_bridge - jax_type = jax.numpy.ndarray -except ImportError: +if not os.environ.get(DISABLE_JAX_KEY, False): + try: + import jax + import jax.numpy as jnp + import jax.scipy.special as jspecial + from jax.lib import xla_bridge + jax_type = jax.numpy.ndarray + except ImportError: + jax = False + jax_type = float +else: jax = False jax_type = float -try: - import cupy as cp - import cupyx - cp_type = cp.ndarray -except ImportError: +if not os.environ.get(DISABLE_CUPY_KEY, False): + try: + import cupy as cp + import cupyx + cp_type = cp.ndarray + except ImportError: + cp = False + cp_type = float +else: cp = False cp_type = float -try: - import tensorflow as tf - import tensorflow.experimental.numpy as tnp - tf_type = tf.Tensor -except ImportError: +if not os.environ.get(DISABLE_TF_KEY, False): + try: + import tensorflow as tf + import tensorflow.experimental.numpy as tnp + tf_type = tf.Tensor + except ImportError: + tf = False + tf_type = float +else: tf = False tf_type = float @@ -132,26 +156,33 @@ # Mapping between argument types and the existing backend -_BACKENDS = [] +_BACKEND_IMPLEMENTATIONS = [] +_BACKENDS = {} -def register_backend(backend): - _BACKENDS.append(backend) +def register_backend_implementation(backend_impl): + _BACKEND_IMPLEMENTATIONS.append(backend_impl) def get_backend_list(): """Returns the list of available backends""" - return _BACKENDS + return list(_BACKENDS.values()) + + +def _get_backend_instance(backend_impl): + if backend_impl.__name__ not in _BACKENDS: + _BACKENDS[backend_impl.__name__] = backend_impl() + return _BACKENDS[backend_impl.__name__] -def _check_args_backend(backend, args): - is_instance = set(isinstance(a, backend.__type__) for a in args) +def _check_args_backend(backend_impl, args): + is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args) # check that all arguments matched or not the type if len(is_instance) == 1: return is_instance.pop() - # Oterwise return an error - raise ValueError(str_type_error.format([type(a) for a in args])) + # Otherwise return an error + raise ValueError(str_type_error.format([type(arg) for arg in args])) def get_backend(*args): @@ -160,12 +191,12 @@ def get_backend(*args): Also raises TypeError if all arrays are not from the same backend """ # check that some arrays given - if not len(args) > 0: + if len(args) == 0: raise ValueError(" The function takes at least one parameter") - for backend in _BACKENDS: - if _check_args_backend(backend, args): - return backend + for backend_impl in _BACKEND_IMPLEMENTATIONS: + if _check_args_backend(backend_impl, args): + return _get_backend_instance(backend_impl) raise ValueError("Unknown type of non implemented backend.") @@ -1337,7 +1368,7 @@ def matmul(self, a, b): return np.matmul(a, b) -register_backend(NumpyBackend()) +register_backend_implementation(NumpyBackend) class JaxBackend(Backend): @@ -1706,7 +1737,7 @@ def matmul(self, a, b): if jax: # Only register jax backend if it is installed - register_backend(JaxBackend()) + register_backend_implementation(JaxBackend) class TorchBackend(Backend): @@ -2189,7 +2220,7 @@ def matmul(self, a, b): if torch: # Only register torch backend if it is installed - register_backend(TorchBackend()) + register_backend_implementation(TorchBackend) class CupyBackend(Backend): # pragma: no cover @@ -2582,7 +2613,7 @@ def matmul(self, a, b): if cp: # Only register cp backend if it is installed - register_backend(CupyBackend()) + register_backend_implementation(CupyBackend) class TensorflowBackend(Backend): @@ -2995,4 +3026,4 @@ def matmul(self, a, b): if tf: # Only register tensorflow backend if it is installed - register_backend(TensorflowBackend()) + register_backend_implementation(TensorflowBackend) From e69fcfb023f332b385cbeffd8539d4e80a2a010b Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 5 Sep 2023 20:22:31 +0200 Subject: [PATCH 02/11] Make sure list of backends is properly loaded in tests --- ot/backend.py | 25 +++++++++++++++---------- test/conftest.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index af516e288..e21b7c09a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -160,15 +160,20 @@ _BACKENDS = {} -def register_backend_implementation(backend_impl): - _BACKEND_IMPLEMENTATIONS.append(backend_impl) - - def get_backend_list(): - """Returns the list of available backends""" + """Returns the list of already instantiated backends.""" return list(_BACKENDS.values()) +def get_available_backend_implementations(): + """Returns the list of available backend implementations.""" + return _BACKEND_IMPLEMENTATIONS + + +def _register_backend_implementation(backend_impl): + _BACKEND_IMPLEMENTATIONS.append(backend_impl) + + def _get_backend_instance(backend_impl): if backend_impl.__name__ not in _BACKENDS: _BACKENDS[backend_impl.__name__] = backend_impl() @@ -1368,7 +1373,7 @@ def matmul(self, a, b): return np.matmul(a, b) -register_backend_implementation(NumpyBackend) +_register_backend_implementation(NumpyBackend) class JaxBackend(Backend): @@ -1737,7 +1742,7 @@ def matmul(self, a, b): if jax: # Only register jax backend if it is installed - register_backend_implementation(JaxBackend) + _register_backend_implementation(JaxBackend) class TorchBackend(Backend): @@ -2220,7 +2225,7 @@ def matmul(self, a, b): if torch: # Only register torch backend if it is installed - register_backend_implementation(TorchBackend) + _register_backend_implementation(TorchBackend) class CupyBackend(Backend): # pragma: no cover @@ -2613,7 +2618,7 @@ def matmul(self, a, b): if cp: # Only register cp backend if it is installed - register_backend_implementation(CupyBackend) + _register_backend_implementation(CupyBackend) class TensorflowBackend(Backend): @@ -3026,4 +3031,4 @@ def matmul(self, a, b): if tf: # Only register tensorflow backend if it is installed - register_backend_implementation(TensorflowBackend) + _register_backend_implementation(TensorflowBackend) diff --git a/test/conftest.py b/test/conftest.py index c0db8abe2..5309a714c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,19 +4,43 @@ # License: MIT License -import pytest -from ot.backend import jax, tf -from ot.backend import get_backend_list import functools +import pytest + +from ot.backend import ( + get_available_backend_implementations, + get_backend_list, + _get_backend_instance, + jax, + tf +) + if jax: from jax.config import config config.update("jax_enable_x64", True) if tf: + # make sure TF doesn't allocate entire GPU + import tensorflow as tf + physical_devices = tf.config.list_physical_devices('GPU') + for device in physical_devices: + try: + tf.config.experimental.set_memory_growth(device, True) + except Exception: + pass + + # allow numpy API for TF from tensorflow.python.ops.numpy_ops import np_config np_config.enable_numpy_behavior() + +# before taking list of backends, we need to make sure all +# available implementations are instantiated. looks somewhat hacky, +# but hopefully it won't be needed for a common library use +for backend_impl in get_available_backend_implementations(): + _get_backend_instance(backend_impl) + backend_list = get_backend_list() From b394350532dda152e87d8164f9f152807f89fdb8 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 5 Sep 2023 20:32:11 +0200 Subject: [PATCH 03/11] Set JAX configuration to avoid pre-allocation in tests --- test/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/conftest.py b/test/conftest.py index 5309a714c..827be3656 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,6 +5,7 @@ # License: MIT License import functools +import os import pytest from ot.backend import ( @@ -19,6 +20,7 @@ if jax: from jax.config import config config.update("jax_enable_x64", True) + config.update("xla_python_client_preallocate", False) if tf: # make sure TF doesn't allocate entire GPU From 095ab3accbd2b1a7f972fd08b4ef42dc96f89821 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 5 Sep 2023 20:32:42 +0200 Subject: [PATCH 04/11] Remove unused import --- test/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index 827be3656..36b7ceac8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,7 +5,6 @@ # License: MIT License import functools -import os import pytest from ot.backend import ( From c33416d1f5ba6310a550d4b94222ca57ff612479 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 5 Sep 2023 20:48:41 +0200 Subject: [PATCH 05/11] Documentation section about backend support --- docs/source/quickstart.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index cd41a95d4..32422f983 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -961,6 +961,12 @@ List of compatible Backends - `Tensorflow `_ (all outputs differentiable w.r.t. inputs) - `Cupy `_ (no differentiation, GPU only) +Library automatically detect which backend are available to be used. Backend is instantiated lazily +only when needed to avoid unnecessary GPU memory allocations. It's also possible to disable import +of the corresponding backend library (for example, to speedup library import) with environment variable +`POT_BACKEND_DISABLE_` (e.g. `POT_BACKEND_DISABLE_TENSORFLOW`). Note that `numpy` backend +cannot be disabled. + List of compatible modules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 5241f8bac8cc35d6ae3a9e13f8073d3e7ae024e4 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 5 Sep 2023 20:53:50 +0200 Subject: [PATCH 06/11] Use enviornment variable for XLA --- test/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index 36b7ceac8..37e41e44e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,6 +5,7 @@ # License: MIT License import functools +import os import pytest from ot.backend import ( @@ -17,9 +18,9 @@ if jax: + os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' from jax.config import config config.update("jax_enable_x64", True) - config.update("xla_python_client_preallocate", False) if tf: # make sure TF doesn't allocate entire GPU From 1a694bbfb801c718cb7ec83ed778ac291d4de8db Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 5 Sep 2023 21:52:18 +0200 Subject: [PATCH 07/11] Rewrite doc on backends --- docs/source/quickstart.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 32422f983..8f888ed14 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -961,11 +961,12 @@ List of compatible Backends - `Tensorflow `_ (all outputs differentiable w.r.t. inputs) - `Cupy `_ (no differentiation, GPU only) -Library automatically detect which backend are available to be used. Backend is instantiated lazily -only when needed to avoid unnecessary GPU memory allocations. It's also possible to disable import -of the corresponding backend library (for example, to speedup library import) with environment variable -`POT_BACKEND_DISABLE_` (e.g. `POT_BACKEND_DISABLE_TENSORFLOW`). Note that `numpy` backend -cannot be disabled. +The library automatically detects which backends are available for use. A backend +is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations. +You can also disable the import of a specific backend library (e.g., to accelerate +loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_`. +For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`. +It's important to note that the `numpy` backend cannot be disabled. List of compatible modules From 6f4be00fc19a2c4a2b418eeb845e33211002f6fb Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Wed, 6 Sep 2023 16:58:38 +0200 Subject: [PATCH 08/11] Update get_backend_list to instantiate all backend objects --- ot/backend.py | 33 +++++++++++++++++++++++---------- test/conftest.py | 14 +------------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index e21b7c09a..084dc256a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -160,16 +160,6 @@ _BACKENDS = {} -def get_backend_list(): - """Returns the list of already instantiated backends.""" - return list(_BACKENDS.values()) - - -def get_available_backend_implementations(): - """Returns the list of available backend implementations.""" - return _BACKEND_IMPLEMENTATIONS - - def _register_backend_implementation(backend_impl): _BACKEND_IMPLEMENTATIONS.append(backend_impl) @@ -190,6 +180,29 @@ def _check_args_backend(backend_impl, args): raise ValueError(str_type_error.format([type(arg) for arg in args])) +def get_backend_list(): + """Returns instances of all available backends. + + Note that the function forces all detected implementations + to be instantiated even if specific backend was not use before. + Be careful as instantiation of the backend might lead to side effects, + like GPU memory pre-allocation. See the documentation for more details. + If you only need to know which implementations are available, + use `:py:func:`ot.backend.get_available_backend_implementations`, + which does not force instance of the backend object to be created. + """ + return [ + _get_backend_instance(backend_impl) + for backend_impl + in get_available_backend_implementations() + ] + + +def get_available_backend_implementations(): + """Returns the list of available backend implementations.""" + return _BACKEND_IMPLEMENTATIONS + + def get_backend(*args): """Returns the proper backend for a list of input arrays diff --git a/test/conftest.py b/test/conftest.py index 37e41e44e..0303ed9f2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,13 +8,7 @@ import os import pytest -from ot.backend import ( - get_available_backend_implementations, - get_backend_list, - _get_backend_instance, - jax, - tf -) +from ot.backend import get_backend_list, jax, tf if jax: @@ -37,12 +31,6 @@ np_config.enable_numpy_behavior() -# before taking list of backends, we need to make sure all -# available implementations are instantiated. looks somewhat hacky, -# but hopefully it won't be needed for a common library use -for backend_impl in get_available_backend_implementations(): - _get_backend_instance(backend_impl) - backend_list = get_backend_list() From 9571538be487d59ce53abc17852bac72b12cdf1f Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Wed, 6 Sep 2023 17:02:35 +0200 Subject: [PATCH 09/11] Add fix to the changelog --- RELEASES.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index d0209e233..b8ec668d9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +## 0.9.2 + +#### Closed issues +- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (PR #520) + + ## 0.9.1 *August 2023* From 207a6937a0cfe79b8bf0efcec923ea01177bdc0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 6 Sep 2023 17:36:48 +0200 Subject: [PATCH 10/11] Update RELEASES.md --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index b8ec668d9..3c635eb1b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,7 +3,7 @@ ## 0.9.2 #### Closed issues -- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (PR #520) +- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520) ## 0.9.1 From 67a6129ccee6aeeac7a752a21e5c81291687ba37 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Thu, 21 Sep 2023 10:49:06 +0200 Subject: [PATCH 11/11] Update docs/source/quickstart.rst MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rémi Flamary --- docs/source/quickstart.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8f888ed14..1f1c69398 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -964,7 +964,7 @@ List of compatible Backends The library automatically detects which backends are available for use. A backend is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations. You can also disable the import of a specific backend library (e.g., to accelerate -loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_`. +loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_` with in (TORCH,TENSORFLOW,CUPY,JAX). For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`. It's important to note that the `numpy` backend cannot be disabled.