87
87
# License: MIT License
88
88
89
89
import numpy as np
90
+ import os
90
91
import scipy
91
92
import scipy .linalg
92
- import scipy .special as special
93
93
from scipy .sparse import issparse , coo_matrix , csr_matrix
94
- import warnings
94
+ import scipy . special as special
95
95
import time
96
+ import warnings
97
+
98
+
99
+ DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
100
+ DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
101
+ DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'
102
+ DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'
103
+
96
104
97
- try :
98
- import torch
99
- torch_type = torch .Tensor
100
- except ImportError :
105
+ if not os .environ .get (DISABLE_TORCH_KEY , False ):
106
+ try :
107
+ import torch
108
+ torch_type = torch .Tensor
109
+ except ImportError :
110
+ torch = False
111
+ torch_type = float
112
+ else :
101
113
torch = False
102
114
torch_type = float
103
115
104
- try :
105
- import jax
106
- import jax .numpy as jnp
107
- import jax .scipy .special as jspecial
108
- from jax .lib import xla_bridge
109
- jax_type = jax .numpy .ndarray
110
- except ImportError :
116
+ if not os .environ .get (DISABLE_JAX_KEY , False ):
117
+ try :
118
+ import jax
119
+ import jax .numpy as jnp
120
+ import jax .scipy .special as jspecial
121
+ from jax .lib import xla_bridge
122
+ jax_type = jax .numpy .ndarray
123
+ except ImportError :
124
+ jax = False
125
+ jax_type = float
126
+ else :
111
127
jax = False
112
128
jax_type = float
113
129
114
- try :
115
- import cupy as cp
116
- import cupyx
117
- cp_type = cp .ndarray
118
- except ImportError :
130
+ if not os .environ .get (DISABLE_CUPY_KEY , False ):
131
+ try :
132
+ import cupy as cp
133
+ import cupyx
134
+ cp_type = cp .ndarray
135
+ except ImportError :
136
+ cp = False
137
+ cp_type = float
138
+ else :
119
139
cp = False
120
140
cp_type = float
121
141
122
- try :
123
- import tensorflow as tf
124
- import tensorflow .experimental .numpy as tnp
125
- tf_type = tf .Tensor
126
- except ImportError :
142
+ if not os .environ .get (DISABLE_TF_KEY , False ):
143
+ try :
144
+ import tensorflow as tf
145
+ import tensorflow .experimental .numpy as tnp
146
+ tf_type = tf .Tensor
147
+ except ImportError :
148
+ tf = False
149
+ tf_type = float
150
+ else :
127
151
tf = False
128
152
tf_type = float
129
153
132
156
133
157
134
158
# Mapping between argument types and the existing backend
135
- _BACKENDS = []
159
+ _BACKEND_IMPLEMENTATIONS = []
160
+ _BACKENDS = {}
136
161
137
162
138
- def register_backend ( backend ):
139
- _BACKENDS .append (backend )
163
+ def _register_backend_implementation ( backend_impl ):
164
+ _BACKEND_IMPLEMENTATIONS .append (backend_impl )
140
165
141
166
142
- def get_backend_list ():
143
- """Returns the list of available backends"""
144
- return _BACKENDS
167
+ def _get_backend_instance (backend_impl ):
168
+ if backend_impl .__name__ not in _BACKENDS :
169
+ _BACKENDS [backend_impl .__name__ ] = backend_impl ()
170
+ return _BACKENDS [backend_impl .__name__ ]
145
171
146
172
147
- def _check_args_backend (backend , args ):
148
- is_instance = set (isinstance (a , backend .__type__ ) for a in args )
173
+ def _check_args_backend (backend_impl , args ):
174
+ is_instance = set (isinstance (arg , backend_impl .__type__ ) for arg in args )
149
175
# check that all arguments matched or not the type
150
176
if len (is_instance ) == 1 :
151
177
return is_instance .pop ()
152
178
153
- # Oterwise return an error
154
- raise ValueError (str_type_error .format ([type (a ) for a in args ]))
179
+ # Otherwise return an error
180
+ raise ValueError (str_type_error .format ([type (arg ) for arg in args ]))
181
+
182
+
183
+ def get_backend_list ():
184
+ """Returns instances of all available backends.
185
+
186
+ Note that the function forces all detected implementations
187
+ to be instantiated even if specific backend was not use before.
188
+ Be careful as instantiation of the backend might lead to side effects,
189
+ like GPU memory pre-allocation. See the documentation for more details.
190
+ If you only need to know which implementations are available,
191
+ use `:py:func:`ot.backend.get_available_backend_implementations`,
192
+ which does not force instance of the backend object to be created.
193
+ """
194
+ return [
195
+ _get_backend_instance (backend_impl )
196
+ for backend_impl
197
+ in get_available_backend_implementations ()
198
+ ]
199
+
200
+
201
+ def get_available_backend_implementations ():
202
+ """Returns the list of available backend implementations."""
203
+ return _BACKEND_IMPLEMENTATIONS
155
204
156
205
157
206
def get_backend (* args ):
@@ -167,9 +216,9 @@ def get_backend(*args):
167
216
if not len (args ) > 0 :
168
217
raise ValueError (" The function takes at least one (non-None) parameter" )
169
218
170
- for backend in _BACKENDS :
171
- if _check_args_backend (backend , args ):
172
- return backend
219
+ for backend_impl in _BACKEND_IMPLEMENTATIONS :
220
+ if _check_args_backend (backend_impl , args ):
221
+ return _get_backend_instance ( backend_impl )
173
222
174
223
raise ValueError ("Unknown type of non implemented backend." )
175
224
@@ -1341,7 +1390,7 @@ def matmul(self, a, b):
1341
1390
return np .matmul (a , b )
1342
1391
1343
1392
1344
- register_backend (NumpyBackend () )
1393
+ _register_backend_implementation (NumpyBackend )
1345
1394
1346
1395
1347
1396
class JaxBackend (Backend ):
@@ -1710,7 +1759,7 @@ def matmul(self, a, b):
1710
1759
1711
1760
if jax :
1712
1761
# Only register jax backend if it is installed
1713
- register_backend (JaxBackend () )
1762
+ _register_backend_implementation (JaxBackend )
1714
1763
1715
1764
1716
1765
class TorchBackend (Backend ):
@@ -2193,7 +2242,7 @@ def matmul(self, a, b):
2193
2242
2194
2243
if torch :
2195
2244
# Only register torch backend if it is installed
2196
- register_backend (TorchBackend () )
2245
+ _register_backend_implementation (TorchBackend )
2197
2246
2198
2247
2199
2248
class CupyBackend (Backend ): # pragma: no cover
@@ -2586,7 +2635,7 @@ def matmul(self, a, b):
2586
2635
2587
2636
if cp :
2588
2637
# Only register cp backend if it is installed
2589
- register_backend (CupyBackend () )
2638
+ _register_backend_implementation (CupyBackend )
2590
2639
2591
2640
2592
2641
class TensorflowBackend (Backend ):
@@ -3006,4 +3055,4 @@ def matmul(self, a, b):
3006
3055
3007
3056
if tf :
3008
3057
# Only register tensorflow backend if it is installed
3009
- register_backend (TensorflowBackend () )
3058
+ _register_backend_implementation (TensorflowBackend )
0 commit comments