diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 39f0b23d4..6f6a72737 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -48,7 +48,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) -* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) +* [Clément Bonet](https://clbonet.github.io) (Wasserstein on circle, Spherical Sliced-Wasserstein) * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) diff --git a/README.md b/README.md index 7bbae9e8a..f64db8f56 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ POT provides the following generic OT solvers (links to examples): * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. * [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59]. -* Gaussian Mixture Model OT [69] +* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/others/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) [69]. * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. diff --git a/RELEASES.md b/RELEASES.md index 0ddac599b..a0474eda0 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,7 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) +- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/gmm.py b/ot/gmm.py index cde2f8bbd..5c7a4c287 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -12,7 +12,7 @@ from .backend import get_backend from .lp import emd2, emd import numpy as np -from .lp import dist +from .utils import dist from .gaussian import bures_wasserstein_mapping diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index c6837f1d3..6672240d0 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -185,7 +185,7 @@ def partial_gromov_wasserstein( if m is None: m = min(np.sum(p), np.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(np.sum(p), np.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -654,7 +654,7 @@ def partial_fused_gromov_wasserstein( if m is None: m = min(np.sum(p), np.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(np.sum(p), np.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -1213,7 +1213,7 @@ def entropic_partial_gromov_wasserstein( if m is None: m = min(nx.sum(p), nx.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(nx.sum(p), nx.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" diff --git a/ot/gromov/_quantized.py b/ot/gromov/_quantized.py index ac2db5d2d..f4a8fafa7 100644 --- a/ot/gromov/_quantized.py +++ b/ot/gromov/_quantized.py @@ -375,7 +375,7 @@ def get_graph_partition( raise ValueError( f""" Unknown `part_method='{part_method}'`. Use one of: - {'random', 'louvain', 'fluid', 'spectral', 'GW', 'FGW'}. + {"random", "louvain", "fluid", "spectral", "GW", "FGW"}. """ ) return nx.from_numpy(part, type_as=C0) @@ -447,7 +447,7 @@ def get_graph_representants(C, part, rep_method="pagerank", random_state=0, nx=N raise ValueError( f""" Unknown `rep_method='{rep_method}'`. Use one of: - {'random', 'pagerank'}. + {"random", "pagerank"}. """ ) @@ -953,7 +953,7 @@ def get_partition_and_representants_samples( else: raise ValueError( f""" - Unknown `method='{method}'`. Use one of: {'random', 'kmeans'} + Unknown `method='{method}'`. Use one of: {"random", "kmeans"} """ ) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2b93e84f3..e3cfce0fd 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,15 +8,18 @@ # # License: MIT License -import numpy as np -import warnings - from . import cvx -from .cvx import barycenter from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize +from ._network_simplex import emd, emd2 +from ._barycenter_solvers import ( + barycenter, + free_support_barycenter, + generalized_free_support_barycenter, +) +from ..utils import check_number_threads # import compiled emd -from .emd_wrap import emd_c, check_result, emd_1d_sorted +from .emd_wrap import emd_1d_sorted from .solver_1d import ( emd_1d, emd2_1d, @@ -26,9 +29,6 @@ semidiscrete_wasserstein2_unif_circle, ) -from ..utils import dist, list_to_array -from ..backend import get_backend - __all__ = [ "emd", "emd2", @@ -45,867 +45,5 @@ "semidiscrete_wasserstein2_unif_circle", "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", + "check_number_threads", ] - - -def check_number_threads(numThreads): - """Checks whether or not the requested number of threads has a valid value. - - Parameters - ---------- - numThreads : int or str - The requested number of threads, should either be a strictly positive integer or "max" or None - - Returns - ------- - numThreads : int - Corrected number of threads - """ - if (numThreads is None) or ( - isinstance(numThreads, str) and numThreads.lower() == "max" - ): - return -1 - if (not isinstance(numThreads, int)) or numThreads < 1: - raise ValueError( - 'numThreads should either be "max" or a strictly positive integer' - ) - return numThreads - - -def center_ot_dual(alpha0, beta0, a=None, b=None): - r"""Center dual OT potentials w.r.t. their weights - - The main idea of this function is to find unique dual potentials - that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having - stability when multiple calling of the OT solver with small changes. - - Basically we add another constraint to the potential that will not - change the objective value but will ensure unicity. The constraint - is the following: - - .. math:: - \alpha^T \mathbf{a} = \beta^T \mathbf{b} - - in addition to the OT problem constraints. - - since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing - a constant from both :math:`\alpha_0` and :math:`\beta_0`. - - .. math:: - c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} - - \alpha &= \alpha_0 + c - - \beta &= \beta_0 + c - - Parameters - ---------- - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - a : (ns,) numpy.ndarray, float64 - Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) - - Returns - ------- - alpha : (ns,) numpy.ndarray, float64 - Source centered dual potential - beta : (nt,) numpy.ndarray, float64 - Target centered dual potential - - """ - # if no weights are provided, use uniform - if a is None: - a = np.ones(alpha0.shape[0]) / alpha0.shape[0] - if b is None: - b = np.ones(beta0.shape[0]) / beta0.shape[0] - - # compute constant that balances the weighted sums of the duals - c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) - - # update duals - alpha = alpha0 + c - beta = beta0 - c - - return alpha, beta - - -def estimate_dual_null_weights(alpha0, beta0, a, b, M): - r"""Estimate feasible values for 0-weighted dual potentials - - The feasible values are computed efficiently but rather coarsely. - - .. warning:: - This function is necessary because the C++ solver in `emd_c` - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport - matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. - - First we compute the constraints violations: - - .. math:: - \mathbf{V} = \alpha + \beta^T - \mathbf{M} - - Next we compute the max amount of violation per row (:math:`\alpha`) and - columns (:math:`beta`) - - .. math:: - \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} - - \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} - - Finally we update the dual potential with 0 weights if a - constraint is violated - - .. math:: - \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 - - \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 - - In the end the dual potentials are centered using function - :py:func:`ot.lp.center_ot_dual`. - - Note that all those updates do not change the objective value of the - solution but provide dual potentials that do not violate the constraints. - - Parameters - ---------- - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - a : (ns,) numpy.ndarray, float64 - Source distribution (uniform weights if empty list) - b : (nt,) numpy.ndarray, float64 - Target distribution (uniform weights if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - - Returns - ------- - alpha : (ns,) numpy.ndarray, float64 - Source corrected dual potential - beta : (nt,) numpy.ndarray, float64 - Target corrected dual potential - - """ - - # binary indexing of non-zeros weights - asel = a != 0 - bsel = b != 0 - - # compute dual constraints violation - constraint_violation = alpha0[:, None] + beta0[None, :] - M - - # Compute largest violation per line and columns - aviol = np.max(constraint_violation, 1) - bviol = np.max(constraint_violation, 0) - - # update corrects violation of - alpha_up = -1 * ~asel * np.maximum(aviol, 0) - beta_up = -1 * ~bsel * np.maximum(bviol, 0) - - alpha = alpha0 + alpha_up - beta = beta0 + beta_up - - return center_ot_dual(alpha, beta, a, b) - - -def emd( - a, - b, - M, - numItermax=100000, - log=False, - center_dual=True, - numThreads=1, - check_marginals=True, -): - r"""Solves the Earth Movers distance problem and returns the OT matrix - - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - where : - - - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - - .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order - numpy.array in float64 format. It will be converted if not in this - format - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - .. note:: This function will cast the computed transport plan to the data type - of the provided input with the following priority: :math:`\mathbf{a}`, - then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. - Casting to an integer tensor might result in a loss of precision. - If this behaviour is unwanted, please make sure to provide a - floating point input. - - .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. - - Uses the algorithm proposed in :ref:`[1] `. - - Parameters - ---------- - a : (ns,) array-like, float - Source histogram (uniform weight if empty list) - b : (nt,) array-like, float - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) - numItermax : int, optional (default=100000) - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual variables. - Otherwise returns only the optimal transportation matrix. - center_dual: boolean, optional (default=True) - If True, centers the dual potential using function - :py:func:`ot.lp.center_ot_dual`. - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - check_marginals: bool, optional (default=True) - If True, checks that the marginals mass are equal. If False, skips the - check. - - - Returns - ------- - gamma: array-like, shape (ns, nt) - Optimal transportation matrix for the given - parameters - log: dict, optional - If input log is true, a dictionary containing the - cost and dual variables and exit status - - - Examples - -------- - - Simple example with obvious solution. The function emd accepts lists and - perform automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a, b, M) - array([[0.5, 0. ], - [0. , 0.5]]) - - - .. _references-emd: - References - ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, - December). Displacement interpolation using Lagrangian mass transport. - In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. - - See Also - -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - """ - - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) - - if len(a) != 0: - type_as = a - elif len(b) != 0: - type_as = b - else: - type_as = M - - # if empty array given then use uniform distributions - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) - - # ensure float64 - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - - assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] - ), "Dimension mismatch, check dimensions of M with a and b" - - # ensure that same mass - if check_marginals: - np.testing.assert_almost_equal( - a.sum(0), - b.sum(0), - err_msg="a and b vector must have the same sum", - decimal=6, - ) - b = b * a.sum() / b.sum() - - asel = a != 0 - bsel = b != 0 - - numThreads = check_number_threads(numThreads) - - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - result_code_string = check_result(result_code) - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - if log: - log = {} - log["cost"] = cost - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - return nx.from_numpy(G, type_as=type_as), log - return nx.from_numpy(G, type_as=type_as) - - -def emd2( - a, - b, - M, - processes=1, - numItermax=100000, - log=False, - return_matrix=False, - center_dual=True, - numThreads=1, - check_marginals=True, -): - r"""Solves the Earth Movers distance problem and returns the loss - - .. math:: - \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - where : - - - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - .. note:: This function will cast the computed transport plan and - transportation loss to the data type of the provided input with the - following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, - then :math:`\mathbf{M}` if marginals are not provided. - Casting to an integer tensor might result in a loss of precision. - If this behaviour is unwanted, please make sure to provide a - floating point input. - - .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. - - Uses the algorithm proposed in :ref:`[1] `. - - Parameters - ---------- - a : (ns,) array-like, float64 - Source histogram (uniform weight if empty list) - b : (nt,) array-like, float64 - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float64 - Loss matrix (for numpy c-order array with type float64) - processes : int, optional (default=1) - Nb of processes used for multiple emd computation (deprecated) - numItermax : int, optional (default=100000) - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: boolean, optional (default=False) - If True, returns a dictionary containing dual - variables. Otherwise returns only the optimal transportation cost. - return_matrix: boolean, optional (default=False) - If True, returns the optimal transportation matrix in the log. - center_dual: boolean, optional (default=True) - If True, centers the dual potential using function - :py:func:`ot.lp.center_ot_dual`. - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - check_marginals: bool, optional (default=True) - If True, checks that the marginals mass are equal. If False, skips the - check. - - - Returns - ------- - W: float, array-like - Optimal transportation loss for the given parameters - log: dict - If input log is true, a dictionary containing dual - variables and exit status - - - Examples - -------- - - Simple example with obvious solution. The function emd accepts lists and - perform automatic conversion to numpy arrays - - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd2(a,b,M) - 0.0 - - - .. _references-emd2: - References - ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. - - See Also - -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - """ - - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) - - if len(a) != 0: - type_as = a - elif len(b) != 0: - type_as = b - else: - type_as = M - - # if empty array given then use uniform distributions - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - - # store original tensors - a0, b0, M0 = a, b, M - - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] - ), "Dimension mismatch, check dimensions of M with a and b" - - # ensure that same mass - if check_marginals: - np.testing.assert_almost_equal( - a.sum(0), - b.sum(0, keepdims=True), - err_msg="a and b vector must have the same sum", - decimal=6, - ) - b = b * a.sum(0) / b.sum(0, keepdims=True) - - asel = a != 0 - - numThreads = check_number_threads(numThreads) - - if log or return_matrix: - - def f(b): - bsel = b != 0 - - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - result_code_string = check_result(result_code) - log = {} - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - G = nx.from_numpy(G, type_as=type_as) - if return_matrix: - log["G"] = G - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), - ) - return [cost, log] - else: - - def f(b): - bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - G = nx.from_numpy(G, type_as=type_as) - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - ( - nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), - G, - ), - ) - - check_result(result_code) - return cost - - if len(b.shape) == 1: - return f(b) - nb = b.shape[1] - - if processes > 1: - warnings.warn( - "The 'processes' parameter has been deprecated. " - "Multiprocessing should be done outside of POT." - ) - res = list(map(f, [b[:, i].copy() for i in range(nb)])) - - return res - - -def free_support_barycenter( - measures_locations, - measures_weights, - X_init, - b=None, - weights=None, - numItermax=100, - stopThr=1e-7, - verbose=False, - log=None, - numThreads=1, -): - r""" - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: - - .. math:: - \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) - - where : - - - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) - - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - - This problem is considered in :ref:`[20] ` (Algorithm 2). - There are two differences with the following codes: - - - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[20] ` (Algorithm 2). This can be seen as a discrete - implementation of the fixed-point algorithm of - :ref:`[43] ` proposed in the continuous setting. - - Parameters - ---------- - measures_locations : list of N (k_i,d) array-like - The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space - (:math:`k_i` can be different for each element of the list) - measures_weights : list of N (k_i,) array-like - Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one - representing the weights of each discrete input measure - - X_init : (k,d) array-like - Initialization of the support locations (on `k` atoms) of the barycenter - b : (k,) array-like - Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (N,) array-like - Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - - - Returns - ------- - X : (k,d) array-like - Support locations (on k atoms) of the barycenter - - - .. _references-free-support-barycenter: - - References - ---------- - .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - - """ - - nx = get_backend(*measures_locations, *measures_weights, X_init) - - iter_count = 0 - - N = len(measures_locations) - k = X_init.shape[0] - d = X_init.shape[1] - if b is None: - b = nx.ones((k,), type_as=X_init) / k - if weights is None: - weights = nx.ones((N,), type_as=X_init) / N - - X = X_init - - log_dict = {} - displacement_square_norms = [] - - displacement_square_norm = stopThr + 1.0 - - while displacement_square_norm > stopThr and iter_count < numItermax: - T_sum = nx.zeros((k, d), type_as=X_init) - - for measure_locations_i, measure_weights_i, weight_i in zip( - measures_locations, measures_weights, weights - ): - M_i = dist(X, measure_locations_i) - T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( - T_i, measure_locations_i - ) - - displacement_square_norm = nx.sum((T_sum - X) ** 2) - if log: - displacement_square_norms.append(displacement_square_norm) - - X = T_sum - - if verbose: - print( - "iteration %d, displacement_square_norm=%f\n", - iter_count, - displacement_square_norm, - ) - - iter_count += 1 - - if log: - log_dict["displacement_square_norms"] = displacement_square_norms - return X, log_dict - else: - return X - - -def generalized_free_support_barycenter( - X_list, - a_list, - P_list, - n_samples_bary, - Y_init=None, - b=None, - weights=None, - numItermax=100, - stopThr=1e-7, - verbose=False, - log=None, - numThreads=1, - eps=0, -): - r""" - Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with - a fixed amount of points of uniform weights) whose respective projections fit the input measures. - More formally: - - .. math:: - \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) - - where : - - - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` - - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter - - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` - - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) - - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations - - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) - - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` - - As show by :ref:`[42] `, - this problem can be re-written as a Wasserstein Barycenter problem, - which we solve using the free support method :ref:`[20] ` - (Algorithm 2). - - Parameters - ---------- - X_list : list of p (k_i,d_i) array-like - Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space - (:math:`k_i` can be different for each element of the list) - a_list : list of p (k_i,) array-like - Measure weights: each element is a vector (k_i) on the simplex - P_list : list of p (d_i,d) array-like - Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` - n_samples_bary : int - Number of barycenter points - Y_init : (n_samples_bary,d) array-like - Initialization of the support locations (on `k` atoms) of the barycenter - b : (n_samples_bary,) array-like - Initialization of the weights of the barycenter measure (on the simplex) - weights : (p,) array-like - Initialization of the coefficients of the barycenter (on the simplex) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - eps: Stability coefficient for the change of variable matrix inversion - If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix - inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) - - - Returns - ------- - Y : (n_samples_bary,d) array-like - Support locations (on n_samples_bary atoms) of the barycenter - - - .. _references-generalized-free-support-barycenter: - References - ---------- - .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. - - """ - nx = get_backend(*X_list, *a_list, *P_list) - d = P_list[0].shape[1] - p = len(X_list) - - if weights is None: - weights = nx.ones(p, type_as=X_list[0]) / p - - # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) - A = eps * nx.eye( - d, type_as=X_list[0] - ) # if eps nonzero: will force the invertibility of A - for P_i, lambda_i in zip(P_list, weights): - A = A + lambda_i * P_i.T @ P_i - B = nx.inv(nx.sqrtm(A)) - - Z_list = [ - x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list) - ] # change of variables -> (WB) problem on Z - - if Y_init is None: - Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) - - if b is None: - b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized - - out = free_support_barycenter( - Z_list, - a_list, - Y_init, - b, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, - numThreads=numThreads, - ) - - if log: # unpack - Y, log_dict = out - else: - Y = out - log_dict = None - Y = Y @ B.T # return to the Generalized WB formulation - - if log: - return Y, log_dict - else: - return Y diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py new file mode 100644 index 000000000..8b64214d9 --- /dev/null +++ b/ot/lp/_barycenter_solvers.py @@ -0,0 +1,424 @@ +# -*- coding: utf-8 -*- +""" +OT Barycenter Solvers +""" + +# Author: Remi Flamary +# Eloi Tanguy +# +# License: MIT License + +from ..backend import get_backend +from ..utils import dist +from ._network_simplex import emd + +import numpy as np +import scipy as sp +import scipy.sparse as sps + +try: + import cvxopt # for cvxopt barycenter solver + from cvxopt import solvers, matrix, spmatrix +except ImportError: + cvxopt = False + + +def scipy_sparse_to_spmatrix(A): + """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" + coo = A.tocoo() + SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) + return SP + + +def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"): + r"""Compute the Wasserstein barycenter of distributions A + + The function solves the following optimization problem [16]: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + + The linear program is solved using the interior point solver from scipy.optimize. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coordinates) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + solver : string, optional + the solver used, default 'interior-point' use the lp solver from + scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. + + Returns + ------- + a : (d,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. + + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert len(weights) == A.shape[1] + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) + for i in range(n_distributions): + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] + # row constraints + A_eq1 = sps.hstack( + (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))) + ) + + # columns constraints + lst_idiag2 = [] + lst_eye = [] + for i in range(n_distributions): + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) + lst_eye.append(-sps.eye(n)) + else: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + + # full problem + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]: + # cvxopt not installed or interior point + + if solver is None: + solver = "interior-point" + + options = {"disp": verbose} + sol = sp.optimize.linprog( + c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options + ) + x = sol.x + b = x[-n:] + + else: + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp( + matrix(c), + scipy_sparse_to_spmatrix(G), + matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), + b=matrix(b_eq), + solver=solver, + ) + + x = np.array(sol["x"]) + b = x[-n:].ravel() + + if log: + return b, sol + else: + return b + + +def free_support_barycenter( + measures_locations, + measures_weights, + X_init, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, +): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] ` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] ` proposed in the continuous setting. + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + + .. _references-free-support-barycenter: + + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1.0 + + while displacement_square_norm > stopThr and iter_count < numItermax: + T_sum = nx.zeros((k, d), type_as=X_init) + + for measure_locations_i, measure_weights_i, weight_i in zip( + measures_locations, measures_weights, weights + ): + M_i = dist(X, measure_locations_i) + T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) + T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( + T_i, measure_locations_i + ) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print( + "iteration %d, displacement_square_norm=%f\n", + iter_count, + displacement_square_norm, + ) + + iter_count += 1 + + if log: + log_dict["displacement_square_norms"] = displacement_square_norms + return X, log_dict + else: + return X + + +def generalized_free_support_barycenter( + X_list, + a_list, + P_list, + n_samples_bary, + Y_init=None, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, + eps=0, +): + r""" + Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with + a fixed amount of points of uniform weights) whose respective projections fit the input measures. + More formally: + + .. math:: + \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) + + where : + + - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter + - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` + - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) + - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations + - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) + - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + + As show by :ref:`[42] `, + this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] ` + (Algorithm 2). + + Parameters + ---------- + X_list : list of p (k_i,d_i) array-like + Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space + (:math:`k_i` can be different for each element of the list) + a_list : list of p (k_i,) array-like + Measure weights: each element is a vector (k_i) on the simplex + P_list : list of p (d_i,d) array-like + Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + n_samples_bary : int + Number of barycenter points + Y_init : (n_samples_bary,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (n_samples_bary,) array-like + Initialization of the weights of the barycenter measure (on the simplex) + weights : (p,) array-like + Initialization of the coefficients of the barycenter (on the simplex) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + eps: Stability coefficient for the change of variable matrix inversion + If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix + inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) + + + Returns + ------- + Y : (n_samples_bary,d) array-like + Support locations (on n_samples_bary atoms) of the barycenter + + + .. _references-generalized-free-support-barycenter: + References + ---------- + .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + + """ + nx = get_backend(*X_list, *a_list, *P_list) + d = P_list[0].shape[1] + p = len(X_list) + + if weights is None: + weights = nx.ones(p, type_as=X_list[0]) / p + + # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + A = eps * nx.eye( + d, type_as=X_list[0] + ) # if eps nonzero: will force the invertibility of A + for P_i, lambda_i in zip(P_list, weights): + A = A + lambda_i * P_i.T @ P_i + B = nx.inv(nx.sqrtm(A)) + + Z_list = [ + x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list) + ] # change of variables -> (WB) problem on Z + + if Y_init is None: + Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) + + if b is None: + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized + + out = free_support_barycenter( + Z_list, + a_list, + Y_init, + b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + numThreads=numThreads, + ) + + if log: # unpack + Y, log_dict = out + else: + Y = out + log_dict = None + Y = Y @ B.T # return to the Generalized WB formulation + + if log: + return Y, log_dict + else: + return Y diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py new file mode 100644 index 000000000..492e4c7ac --- /dev/null +++ b/ot/lp/_network_simplex.py @@ -0,0 +1,588 @@ +# -*- coding: utf-8 -*- +""" +Solvers for the original linear program OT problem. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +import warnings + +from ..utils import list_to_array, check_number_threads +from ..backend import get_backend +from .emd_wrap import emd_c, check_result + + +def center_ot_dual(alpha0, beta0, a=None, b=None): + r"""Center dual OT potentials w.r.t. their weights + + The main idea of this function is to find unique dual potentials + that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having + stability when multiple calling of the OT solver with small changes. + + Basically we add another constraint to the potential that will not + change the objective value but will ensure unicity. The constraint + is the following: + + .. math:: + \alpha^T \mathbf{a} = \beta^T \mathbf{b} + + in addition to the OT problem constraints. + + since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing + a constant from both :math:`\alpha_0` and :math:`\beta_0`. + + .. math:: + c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} + + \alpha &= \alpha_0 + c + + \beta &= \beta_0 + c + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source centered dual potential + beta : (nt,) numpy.ndarray, float64 + Target centered dual potential + + """ + # if no weights are provided, use uniform + if a is None: + a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + if b is None: + b = np.ones(beta0.shape[0]) / beta0.shape[0] + + # compute constant that balances the weighted sums of the duals + c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + + # update duals + alpha = alpha0 + c + beta = beta0 - c + + return alpha, beta + + +def estimate_dual_null_weights(alpha0, beta0, a, b, M): + r"""Estimate feasible values for 0-weighted dual potentials + + The feasible values are computed efficiently but rather coarsely. + + .. warning:: + This function is necessary because the C++ solver in `emd_c` + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + + First we compute the constraints violations: + + .. math:: + \mathbf{V} = \alpha + \beta^T - \mathbf{M} + + Next we compute the max amount of violation per row (:math:`\alpha`) and + columns (:math:`beta`) + + .. math:: + \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} + + \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} + + Finally we update the dual potential with 0 weights if a + constraint is violated + + .. math:: + \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 + + \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 + + In the end the dual potentials are centered using function + :py:func:`ot.lp.center_ot_dual`. + + Note that all those updates do not change the objective value of the + solution but provide dual potentials that do not violate the constraints. + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source distribution (uniform weights if empty list) + b : (nt,) numpy.ndarray, float64 + Target distribution (uniform weights if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source corrected dual potential + beta : (nt,) numpy.ndarray, float64 + Target corrected dual potential + + """ + + # binary indexing of non-zeros weights + asel = a != 0 + bsel = b != 0 + + # compute dual constraints violation + constraint_violation = alpha0[:, None] + beta0[None, :] - M + + # Compute largest violation per line and columns + aviol = np.max(constraint_violation, 1) + bviol = np.max(constraint_violation, 0) + + # update corrects violation of + alpha_up = -1 * ~asel * np.maximum(aviol, 0) + beta_up = -1 * ~bsel * np.maximum(bviol, 0) + + alpha = alpha0 + alpha_up + beta = beta0 + beta_up + + return center_ot_dual(alpha, beta, a, b) + + +def emd( + a, + b, + M, + numItermax=100000, + log=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem and returns the OT matrix + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan to the data type + of the provided input with the following priority: :math:`\mathbf{a}`, + then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + + Uses the algorithm proposed in :ref:`[1] `. + + Parameters + ---------- + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :py:func:`ot.lp.center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal. If False, skips the + check. + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd(a, b, M) + array([[0.5, 0. ], + [0. , 0.5]]) + + + .. _references-emd: + References + ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) + + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b + else: + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # convert to numpy + M, a, b = nx.to_numpy(M, a, b) + + # ensure float64 + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order="C") + + # if empty array given then use uniform distributions + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b = b * a.sum() / b.sum() + + asel = a != 0 + bsel = b != 0 + + numThreads = check_number_threads(numThreads) + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + if log: + log = {} + log["cost"] = cost + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code + return nx.from_numpy(G, type_as=type_as), log + return nx.from_numpy(G, type_as=type_as) + + +def emd2( + a, + b, + M, + processes=1, + numItermax=100000, + log=False, + return_matrix=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem and returns the loss + + .. math:: + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan and + transportation loss to the data type of the provided input with the + following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + + Uses the algorithm proposed in :ref:`[1] `. + + Parameters + ---------- + a : (ns,) array-like, float64 + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float64 + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) + processes : int, optional (default=1) + Nb of processes used for multiple emd computation (deprecated) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing dual + variables. Otherwise returns only the optimal transportation cost. + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :py:func:`ot.lp.center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal. If False, skips the + check. + + + Returns + ------- + W: float, array-like + Optimal transportation loss for the given parameters + log: dict + If input log is true, a dictionary containing dual + variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd2(a,b,M) + 0.0 + + + .. _references-emd2: + References + ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) + + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b + else: + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M + + # convert to numpy + M, a, b = nx.to_numpy(M, a, b) + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order="C") + + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0, keepdims=True), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b = b * a.sum(0) / b.sum(0, keepdims=True) + + asel = a != 0 + + numThreads = check_number_threads(numThreads) + + if log or return_matrix: + + def f(b): + bsel = b != 0 + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + log = {} + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + G = nx.from_numpy(G, type_as=type_as) + if return_matrix: + log["G"] = G + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), + ) + return [cost, log] + else: + + def f(b): + bsel = b != 0 + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + G, + ), + ) + + check_result(result_code) + return cost + + if len(b.shape) == 1: + return f(b) + nb = b.shape[1] + + if processes > 1: + warnings.warn( + "The 'processes' parameter has been deprecated. " + "Multiprocessing should be done outside of POT." + ) + res = list(map(f, [b[:, i].copy() for i in range(nb)])) + + return res diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 01f5e5d87..e88d15375 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -1,152 +1,21 @@ # -*- coding: utf-8 -*- """ -LP solvers for optimal transport using cvxopt +(DEPRECATED) LP solvers for optimal transport using cvxopt """ # Author: Remi Flamary # # License: MIT License -import numpy as np -import scipy as sp -import scipy.sparse as sps +import warnings +from ._barycenter_solvers import barycenter -try: - import cvxopt - from cvxopt import solvers, matrix, spmatrix -except ImportError: - cvxopt = False +__all__ = ["barycenter"] -def scipy_sparse_to_spmatrix(A): - """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" - coo = A.tocoo() - SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) - return SP - -def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"): - r"""Compute the Wasserstein barycenter of distributions A - - The function solves the following optimization problem [16]: - - .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - The linear program is solved using the interior point solver from scipy.optimize. - If cvxopt solver if installed it can use cvxopt - - Note that this problem do not scale well (both in memory and computational time). - - Parameters - ---------- - A : np.ndarray (d,n) - n training distributions a_i of size d - M : np.ndarray (d,d) - loss matrix for OT - reg : float - Regularization term >0 - weights : np.ndarray (n,) - Weights of each histogram a_i on the simplex (barycentric coordinates) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - solver : string, optional - the solver used, default 'interior-point' use the lp solver from - scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. - - Returns - ------- - a : (d,) ndarray - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. - - - """ - - if weights is None: - weights = np.ones(A.shape[1]) / A.shape[1] - else: - assert len(weights) == A.shape[1] - - n_distributions = A.shape[1] - n = A.shape[0] - - n2 = n * n - c = np.zeros((0)) - b_eq1 = np.zeros((0)) - for i in range(n_distributions): - c = np.concatenate((c, M.ravel() * weights[i])) - b_eq1 = np.concatenate((b_eq1, A[:, i])) - c = np.concatenate((c, np.zeros(n))) - - lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] - # row constraints - A_eq1 = sps.hstack( - (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))) - ) - - # columns constraints - lst_idiag2 = [] - lst_eye = [] - for i in range(n_distributions): - if i == 0: - lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) - lst_eye.append(-sps.eye(n)) - else: - lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) - lst_eye.append(-sps.eye(n - 1, n)) - - A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) - b_eq2 = np.zeros((A_eq2.shape[0])) - - # full problem - A_eq = sps.vstack((A_eq1, A_eq2)) - b_eq = np.concatenate((b_eq1, b_eq2)) - - if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]: - # cvxopt not installed or interior point - - if solver is None: - solver = "interior-point" - - options = {"disp": verbose} - sol = sp.optimize.linprog( - c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options - ) - x = sol.x - b = x[-n:] - - else: - h = np.zeros((n_distributions * n2 + n)) - G = -sps.eye(n_distributions * n2 + n) - - sol = solvers.lp( - matrix(c), - scipy_sparse_to_spmatrix(G), - matrix(h), - A=scipy_sparse_to_spmatrix(A_eq), - b=matrix(b_eq), - solver=solver, - ) - - x = np.array(sol["x"]) - b = x[-n:].ravel() - - if log: - return b, sol - else: - return b +warnings.warn( + "The module ot.lp.cvx is deprecated and will be removed in future versions." + "The function `barycenter` was moved to ot.lp._barycenter_solvers and can" + "be importer via ot.lp." +) diff --git a/ot/partial.py b/ot/partial.py index c11ab228a..6b2304e08 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -126,7 +126,7 @@ def partial_wasserstein_lagrange( nx = get_backend(a, b, M) if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors - raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") + raise ValueError("Problem infeasible. Check that a and b are in the simplex") if reg_m is None: reg_m = float(nx.max(M)) + 1 @@ -171,7 +171,7 @@ def partial_wasserstein_lagrange( if log_emd["warning"] is not None: raise ValueError( - "Error in the EMD resolution: try to increase the" " number of dummy points" + "Error in the EMD resolution: try to increase the number of dummy points" ) log_emd["cost"] = nx.sum(gamma * M0) log_emd["u"] = nx.from_numpy(log_emd["u"], type_as=a0) @@ -287,7 +287,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -315,7 +315,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): if log_emd["warning"] is not None: raise ValueError( - "Error in the EMD resolution: try to increase the" " number of dummy points" + "Error in the EMD resolution: try to increase the number of dummy points" ) log_emd["partial_w_dist"] = nx.sum(M * gamma) log_emd["u"] = log_emd["u"][: len(a)] @@ -522,7 +522,7 @@ def entropic_partial_wasserstein( if m is None: m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 if m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -780,7 +780,7 @@ def partial_gromov_wasserstein( if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > np.min((np.sum(p), np.sum(q))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -1132,7 +1132,7 @@ def entropic_partial_gromov_wasserstein( if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > np.min((np.sum(p), np.sum(q))): raise ValueError( "Problem infeasible. Parameter m should lower or" diff --git a/ot/utils.py b/ot/utils.py index a2d328484..66ff7e354 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -517,7 +517,7 @@ def check_random_state(seed): if isinstance(seed, np.random.RandomState): return seed raise ValueError( - "{} cannot be used to seed a numpy.random.RandomState" " instance".format(seed) + "{} cannot be used to seed a numpy.random.RandomState instance".format(seed) ) @@ -787,7 +787,7 @@ def _update_doc(self, olddoc): def _is_deprecated(func): r"""Helper to check if func is wrapped by our deprecated decorator""" if sys.version_info < (3, 5): - raise NotImplementedError("This is only available for python3.5 " "or above") + raise NotImplementedError("This is only available for python3.5 or above") closures = getattr(func, "__closure__", []) if closures is None: closures = [] @@ -1341,3 +1341,27 @@ def proj_SDP(S, nx=None, vmin=0.0): Q = nx.einsum("ijk,ik->ijk", P, w) # Q[i] = P[i] @ diag(w[i]) # R[i] = Q[i] @ P[i].T return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1))) + + +def check_number_threads(numThreads): + """Checks whether or not the requested number of threads has a valid value. + + Parameters + ---------- + numThreads : int or str + The requested number of threads, should either be a strictly positive integer or "max" or None + + Returns + ------- + numThreads : int + Corrected number of threads + """ + if (numThreads is None) or ( + isinstance(numThreads, str) and numThreads.lower() == "max" + ): + return -1 + if (not isinstance(numThreads, int)) or numThreads < 1: + raise ValueError( + 'numThreads should either be "max" or a strictly positive integer' + ) + return numThreads diff --git a/test/test_ot.py b/test/test_ot.py index da0ec746e..f84f8773a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -395,7 +395,7 @@ def test_generalised_free_support_barycenter_backends(nx): np.testing.assert_allclose(Y, nx.to_numpy(Y2)) -@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") +@pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] a2 = np.array([0, 0, 1.0])[:, None]