From 590e4d714f73d6596cd14614c93b1c15e7426c51 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 11:43:10 +0100 Subject: [PATCH 01/27] ot.lp reorganise to avoid def in __init__ --- CONTRIBUTORS.md | 2 +- RELEASES.md | 2 + ot/lp/__init__.py | 876 +-------------------------------------- ot/lp/barycenter.py | 266 ++++++++++++ ot/lp/network_simplex.py | 612 +++++++++++++++++++++++++++ 5 files changed, 887 insertions(+), 871 deletions(-) create mode 100644 ot/lp/barycenter.py create mode 100644 ot/lp/network_simplex.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 39f0b23d4..6f6a72737 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -48,7 +48,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) -* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) +* [Clément Bonet](https://clbonet.github.io) (Wasserstein on circle, Spherical Sliced-Wasserstein) * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) diff --git a/RELEASES.md b/RELEASES.md index 0ddac599b..e29be544e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,8 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) +- Implement fixed-point solver for OT barycenters with generic cost functions + (generalizes `ot.lp.free_support_barycenter`). (PR #???) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2b93e84f3..d11a5ee41 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,15 +8,17 @@ # # License: MIT License -import numpy as np -import warnings - from . import cvx from .cvx import barycenter from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize +from .network_simplex import emd, emd2 +from .barycenter import ( + free_support_barycenter, + generalized_free_support_barycenter +) # import compiled emd -from .emd_wrap import emd_c, check_result, emd_1d_sorted +from .emd_wrap import emd_1d_sorted from .solver_1d import ( emd_1d, emd2_1d, @@ -26,9 +28,6 @@ semidiscrete_wasserstein2_unif_circle, ) -from ..utils import dist, list_to_array -from ..backend import get_backend - __all__ = [ "emd", "emd2", @@ -46,866 +45,3 @@ "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", ] - - -def check_number_threads(numThreads): - """Checks whether or not the requested number of threads has a valid value. - - Parameters - ---------- - numThreads : int or str - The requested number of threads, should either be a strictly positive integer or "max" or None - - Returns - ------- - numThreads : int - Corrected number of threads - """ - if (numThreads is None) or ( - isinstance(numThreads, str) and numThreads.lower() == "max" - ): - return -1 - if (not isinstance(numThreads, int)) or numThreads < 1: - raise ValueError( - 'numThreads should either be "max" or a strictly positive integer' - ) - return numThreads - - -def center_ot_dual(alpha0, beta0, a=None, b=None): - r"""Center dual OT potentials w.r.t. their weights - - The main idea of this function is to find unique dual potentials - that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having - stability when multiple calling of the OT solver with small changes. - - Basically we add another constraint to the potential that will not - change the objective value but will ensure unicity. The constraint - is the following: - - .. math:: - \alpha^T \mathbf{a} = \beta^T \mathbf{b} - - in addition to the OT problem constraints. - - since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing - a constant from both :math:`\alpha_0` and :math:`\beta_0`. - - .. math:: - c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} - - \alpha &= \alpha_0 + c - - \beta &= \beta_0 + c - - Parameters - ---------- - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - a : (ns,) numpy.ndarray, float64 - Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) - - Returns - ------- - alpha : (ns,) numpy.ndarray, float64 - Source centered dual potential - beta : (nt,) numpy.ndarray, float64 - Target centered dual potential - - """ - # if no weights are provided, use uniform - if a is None: - a = np.ones(alpha0.shape[0]) / alpha0.shape[0] - if b is None: - b = np.ones(beta0.shape[0]) / beta0.shape[0] - - # compute constant that balances the weighted sums of the duals - c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) - - # update duals - alpha = alpha0 + c - beta = beta0 - c - - return alpha, beta - - -def estimate_dual_null_weights(alpha0, beta0, a, b, M): - r"""Estimate feasible values for 0-weighted dual potentials - - The feasible values are computed efficiently but rather coarsely. - - .. warning:: - This function is necessary because the C++ solver in `emd_c` - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport - matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. - - First we compute the constraints violations: - - .. math:: - \mathbf{V} = \alpha + \beta^T - \mathbf{M} - - Next we compute the max amount of violation per row (:math:`\alpha`) and - columns (:math:`beta`) - - .. math:: - \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} - - \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} - - Finally we update the dual potential with 0 weights if a - constraint is violated - - .. math:: - \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 - - \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 - - In the end the dual potentials are centered using function - :py:func:`ot.lp.center_ot_dual`. - - Note that all those updates do not change the objective value of the - solution but provide dual potentials that do not violate the constraints. - - Parameters - ---------- - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - a : (ns,) numpy.ndarray, float64 - Source distribution (uniform weights if empty list) - b : (nt,) numpy.ndarray, float64 - Target distribution (uniform weights if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - - Returns - ------- - alpha : (ns,) numpy.ndarray, float64 - Source corrected dual potential - beta : (nt,) numpy.ndarray, float64 - Target corrected dual potential - - """ - - # binary indexing of non-zeros weights - asel = a != 0 - bsel = b != 0 - - # compute dual constraints violation - constraint_violation = alpha0[:, None] + beta0[None, :] - M - - # Compute largest violation per line and columns - aviol = np.max(constraint_violation, 1) - bviol = np.max(constraint_violation, 0) - - # update corrects violation of - alpha_up = -1 * ~asel * np.maximum(aviol, 0) - beta_up = -1 * ~bsel * np.maximum(bviol, 0) - - alpha = alpha0 + alpha_up - beta = beta0 + beta_up - - return center_ot_dual(alpha, beta, a, b) - - -def emd( - a, - b, - M, - numItermax=100000, - log=False, - center_dual=True, - numThreads=1, - check_marginals=True, -): - r"""Solves the Earth Movers distance problem and returns the OT matrix - - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - where : - - - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - - .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order - numpy.array in float64 format. It will be converted if not in this - format - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - .. note:: This function will cast the computed transport plan to the data type - of the provided input with the following priority: :math:`\mathbf{a}`, - then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. - Casting to an integer tensor might result in a loss of precision. - If this behaviour is unwanted, please make sure to provide a - floating point input. - - .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. - - Uses the algorithm proposed in :ref:`[1] `. - - Parameters - ---------- - a : (ns,) array-like, float - Source histogram (uniform weight if empty list) - b : (nt,) array-like, float - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) - numItermax : int, optional (default=100000) - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual variables. - Otherwise returns only the optimal transportation matrix. - center_dual: boolean, optional (default=True) - If True, centers the dual potential using function - :py:func:`ot.lp.center_ot_dual`. - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - check_marginals: bool, optional (default=True) - If True, checks that the marginals mass are equal. If False, skips the - check. - - - Returns - ------- - gamma: array-like, shape (ns, nt) - Optimal transportation matrix for the given - parameters - log: dict, optional - If input log is true, a dictionary containing the - cost and dual variables and exit status - - - Examples - -------- - - Simple example with obvious solution. The function emd accepts lists and - perform automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a, b, M) - array([[0.5, 0. ], - [0. , 0.5]]) - - - .. _references-emd: - References - ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, - December). Displacement interpolation using Lagrangian mass transport. - In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. - - See Also - -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - """ - - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) - - if len(a) != 0: - type_as = a - elif len(b) != 0: - type_as = b - else: - type_as = M - - # if empty array given then use uniform distributions - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) - - # ensure float64 - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - - assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] - ), "Dimension mismatch, check dimensions of M with a and b" - - # ensure that same mass - if check_marginals: - np.testing.assert_almost_equal( - a.sum(0), - b.sum(0), - err_msg="a and b vector must have the same sum", - decimal=6, - ) - b = b * a.sum() / b.sum() - - asel = a != 0 - bsel = b != 0 - - numThreads = check_number_threads(numThreads) - - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - result_code_string = check_result(result_code) - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - if log: - log = {} - log["cost"] = cost - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - return nx.from_numpy(G, type_as=type_as), log - return nx.from_numpy(G, type_as=type_as) - - -def emd2( - a, - b, - M, - processes=1, - numItermax=100000, - log=False, - return_matrix=False, - center_dual=True, - numThreads=1, - check_marginals=True, -): - r"""Solves the Earth Movers distance problem and returns the loss - - .. math:: - \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - where : - - - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - .. note:: This function will cast the computed transport plan and - transportation loss to the data type of the provided input with the - following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, - then :math:`\mathbf{M}` if marginals are not provided. - Casting to an integer tensor might result in a loss of precision. - If this behaviour is unwanted, please make sure to provide a - floating point input. - - .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. - - Uses the algorithm proposed in :ref:`[1] `. - - Parameters - ---------- - a : (ns,) array-like, float64 - Source histogram (uniform weight if empty list) - b : (nt,) array-like, float64 - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float64 - Loss matrix (for numpy c-order array with type float64) - processes : int, optional (default=1) - Nb of processes used for multiple emd computation (deprecated) - numItermax : int, optional (default=100000) - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: boolean, optional (default=False) - If True, returns a dictionary containing dual - variables. Otherwise returns only the optimal transportation cost. - return_matrix: boolean, optional (default=False) - If True, returns the optimal transportation matrix in the log. - center_dual: boolean, optional (default=True) - If True, centers the dual potential using function - :py:func:`ot.lp.center_ot_dual`. - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - check_marginals: bool, optional (default=True) - If True, checks that the marginals mass are equal. If False, skips the - check. - - - Returns - ------- - W: float, array-like - Optimal transportation loss for the given parameters - log: dict - If input log is true, a dictionary containing dual - variables and exit status - - - Examples - -------- - - Simple example with obvious solution. The function emd accepts lists and - perform automatic conversion to numpy arrays - - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd2(a,b,M) - 0.0 - - - .. _references-emd2: - References - ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. - - See Also - -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - """ - - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) - - if len(a) != 0: - type_as = a - elif len(b) != 0: - type_as = b - else: - type_as = M - - # if empty array given then use uniform distributions - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - - # store original tensors - a0, b0, M0 = a, b, M - - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] - ), "Dimension mismatch, check dimensions of M with a and b" - - # ensure that same mass - if check_marginals: - np.testing.assert_almost_equal( - a.sum(0), - b.sum(0, keepdims=True), - err_msg="a and b vector must have the same sum", - decimal=6, - ) - b = b * a.sum(0) / b.sum(0, keepdims=True) - - asel = a != 0 - - numThreads = check_number_threads(numThreads) - - if log or return_matrix: - - def f(b): - bsel = b != 0 - - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - result_code_string = check_result(result_code) - log = {} - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - G = nx.from_numpy(G, type_as=type_as) - if return_matrix: - log["G"] = G - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), - ) - return [cost, log] - else: - - def f(b): - bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - G = nx.from_numpy(G, type_as=type_as) - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - ( - nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), - G, - ), - ) - - check_result(result_code) - return cost - - if len(b.shape) == 1: - return f(b) - nb = b.shape[1] - - if processes > 1: - warnings.warn( - "The 'processes' parameter has been deprecated. " - "Multiprocessing should be done outside of POT." - ) - res = list(map(f, [b[:, i].copy() for i in range(nb)])) - - return res - - -def free_support_barycenter( - measures_locations, - measures_weights, - X_init, - b=None, - weights=None, - numItermax=100, - stopThr=1e-7, - verbose=False, - log=None, - numThreads=1, -): - r""" - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: - - .. math:: - \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) - - where : - - - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) - - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - - This problem is considered in :ref:`[20] ` (Algorithm 2). - There are two differences with the following codes: - - - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[20] ` (Algorithm 2). This can be seen as a discrete - implementation of the fixed-point algorithm of - :ref:`[43] ` proposed in the continuous setting. - - Parameters - ---------- - measures_locations : list of N (k_i,d) array-like - The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space - (:math:`k_i` can be different for each element of the list) - measures_weights : list of N (k_i,) array-like - Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one - representing the weights of each discrete input measure - - X_init : (k,d) array-like - Initialization of the support locations (on `k` atoms) of the barycenter - b : (k,) array-like - Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (N,) array-like - Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - - - Returns - ------- - X : (k,d) array-like - Support locations (on k atoms) of the barycenter - - - .. _references-free-support-barycenter: - - References - ---------- - .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - - """ - - nx = get_backend(*measures_locations, *measures_weights, X_init) - - iter_count = 0 - - N = len(measures_locations) - k = X_init.shape[0] - d = X_init.shape[1] - if b is None: - b = nx.ones((k,), type_as=X_init) / k - if weights is None: - weights = nx.ones((N,), type_as=X_init) / N - - X = X_init - - log_dict = {} - displacement_square_norms = [] - - displacement_square_norm = stopThr + 1.0 - - while displacement_square_norm > stopThr and iter_count < numItermax: - T_sum = nx.zeros((k, d), type_as=X_init) - - for measure_locations_i, measure_weights_i, weight_i in zip( - measures_locations, measures_weights, weights - ): - M_i = dist(X, measure_locations_i) - T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( - T_i, measure_locations_i - ) - - displacement_square_norm = nx.sum((T_sum - X) ** 2) - if log: - displacement_square_norms.append(displacement_square_norm) - - X = T_sum - - if verbose: - print( - "iteration %d, displacement_square_norm=%f\n", - iter_count, - displacement_square_norm, - ) - - iter_count += 1 - - if log: - log_dict["displacement_square_norms"] = displacement_square_norms - return X, log_dict - else: - return X - - -def generalized_free_support_barycenter( - X_list, - a_list, - P_list, - n_samples_bary, - Y_init=None, - b=None, - weights=None, - numItermax=100, - stopThr=1e-7, - verbose=False, - log=None, - numThreads=1, - eps=0, -): - r""" - Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with - a fixed amount of points of uniform weights) whose respective projections fit the input measures. - More formally: - - .. math:: - \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) - - where : - - - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` - - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter - - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` - - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) - - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations - - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) - - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` - - As show by :ref:`[42] `, - this problem can be re-written as a Wasserstein Barycenter problem, - which we solve using the free support method :ref:`[20] ` - (Algorithm 2). - - Parameters - ---------- - X_list : list of p (k_i,d_i) array-like - Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space - (:math:`k_i` can be different for each element of the list) - a_list : list of p (k_i,) array-like - Measure weights: each element is a vector (k_i) on the simplex - P_list : list of p (d_i,d) array-like - Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` - n_samples_bary : int - Number of barycenter points - Y_init : (n_samples_bary,d) array-like - Initialization of the support locations (on `k` atoms) of the barycenter - b : (n_samples_bary,) array-like - Initialization of the weights of the barycenter measure (on the simplex) - weights : (p,) array-like - Initialization of the coefficients of the barycenter (on the simplex) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - eps: Stability coefficient for the change of variable matrix inversion - If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix - inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) - - - Returns - ------- - Y : (n_samples_bary,d) array-like - Support locations (on n_samples_bary atoms) of the barycenter - - - .. _references-generalized-free-support-barycenter: - References - ---------- - .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. - - """ - nx = get_backend(*X_list, *a_list, *P_list) - d = P_list[0].shape[1] - p = len(X_list) - - if weights is None: - weights = nx.ones(p, type_as=X_list[0]) / p - - # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) - A = eps * nx.eye( - d, type_as=X_list[0] - ) # if eps nonzero: will force the invertibility of A - for P_i, lambda_i in zip(P_list, weights): - A = A + lambda_i * P_i.T @ P_i - B = nx.inv(nx.sqrtm(A)) - - Z_list = [ - x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list) - ] # change of variables -> (WB) problem on Z - - if Y_init is None: - Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) - - if b is None: - b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized - - out = free_support_barycenter( - Z_list, - a_list, - Y_init, - b, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, - numThreads=numThreads, - ) - - if log: # unpack - Y, log_dict = out - else: - Y = out - log_dict = None - Y = Y @ B.T # return to the Generalized WB formulation - - if log: - return Y, log_dict - else: - return Y diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter.py new file mode 100644 index 000000000..5468fb4eb --- /dev/null +++ b/ot/lp/barycenter.py @@ -0,0 +1,266 @@ + +def free_support_barycenter( + measures_locations, + measures_weights, + X_init, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, +): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] ` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] ` proposed in the continuous setting. + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + + .. _references-free-support-barycenter: + + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1.0 + + while displacement_square_norm > stopThr and iter_count < numItermax: + T_sum = nx.zeros((k, d), type_as=X_init) + + for measure_locations_i, measure_weights_i, weight_i in zip( + measures_locations, measures_weights, weights + ): + M_i = dist(X, measure_locations_i) + T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) + T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( + T_i, measure_locations_i + ) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print( + "iteration %d, displacement_square_norm=%f\n", + iter_count, + displacement_square_norm, + ) + + iter_count += 1 + + if log: + log_dict["displacement_square_norms"] = displacement_square_norms + return X, log_dict + else: + return X + + +def generalized_free_support_barycenter( + X_list, + a_list, + P_list, + n_samples_bary, + Y_init=None, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, + eps=0, +): + r""" + Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with + a fixed amount of points of uniform weights) whose respective projections fit the input measures. + More formally: + + .. math:: + \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) + + where : + + - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter + - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` + - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) + - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations + - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) + - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + + As show by :ref:`[42] `, + this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] ` + (Algorithm 2). + + Parameters + ---------- + X_list : list of p (k_i,d_i) array-like + Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space + (:math:`k_i` can be different for each element of the list) + a_list : list of p (k_i,) array-like + Measure weights: each element is a vector (k_i) on the simplex + P_list : list of p (d_i,d) array-like + Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + n_samples_bary : int + Number of barycenter points + Y_init : (n_samples_bary,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (n_samples_bary,) array-like + Initialization of the weights of the barycenter measure (on the simplex) + weights : (p,) array-like + Initialization of the coefficients of the barycenter (on the simplex) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + eps: Stability coefficient for the change of variable matrix inversion + If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix + inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) + + + Returns + ------- + Y : (n_samples_bary,d) array-like + Support locations (on n_samples_bary atoms) of the barycenter + + + .. _references-generalized-free-support-barycenter: + References + ---------- + .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + + """ + nx = get_backend(*X_list, *a_list, *P_list) + d = P_list[0].shape[1] + p = len(X_list) + + if weights is None: + weights = nx.ones(p, type_as=X_list[0]) / p + + # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + A = eps * nx.eye( + d, type_as=X_list[0] + ) # if eps nonzero: will force the invertibility of A + for P_i, lambda_i in zip(P_list, weights): + A = A + lambda_i * P_i.T @ P_i + B = nx.inv(nx.sqrtm(A)) + + Z_list = [ + x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list) + ] # change of variables -> (WB) problem on Z + + if Y_init is None: + Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) + + if b is None: + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized + + out = free_support_barycenter( + Z_list, + a_list, + Y_init, + b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + numThreads=numThreads, + ) + + if log: # unpack + Y, log_dict = out + else: + Y = out + log_dict = None + Y = Y @ B.T # return to the Generalized WB formulation + + if log: + return Y, log_dict + else: + return Y diff --git a/ot/lp/network_simplex.py b/ot/lp/network_simplex.py new file mode 100644 index 000000000..0e820fec6 --- /dev/null +++ b/ot/lp/network_simplex.py @@ -0,0 +1,612 @@ +# -*- coding: utf-8 -*- +""" +Solvers for the original linear program OT problem. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +import warnings + +from ..utils import list_to_array +from ..backend import get_backend +from .emd_wrap import emd_c, check_result + + +def check_number_threads(numThreads): + """Checks whether or not the requested number of threads has a valid value. + + Parameters + ---------- + numThreads : int or str + The requested number of threads, should either be a strictly positive integer or "max" or None + + Returns + ------- + numThreads : int + Corrected number of threads + """ + if (numThreads is None) or ( + isinstance(numThreads, str) and numThreads.lower() == "max" + ): + return -1 + if (not isinstance(numThreads, int)) or numThreads < 1: + raise ValueError( + 'numThreads should either be "max" or a strictly positive integer' + ) + return numThreads + + +def center_ot_dual(alpha0, beta0, a=None, b=None): + r"""Center dual OT potentials w.r.t. their weights + + The main idea of this function is to find unique dual potentials + that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having + stability when multiple calling of the OT solver with small changes. + + Basically we add another constraint to the potential that will not + change the objective value but will ensure unicity. The constraint + is the following: + + .. math:: + \alpha^T \mathbf{a} = \beta^T \mathbf{b} + + in addition to the OT problem constraints. + + since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing + a constant from both :math:`\alpha_0` and :math:`\beta_0`. + + .. math:: + c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} + + \alpha &= \alpha_0 + c + + \beta &= \beta_0 + c + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source centered dual potential + beta : (nt,) numpy.ndarray, float64 + Target centered dual potential + + """ + # if no weights are provided, use uniform + if a is None: + a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + if b is None: + b = np.ones(beta0.shape[0]) / beta0.shape[0] + + # compute constant that balances the weighted sums of the duals + c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + + # update duals + alpha = alpha0 + c + beta = beta0 - c + + return alpha, beta + + +def estimate_dual_null_weights(alpha0, beta0, a, b, M): + r"""Estimate feasible values for 0-weighted dual potentials + + The feasible values are computed efficiently but rather coarsely. + + .. warning:: + This function is necessary because the C++ solver in `emd_c` + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + + First we compute the constraints violations: + + .. math:: + \mathbf{V} = \alpha + \beta^T - \mathbf{M} + + Next we compute the max amount of violation per row (:math:`\alpha`) and + columns (:math:`beta`) + + .. math:: + \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} + + \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} + + Finally we update the dual potential with 0 weights if a + constraint is violated + + .. math:: + \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 + + \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 + + In the end the dual potentials are centered using function + :py:func:`ot.lp.center_ot_dual`. + + Note that all those updates do not change the objective value of the + solution but provide dual potentials that do not violate the constraints. + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source distribution (uniform weights if empty list) + b : (nt,) numpy.ndarray, float64 + Target distribution (uniform weights if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source corrected dual potential + beta : (nt,) numpy.ndarray, float64 + Target corrected dual potential + + """ + + # binary indexing of non-zeros weights + asel = a != 0 + bsel = b != 0 + + # compute dual constraints violation + constraint_violation = alpha0[:, None] + beta0[None, :] - M + + # Compute largest violation per line and columns + aviol = np.max(constraint_violation, 1) + bviol = np.max(constraint_violation, 0) + + # update corrects violation of + alpha_up = -1 * ~asel * np.maximum(aviol, 0) + beta_up = -1 * ~bsel * np.maximum(bviol, 0) + + alpha = alpha0 + alpha_up + beta = beta0 + beta_up + + return center_ot_dual(alpha, beta, a, b) + + +def emd( + a, + b, + M, + numItermax=100000, + log=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem and returns the OT matrix + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan to the data type + of the provided input with the following priority: :math:`\mathbf{a}`, + then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + + Uses the algorithm proposed in :ref:`[1] `. + + Parameters + ---------- + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :py:func:`ot.lp.center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal. If False, skips the + check. + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd(a, b, M) + array([[0.5, 0. ], + [0. , 0.5]]) + + + .. _references-emd: + References + ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) + + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b + else: + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # convert to numpy + M, a, b = nx.to_numpy(M, a, b) + + # ensure float64 + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order="C") + + # if empty array given then use uniform distributions + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b = b * a.sum() / b.sum() + + asel = a != 0 + bsel = b != 0 + + numThreads = check_number_threads(numThreads) + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + if log: + log = {} + log["cost"] = cost + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code + return nx.from_numpy(G, type_as=type_as), log + return nx.from_numpy(G, type_as=type_as) + + +def emd2( + a, + b, + M, + processes=1, + numItermax=100000, + log=False, + return_matrix=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem and returns the loss + + .. math:: + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan and + transportation loss to the data type of the provided input with the + following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + + Uses the algorithm proposed in :ref:`[1] `. + + Parameters + ---------- + a : (ns,) array-like, float64 + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float64 + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) + processes : int, optional (default=1) + Nb of processes used for multiple emd computation (deprecated) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing dual + variables. Otherwise returns only the optimal transportation cost. + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :py:func:`ot.lp.center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal. If False, skips the + check. + + + Returns + ------- + W: float, array-like + Optimal transportation loss for the given parameters + log: dict + If input log is true, a dictionary containing dual + variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd2(a,b,M) + 0.0 + + + .. _references-emd2: + References + ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) + + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b + else: + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M + + # convert to numpy + M, a, b = nx.to_numpy(M, a, b) + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order="C") + + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0, keepdims=True), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b = b * a.sum(0) / b.sum(0, keepdims=True) + + asel = a != 0 + + numThreads = check_number_threads(numThreads) + + if log or return_matrix: + + def f(b): + bsel = b != 0 + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + log = {} + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + G = nx.from_numpy(G, type_as=type_as) + if return_matrix: + log["G"] = G + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), + ) + return [cost, log] + else: + + def f(b): + bsel = b != 0 + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + G, + ), + ) + + check_result(result_code) + return cost + + if len(b.shape) == 1: + return f(b) + nb = b.shape[1] + + if processes > 1: + warnings.warn( + "The 'processes' parameter has been deprecated. " + "Multiprocessing should be done outside of POT." + ) + res = list(map(f, [b[:, i].copy() for i in range(nb)])) + + return res From 109edb7534653c767d490703cfd631aad55a6592 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 11:53:38 +0100 Subject: [PATCH 02/27] pr number + enabled pre-commit --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index e29be544e..2eae33215 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,7 +7,7 @@ - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) - Implement fixed-point solver for OT barycenters with generic cost functions - (generalizes `ot.lp.free_support_barycenter`). (PR #???) + (generalizes `ot.lp.free_support_barycenter`). (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From 0957904c9d4fb2bdba58a357899077192c1ee52d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 11:57:45 +0100 Subject: [PATCH 03/27] added barycenter.py imports --- ot/lp/barycenter.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter.py index 5468fb4eb..b1411abe1 100644 --- a/ot/lp/barycenter.py +++ b/ot/lp/barycenter.py @@ -1,3 +1,7 @@ +from ..backend import get_backend +from ..utils import dist +from .network_simplex import emd + def free_support_barycenter( measures_locations, From 818b3e7a278af75ad5a95c50f3a599775193a768 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 12:10:21 +0100 Subject: [PATCH 04/27] fixed wrong import in ot.gmm --- ot/gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/gmm.py b/ot/gmm.py index cde2f8bbd..5c7a4c287 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -12,7 +12,7 @@ from .backend import get_backend from .lp import emd2, emd import numpy as np -from .lp import dist +from .utils import dist from .gaussian import bures_wasserstein_mapping From 08c2285cafe4a1ee6517e799a043af3251031a6e Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 12:24:20 +0100 Subject: [PATCH 05/27] ruff fix attempt --- README.md | 7 ++++++- ot/gromov/_partial.py | 6 +++--- ot/gromov/_quantized.py | 6 +++--- ot/lp/__init__.py | 6 +++--- ot/lp/{barycenter.py => barycenter_solvers.py} | 0 ot/partial.py | 14 +++++++------- ot/utils.py | 4 ++-- 7 files changed, 24 insertions(+), 19 deletions(-) rename ot/lp/{barycenter.py => barycenter_solvers.py} (100%) diff --git a/README.md b/README.md index 7bbae9e8a..dd9622d9d 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,11 @@ POT provides the following generic OT solvers (links to examples): * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. * [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59]. -* Gaussian Mixture Model OT [69] +* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/others/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) [69]. * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. +* OT Barycenters for generic transport costs []. POT provides the following Machine Learning related solvers: @@ -391,3 +392,7 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing +Barycentres of Measures for Generic Transport +Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) \ No newline at end of file diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index c6837f1d3..6672240d0 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -185,7 +185,7 @@ def partial_gromov_wasserstein( if m is None: m = min(np.sum(p), np.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(np.sum(p), np.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -654,7 +654,7 @@ def partial_fused_gromov_wasserstein( if m is None: m = min(np.sum(p), np.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(np.sum(p), np.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -1213,7 +1213,7 @@ def entropic_partial_gromov_wasserstein( if m is None: m = min(nx.sum(p), nx.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(nx.sum(p), nx.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" diff --git a/ot/gromov/_quantized.py b/ot/gromov/_quantized.py index ac2db5d2d..f4a8fafa7 100644 --- a/ot/gromov/_quantized.py +++ b/ot/gromov/_quantized.py @@ -375,7 +375,7 @@ def get_graph_partition( raise ValueError( f""" Unknown `part_method='{part_method}'`. Use one of: - {'random', 'louvain', 'fluid', 'spectral', 'GW', 'FGW'}. + {"random", "louvain", "fluid", "spectral", "GW", "FGW"}. """ ) return nx.from_numpy(part, type_as=C0) @@ -447,7 +447,7 @@ def get_graph_representants(C, part, rep_method="pagerank", random_state=0, nx=N raise ValueError( f""" Unknown `rep_method='{rep_method}'`. Use one of: - {'random', 'pagerank'}. + {"random", "pagerank"}. """ ) @@ -953,7 +953,7 @@ def get_partition_and_representants_samples( else: raise ValueError( f""" - Unknown `method='{method}'`. Use one of: {'random', 'kmeans'} + Unknown `method='{method}'`. Use one of: {"random", "kmeans"} """ ) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d11a5ee41..b29029243 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,9 +12,9 @@ from .cvx import barycenter from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize from .network_simplex import emd, emd2 -from .barycenter import ( - free_support_barycenter, - generalized_free_support_barycenter +from .barycenter_solvers import ( + free_support_barycenter, + generalized_free_support_barycenter, ) # import compiled emd diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter_solvers.py similarity index 100% rename from ot/lp/barycenter.py rename to ot/lp/barycenter_solvers.py diff --git a/ot/partial.py b/ot/partial.py index c11ab228a..6b2304e08 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -126,7 +126,7 @@ def partial_wasserstein_lagrange( nx = get_backend(a, b, M) if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors - raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") + raise ValueError("Problem infeasible. Check that a and b are in the simplex") if reg_m is None: reg_m = float(nx.max(M)) + 1 @@ -171,7 +171,7 @@ def partial_wasserstein_lagrange( if log_emd["warning"] is not None: raise ValueError( - "Error in the EMD resolution: try to increase the" " number of dummy points" + "Error in the EMD resolution: try to increase the number of dummy points" ) log_emd["cost"] = nx.sum(gamma * M0) log_emd["u"] = nx.from_numpy(log_emd["u"], type_as=a0) @@ -287,7 +287,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -315,7 +315,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): if log_emd["warning"] is not None: raise ValueError( - "Error in the EMD resolution: try to increase the" " number of dummy points" + "Error in the EMD resolution: try to increase the number of dummy points" ) log_emd["partial_w_dist"] = nx.sum(M * gamma) log_emd["u"] = log_emd["u"][: len(a)] @@ -522,7 +522,7 @@ def entropic_partial_wasserstein( if m is None: m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 if m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -780,7 +780,7 @@ def partial_gromov_wasserstein( if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > np.min((np.sum(p), np.sum(q))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -1132,7 +1132,7 @@ def entropic_partial_gromov_wasserstein( if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > np.min((np.sum(p), np.sum(q))): raise ValueError( "Problem infeasible. Parameter m should lower or" diff --git a/ot/utils.py b/ot/utils.py index a2d328484..42673ecd6 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -517,7 +517,7 @@ def check_random_state(seed): if isinstance(seed, np.random.RandomState): return seed raise ValueError( - "{} cannot be used to seed a numpy.random.RandomState" " instance".format(seed) + "{} cannot be used to seed a numpy.random.RandomState instance".format(seed) ) @@ -787,7 +787,7 @@ def _update_doc(self, olddoc): def _is_deprecated(func): r"""Helper to check if func is wrapped by our deprecated decorator""" if sys.version_info < (3, 5): - raise NotImplementedError("This is only available for python3.5 " "or above") + raise NotImplementedError("This is only available for python3.5 or above") closures = getattr(func, "__closure__", []) if closures is None: closures = [] From f26851586a7c03d4707a8ed710b8047f9acfc78c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 13:33:23 +0100 Subject: [PATCH 06/27] removed ot bar contribs -> only o.lp reorganisation in this PR --- README.md | 5 ----- RELEASES.md | 3 +-- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/README.md b/README.md index dd9622d9d..f64db8f56 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,6 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. -* OT Barycenters for generic transport costs []. POT provides the following Machine Learning related solvers: @@ -392,7 +391,3 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. - -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing -Barycentres of Measures for Generic Transport -Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 2eae33215..1550b479f 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,8 +6,7 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) -- Implement fixed-point solver for OT barycenters with generic cost functions - (generalizes `ot.lp.free_support_barycenter`). (PR #714) +- Moved functions from `ot/lp/__init__.py` to separate modules. (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From 8f24cb95f28e8c1e3f80cb6e72e768f1b45cc2dc Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 13:39:19 +0100 Subject: [PATCH 07/27] add check_number_threads to ot/lp/__init__.py __all__ --- ot/lp/__init__.py | 2 ++ ot/lp/network_simplex.py | 26 +------------------------- ot/utils.py | 24 ++++++++++++++++++++++++ 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index b29029243..548200123 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,6 +16,7 @@ free_support_barycenter, generalized_free_support_barycenter, ) +from ..utils import check_number_threads # import compiled emd from .emd_wrap import emd_1d_sorted @@ -44,4 +45,5 @@ "semidiscrete_wasserstein2_unif_circle", "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", + "check_number_threads", ] diff --git a/ot/lp/network_simplex.py b/ot/lp/network_simplex.py index 0e820fec6..492e4c7ac 100644 --- a/ot/lp/network_simplex.py +++ b/ot/lp/network_simplex.py @@ -11,35 +11,11 @@ import numpy as np import warnings -from ..utils import list_to_array +from ..utils import list_to_array, check_number_threads from ..backend import get_backend from .emd_wrap import emd_c, check_result -def check_number_threads(numThreads): - """Checks whether or not the requested number of threads has a valid value. - - Parameters - ---------- - numThreads : int or str - The requested number of threads, should either be a strictly positive integer or "max" or None - - Returns - ------- - numThreads : int - Corrected number of threads - """ - if (numThreads is None) or ( - isinstance(numThreads, str) and numThreads.lower() == "max" - ): - return -1 - if (not isinstance(numThreads, int)) or numThreads < 1: - raise ValueError( - 'numThreads should either be "max" or a strictly positive integer' - ) - return numThreads - - def center_ot_dual(alpha0, beta0, a=None, b=None): r"""Center dual OT potentials w.r.t. their weights diff --git a/ot/utils.py b/ot/utils.py index 42673ecd6..66ff7e354 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1341,3 +1341,27 @@ def proj_SDP(S, nx=None, vmin=0.0): Q = nx.einsum("ijk,ik->ijk", P, w) # Q[i] = P[i] @ diag(w[i]) # R[i] = Q[i] @ P[i].T return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1))) + + +def check_number_threads(numThreads): + """Checks whether or not the requested number of threads has a valid value. + + Parameters + ---------- + numThreads : int or str + The requested number of threads, should either be a strictly positive integer or "max" or None + + Returns + ------- + numThreads : int + Corrected number of threads + """ + if (numThreads is None) or ( + isinstance(numThreads, str) and numThreads.lower() == "max" + ): + return -1 + if (not isinstance(numThreads, int)) or numThreads < 1: + raise ValueError( + 'numThreads should either be "max" or a strictly positive integer' + ) + return numThreads From 3e3b4445f4c1edf588c8d58bb218ccadd5ad0111 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 13:41:29 +0100 Subject: [PATCH 08/27] update releases --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 1550b479f..7d138c9c6 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,7 +6,7 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) -- Moved functions from `ot/lp/__init__.py` to separate modules. (PR #714) +- Reorganize sub-module `ot/lp/__init__.py` into separate files. (PR #714) (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From 566a0fc1e3171cd16cd22b58b926a58cc3c9a2cb Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:07:46 +0100 Subject: [PATCH 09/27] made barycenter_solvers and network_simplex hidden + deprecated ot.lp.cvx --- RELEASES.md | 2 +- ot/lp/__init__.py | 6 +- ...nter_solvers.py => _barycenter_solvers.py} | 156 +++++++++++++++++- ...network_simplex.py => _network_simplex.py} | 0 ot/lp/cvx.py | 148 +---------------- 5 files changed, 163 insertions(+), 149 deletions(-) rename ot/lp/{barycenter_solvers.py => _barycenter_solvers.py} (69%) rename ot/lp/{network_simplex.py => _network_simplex.py} (100%) diff --git a/RELEASES.md b/RELEASES.md index 7d138c9c6..a0474eda0 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,7 +6,7 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) -- Reorganize sub-module `ot/lp/__init__.py` into separate files. (PR #714) (PR #714) +- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 548200123..e3cfce0fd 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -9,10 +9,10 @@ # License: MIT License from . import cvx -from .cvx import barycenter from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize -from .network_simplex import emd, emd2 -from .barycenter_solvers import ( +from ._network_simplex import emd, emd2 +from ._barycenter_solvers import ( + barycenter, free_support_barycenter, generalized_free_support_barycenter, ) diff --git a/ot/lp/barycenter_solvers.py b/ot/lp/_barycenter_solvers.py similarity index 69% rename from ot/lp/barycenter_solvers.py rename to ot/lp/_barycenter_solvers.py index b1411abe1..8b64214d9 100644 --- a/ot/lp/barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -1,6 +1,160 @@ +# -*- coding: utf-8 -*- +""" +OT Barycenter Solvers +""" + +# Author: Remi Flamary +# Eloi Tanguy +# +# License: MIT License + from ..backend import get_backend from ..utils import dist -from .network_simplex import emd +from ._network_simplex import emd + +import numpy as np +import scipy as sp +import scipy.sparse as sps + +try: + import cvxopt # for cvxopt barycenter solver + from cvxopt import solvers, matrix, spmatrix +except ImportError: + cvxopt = False + + +def scipy_sparse_to_spmatrix(A): + """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" + coo = A.tocoo() + SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) + return SP + + +def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"): + r"""Compute the Wasserstein barycenter of distributions A + + The function solves the following optimization problem [16]: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + + The linear program is solved using the interior point solver from scipy.optimize. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coordinates) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + solver : string, optional + the solver used, default 'interior-point' use the lp solver from + scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. + + Returns + ------- + a : (d,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. + + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert len(weights) == A.shape[1] + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) + for i in range(n_distributions): + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] + # row constraints + A_eq1 = sps.hstack( + (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))) + ) + + # columns constraints + lst_idiag2 = [] + lst_eye = [] + for i in range(n_distributions): + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) + lst_eye.append(-sps.eye(n)) + else: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + + # full problem + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]: + # cvxopt not installed or interior point + + if solver is None: + solver = "interior-point" + + options = {"disp": verbose} + sol = sp.optimize.linprog( + c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options + ) + x = sol.x + b = x[-n:] + + else: + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp( + matrix(c), + scipy_sparse_to_spmatrix(G), + matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), + b=matrix(b_eq), + solver=solver, + ) + + x = np.array(sol["x"]) + b = x[-n:].ravel() + + if log: + return b, sol + else: + return b def free_support_barycenter( diff --git a/ot/lp/network_simplex.py b/ot/lp/_network_simplex.py similarity index 100% rename from ot/lp/network_simplex.py rename to ot/lp/_network_simplex.py diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 01f5e5d87..b2269b8b4 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -1,152 +1,12 @@ # -*- coding: utf-8 -*- """ -LP solvers for optimal transport using cvxopt +(DEPRECATED) LP solvers for optimal transport using cvxopt """ # Author: Remi Flamary # # License: MIT License -import numpy as np -import scipy as sp -import scipy.sparse as sps - -try: - import cvxopt - from cvxopt import solvers, matrix, spmatrix -except ImportError: - cvxopt = False - - -def scipy_sparse_to_spmatrix(A): - """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" - coo = A.tocoo() - SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) - return SP - - -def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"): - r"""Compute the Wasserstein barycenter of distributions A - - The function solves the following optimization problem [16]: - - .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - The linear program is solved using the interior point solver from scipy.optimize. - If cvxopt solver if installed it can use cvxopt - - Note that this problem do not scale well (both in memory and computational time). - - Parameters - ---------- - A : np.ndarray (d,n) - n training distributions a_i of size d - M : np.ndarray (d,d) - loss matrix for OT - reg : float - Regularization term >0 - weights : np.ndarray (n,) - Weights of each histogram a_i on the simplex (barycentric coordinates) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - solver : string, optional - the solver used, default 'interior-point' use the lp solver from - scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. - - Returns - ------- - a : (d,) ndarray - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. - - - """ - - if weights is None: - weights = np.ones(A.shape[1]) / A.shape[1] - else: - assert len(weights) == A.shape[1] - - n_distributions = A.shape[1] - n = A.shape[0] - - n2 = n * n - c = np.zeros((0)) - b_eq1 = np.zeros((0)) - for i in range(n_distributions): - c = np.concatenate((c, M.ravel() * weights[i])) - b_eq1 = np.concatenate((b_eq1, A[:, i])) - c = np.concatenate((c, np.zeros(n))) - - lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] - # row constraints - A_eq1 = sps.hstack( - (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))) - ) - - # columns constraints - lst_idiag2 = [] - lst_eye = [] - for i in range(n_distributions): - if i == 0: - lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) - lst_eye.append(-sps.eye(n)) - else: - lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) - lst_eye.append(-sps.eye(n - 1, n)) - - A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) - b_eq2 = np.zeros((A_eq2.shape[0])) - - # full problem - A_eq = sps.vstack((A_eq1, A_eq2)) - b_eq = np.concatenate((b_eq1, b_eq2)) - - if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]: - # cvxopt not installed or interior point - - if solver is None: - solver = "interior-point" - - options = {"disp": verbose} - sol = sp.optimize.linprog( - c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options - ) - x = sol.x - b = x[-n:] - - else: - h = np.zeros((n_distributions * n2 + n)) - G = -sps.eye(n_distributions * n2 + n) - - sol = solvers.lp( - matrix(c), - scipy_sparse_to_spmatrix(G), - matrix(h), - A=scipy_sparse_to_spmatrix(A_eq), - b=matrix(b_eq), - solver=solver, - ) - - x = np.array(sol["x"]) - b = x[-n:].ravel() - - if log: - return b, sol - else: - return b +print( + "The module ot.lp.cvx is deprecated and will be removed in future versions. The function `barycenter` was moved to ot.lp._barycenter_solvers and can be importer via ot.lp." +) From 5c35d586ef1b6adf3b5b7d77edb8d90a504904bd Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:10:34 +0100 Subject: [PATCH 10/27] fix ref to lp.cvx in test --- test/test_ot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ot.py b/test/test_ot.py index da0ec746e..f84f8773a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -395,7 +395,7 @@ def test_generalised_free_support_barycenter_backends(nx): np.testing.assert_allclose(Y, nx.to_numpy(Y2)) -@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") +@pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] a2 = np.array([0, 0, 1.0])[:, None] From 8ffb06190ce085af685676ac3072335ef5364680 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:23:50 +0100 Subject: [PATCH 11/27] lp.cvx now imports barycenter and gives a warnings.warning --- ot/lp/cvx.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index b2269b8b4..4f7846341 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -7,6 +7,11 @@ # # License: MIT License -print( - "The module ot.lp.cvx is deprecated and will be removed in future versions. The function `barycenter` was moved to ot.lp._barycenter_solvers and can be importer via ot.lp." +import warnings + + +warnings.warn( + "The module ot.lp.cvx is deprecated and will be removed in future versions." + "The function `barycenter` was moved to ot.lp._barycenter_solvers and can" + "be importer via ot.lp." ) From 26748eb0602305ed5d115ad1d7a3b43f352ff06c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:28:04 +0100 Subject: [PATCH 12/27] cvx import barycenter --- ot/lp/cvx.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 4f7846341..e88d15375 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -8,6 +8,10 @@ # License: MIT License import warnings +from ._barycenter_solvers import barycenter + + +__all__ = ["barycenter"] warnings.warn( From 081e4eb14285a50f23891cb398472d42da70e724 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 16:52:19 +0100 Subject: [PATCH 13/27] added fixed-point barycenter function to ot.lp._barycenter_solvers_ --- CONTRIBUTORS.md | 2 +- README.md | 4 ++ RELEASES.md | 2 + ot/lp/_barycenter_solvers.py | 87 ++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 6f6a72737..fc1ecc313 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -44,7 +44,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW, partial FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein - Barycenters, GMMOT) + Barycenters, GMMOT, Barycenters for General Transport Costs) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) diff --git a/README.md b/README.md index f64db8f56..9a8e5b371 100644 --- a/README.md +++ b/README.md @@ -391,3 +391,7 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing +Barycentres of Measures for Generic Transport +Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) diff --git a/RELEASES.md b/RELEASES.md index a0474eda0..2a6867484 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,8 @@ - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) +- Implement fixed-point solver for OT barycenters with generic cost functions + (generalizes `ot.lp.free_support_barycenter`). (PR #715) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 8b64214d9..7e801caa6 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -422,3 +422,90 @@ def generalized_free_support_barycenter( return Y, log_dict else: return Y + + +class StoppingCriterionReached(Exception): + pass + + +def solve_OT_barycenter_fixed_point( + X_init, + Y_list, + b_list, + cost_list, + B, + max_its=5, + stop_threshold=1e-5, + log=False, +): + """ + Solves the OT barycenter problem using the fixed point algorithm, iterating + the function B on plans between the current barycentre and the measures. + + Parameters + ---------- + X_init : array-like + Array of shape (n, d) representing initial barycentre points. + Y_list : list of array-like + List of K arrays of measure positions, each of shape (m_k, d_k). + b_list : list of array-like + List of K arrays of measure weights, each of shape (m_k). + cost_list : list of callable + List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k). + B : callable + Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre. + max_its : int, optional + Maximum number of iterations (default is 5). + stop_threshold : float, optional + If the iterations move less than this, terminate (default is 1e-5). + log : bool, optional + Whether to return the log dictionary (default is False). + + Returns + ------- + X : array-like + Array of shape (n, d) representing barycentre points. + log_dict : list of array-like, optional + log containing the exit status, list of iterations and list of + displacements if log is True. + """ + nx = get_backend(X_init, Y_list[0]) + K = len(Y_list) + n = X_init.shape[0] + a = nx.ones(n) / n + X_list = [X_init] if log else [] # store the iterations + X = X_init + dX_list = [] # store the displacement squared norms + exit_status = "Unknown" + + try: + for _ in range(max_its): + pi_list = [ # compute the pairwise transport plans + emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K) + ] + Y_perm = [] + for k in range(K): # compute barycentric projections + Y_perm.append(n * pi_list[k] @ Y_list[k]) + X_next = B(Y_perm) + + if log: + X_list.append(X_next) + + # stationary criterion: move less than the threshold + dX = nx.sum((X - X_next) ** 2) + X = X_next + + if log: + dX_list.append(dX) + + if dX < stop_threshold: + exit_status = "Stationary Point" + raise StoppingCriterionReached + + exit_status = "Max iterations reached" + raise StoppingCriterionReached + + except StoppingCriterionReached: + if log: + return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} + return X From 59520198b25a6dd3e2c9f8a403e1846bd77e0995 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 18:06:32 +0100 Subject: [PATCH 14/27] ot bar demo --- .../plot_barycenter_generic_cost.py | 167 ++++++++++++++++++ ...lot_generalized_free_support_barycenter.py | 2 +- examples/others/plot_GMMOT_plan.py | 2 +- examples/others/plot_GMM_flow.py | 2 +- examples/others/plot_SSNB.py | 2 +- ot/gmm.py | 4 +- ot/lp/__init__.py | 3 +- ot/lp/_barycenter_solvers.py | 2 +- ot/mapping.py | 2 +- 9 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 examples/barycenters/plot_barycenter_generic_cost.py diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py new file mode 100644 index 000000000..14779fdff --- /dev/null +++ b/examples/barycenters/plot_barycenter_generic_cost.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +""" +===================================== +OT Barycenter with Generic Costs Demo +===================================== + +This example illustrates the computation of an Optimal Transport for a ground +cost that is not a power of a norm. We take the example of ground costs +:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear) +projection onto a circle k. This is an example of the fixed-point barycenter +solver introduced in [74] which generalises [20]. + +The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over +:math:`x` with Pytorch. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing +Barycentres of Measures for Generic Transport +Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) + +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein +Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International +Conference in Machine Learning + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% Generate data +import torch +from torch.optim import Adam +from ot.utils import dist +import numpy as np +from ot.lp import free_support_barycenter_generic_costs +import matplotlib.pyplot as plt + + +torch.manual_seed(42) + +n = 100 # number of points of the of the barycentre +d = 2 # dimensions of the original measure +K = 4 # number of measures to barycentre +m = 50 # number of points of the measures +b_list = [torch.ones(m) / m] * K # weights of the 4 measures +weights = torch.ones(K) / K # weights for the barycentre +stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo + + +# map R^2 -> R^2 projection onto circle +def proj_circle(X, origin, radius): + diffs = X - origin[None, :] + norms = torch.norm(diffs, dim=1) + return origin[None, :] + radius * diffs / norms[:, None] + + +# circles on which to project +origin1 = torch.tensor([-1.0, -1.0]) +origin2 = torch.tensor([-1.0, 2.0]) +origin3 = torch.tensor([2.0, 2.0]) +origin4 = torch.tensor([2.0, -1.0]) +r = np.sqrt(2) +P_list = [ + lambda X: proj_circle(X, origin1, r), + lambda X: proj_circle(X, origin2, r), + lambda X: proj_circle(X, origin3, r), + lambda X: proj_circle(X, origin4, r), +] + +# measures to barycentre are projections of different random circles +# onto the K circles +Y_list = [] +for k in range(K): + t = torch.rand(m) * 2 * np.pi + X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) + X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] + Y_list.append(P_list[k](X_temp)) + + +# %% Define costs and ground barycenter function +# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a +# (n, n_k) matrix of costs +def c1(x, y): + return dist(P_list[0](x), y) + + +def c2(x, y): + return dist(P_list[1](x), y) + + +def c3(x, y): + return dist(P_list[2](x), y) + + +def c4(x, y): + return dist(P_list[3](x), y) + + +cost_list = [c1, c2, c3, c4] + + +# batched total ground cost function for candidate points x (n, d) +# for computation of the ground barycenter B with gradient descent +def C(x, y): + """ + Computes the barycenter cost for candidate points x (n, d) and + measure supports y: List(n, d_k). + """ + n = x.shape[0] + K = len(y) + out = torch.zeros(n) + for k in range(K): + out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1) + return out + + +# ground barycenter function +def B(y, its=150, lr=1, stop_threshold=stop_threshold): + """ + Computes the ground barycenter for measure supports y: List(n, d_k). + Output: (n, d) array + """ + x = torch.randn(n, d) + x.requires_grad_(True) + opt = Adam([x], lr=lr) + for _ in range(its): + x_prev = x.data.clone() + opt.zero_grad() + loss = torch.sum(C(x, y)) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < stop_threshold: + break + return x + + +# %% Compute the barycenter measure +fixed_point_its = 10 +X_init = torch.rand(n, d) +X_bar = free_support_barycenter_generic_costs( + X_init, + Y_list, + b_list, + cost_list, + B, + max_its=fixed_point_its, + stop_threshold=stop_threshold, +) + +# %% Plot Barycenter (Iteration 10) +alpha = 0.5 +labels = ["circle 1", "circle 2", "circle 3", "circle 4"] +for Y, label in zip(Y_list, labels): + plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label) +plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha) +plt.axis("equal") +plt.xlim(-0.3, 1.3) +plt.ylim(-0.3, 1.3) +plt.axis("off") +plt.legend() +plt.tight_layout() + +# %% diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py index 5b3572bd4..b21c66f13 100644 --- a/examples/barycenters/plot_generalized_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -14,7 +14,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # # License: MIT License diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index 7742d496e..4964ddd66 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -16,7 +16,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index beb675755..dc26ff3ce 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -10,7 +10,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index fbc343a8a..e167b1ee4 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -38,7 +38,7 @@ 2017. """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # License: MIT License # sphinx_gallery_thumbnail_number = 3 diff --git a/ot/gmm.py b/ot/gmm.py index 5c7a4c287..d99d4e5db 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -3,8 +3,8 @@ Optimal transport for Gaussian Mixtures """ -# Author: Eloi Tanguy -# Remi Flamary +# Author: Eloi Tanguy +# Remi Flamary # Julie Delon # # License: MIT License diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index e3cfce0fd..974679440 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,13 +8,13 @@ # # License: MIT License -from . import cvx from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize from ._network_simplex import emd, emd2 from ._barycenter_solvers import ( barycenter, free_support_barycenter, generalized_free_support_barycenter, + free_support_barycenter_generic_costs, ) from ..utils import check_number_threads @@ -46,4 +46,5 @@ "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", "check_number_threads", + "free_support_barycenter_generic_costs", ] diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 7e801caa6..e45092caa 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -428,7 +428,7 @@ class StoppingCriterionReached(Exception): pass -def solve_OT_barycenter_fixed_point( +def free_support_barycenter_generic_costs( X_init, Y_list, b_list, diff --git a/ot/mapping.py b/ot/mapping.py index 1ec55cb95..d2a05809c 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -7,7 +7,7 @@ use it you need to explicitly import :mod:`ot.mapping` """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # # License: MIT License From 3e8421eb6dca94900bbca636a3594ff413cf5925 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 11:35:53 +0100 Subject: [PATCH 15/27] ot bar doc --- .../plot_barycenter_generic_cost.py | 10 +- ot/lp/_barycenter_solvers.py | 100 ++++++++++++++---- 2 files changed, 87 insertions(+), 23 deletions(-) diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py index 14779fdff..3e5ba38fe 100644 --- a/examples/barycenters/plot_barycenter_generic_cost.py +++ b/examples/barycenters/plot_barycenter_generic_cost.py @@ -6,9 +6,9 @@ This example illustrates the computation of an Optimal Transport for a ground cost that is not a power of a norm. We take the example of ground costs -:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear) +:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) projection onto a circle k. This is an example of the fixed-point barycenter -solver introduced in [74] which generalises [20]. +solver introduced in [74] which generalises [20] and [43]. The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in \mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over @@ -22,6 +22,8 @@ Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning +[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + """ # Author: Eloi Tanguy @@ -147,8 +149,8 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): b_list, cost_list, B, - max_its=fixed_point_its, - stop_threshold=stop_threshold, + numItermax=fixed_point_its, + stopThr=stop_threshold, ) # %% Plot Barycenter (Iteration 10) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index e45092caa..a04d4de05 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -430,33 +430,78 @@ class StoppingCriterionReached(Exception): def free_support_barycenter_generic_costs( X_init, - Y_list, - b_list, + measure_locations, + measure_weights, cost_list, B, - max_its=5, - stop_threshold=1e-5, + numItermax=5, + stopThr=1e-5, log=False, ): - """ - Solves the OT barycenter problem using the fixed point algorithm, iterating - the function B on plans between the current barycentre and the measures. + r""" + Solves the OT barycenter problem for generic costs using the fixed point + algorithm, iterating the ground barycenter function B on transport plans + between the current barycentre and the measures. + + The problem finds an optimal barycenter support `X` of given size (n, d) + (enforced by the initialisation), minimising a sum of pairwise transport + costs for the costs :math:`c_k`: + + .. math:: + \min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k), + + where: + + - :math:`X` (n, d) is the barycentre support, + - :math:`a` (n) is the (fixed) barycentre weights, + - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`), + - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), + - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix) + - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`: + + .. math:: + \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F + + s.t. \ \pi \mathbf{1} = \mathbf{a} + + \pi^T \mathbf{1} = \mathbf{b_k} + + \pi \geq 0 + + in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k, + c_k(X, Y_k))`. + + The algorithm requires a given ground barycentre function `B` which computes + a solution of the following minimisation problem given :math:`(y_1, \cdots, + y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`: + + .. math:: + B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), + + where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points + :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times + \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to + this function, and for certain costs it can be computed explicitly of + through a numerical solver. + + This function implements [74] Algorithm 2, which generalises [20] and [43] + to general costs and includes convergence guarantees, including for discrete measures. Parameters ---------- X_init : array-like Array of shape (n, d) representing initial barycentre points. - Y_list : list of array-like + measure_locations : list of array-like List of K arrays of measure positions, each of shape (m_k, d_k). - b_list : list of array-like + measure_weights : list of array-like List of K arrays of measure weights, each of shape (m_k). cost_list : list of callable - List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k). + List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`. B : callable - Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre. - max_its : int, optional + Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre. + numItermax : int, optional Maximum number of iterations (default is 5). - stop_threshold : float, optional + stopThr : float, optional If the iterations move less than this, terminate (default is 1e-5). log : bool, optional Whether to return the log dictionary (default is False). @@ -468,9 +513,25 @@ def free_support_barycenter_generic_costs( log_dict : list of array-like, optional log containing the exit status, list of iterations and list of displacements if log is True. + + .. _references-free-support-barycenter-generic-costs: + + References + ---------- + .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) + + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + See Also + -------- + ot.lp.free_support_barycenter : Free support solver for the case where + :math:`c_k(x,y) = \|x-y\|_2^2`. + ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. """ - nx = get_backend(X_init, Y_list[0]) - K = len(Y_list) + nx = get_backend(X_init, measure_locations[0]) + K = len(measure_locations) n = X_init.shape[0] a = nx.ones(n) / n X_list = [X_init] if log else [] # store the iterations @@ -479,13 +540,14 @@ def free_support_barycenter_generic_costs( exit_status = "Unknown" try: - for _ in range(max_its): + for _ in range(numItermax): pi_list = [ # compute the pairwise transport plans - emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K) + emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) + for k in range(K) ] Y_perm = [] for k in range(K): # compute barycentric projections - Y_perm.append(n * pi_list[k] @ Y_list[k]) + Y_perm.append(n * pi_list[k] @ measure_locations[k]) X_next = B(Y_perm) if log: @@ -498,7 +560,7 @@ def free_support_barycenter_generic_costs( if log: dX_list.append(dX) - if dX < stop_threshold: + if dX < stopThr: exit_status = "Stationary Point" raise StoppingCriterionReached From ccf608a19e515b8f3b664792532f6c1b5136ca5f Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 15:08:00 +0100 Subject: [PATCH 16/27] doc fixes + ot bar coverage --- .../plot_barycenter_generic_cost.py | 46 +++++---- ot/lp/_barycenter_solvers.py | 61 +++++++----- test/test_ot.py | 95 +++++++++++++++++++ 3 files changed, 161 insertions(+), 41 deletions(-) diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py index 3e5ba38fe..e5e5af73a 100644 --- a/examples/barycenters/plot_barycenter_generic_cost.py +++ b/examples/barycenters/plot_barycenter_generic_cost.py @@ -10,19 +10,20 @@ projection onto a circle k. This is an example of the fixed-point barycenter solver introduced in [74] which generalises [20] and [43]. -The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in -\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over +The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over :math:`x` with Pytorch. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing -Barycentres of Measures for Generic Transport -Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. +arXiv preprint 2501.04016 (2024) -[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein -Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International -Conference in Machine Learning +[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein +Barycenters. InternationalConference in Machine Learning -[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. +[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in +Wasserstein space. Journal of Mathematical Analysis and Applications 441.2 +(2016): 744-762. """ @@ -32,7 +33,8 @@ # sphinx_gallery_thumbnail_number = 1 -# %% Generate data +# %% +# Generate data import torch from torch.optim import Adam from ot.utils import dist @@ -43,7 +45,7 @@ torch.manual_seed(42) -n = 100 # number of points of the of the barycentre +n = 200 # number of points of the of the barycentre d = 2 # dimensions of the original measure K = 4 # number of measures to barycentre m = 50 # number of points of the measures @@ -82,7 +84,8 @@ def proj_circle(X, origin, radius): Y_list.append(P_list[k](X_temp)) -# %% Define costs and ground barycenter function +# %% +# Define costs and ground barycenter function # cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a # (n, n_k) matrix of costs def c1(x, y): @@ -140,25 +143,30 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): return x -# %% Compute the barycenter measure -fixed_point_its = 10 +# %% +# Compute the barycenter measure +fixed_point_its = 3 X_init = torch.rand(n, d) X_bar = free_support_barycenter_generic_costs( - X_init, Y_list, b_list, + X_init, cost_list, B, numItermax=fixed_point_its, stopThr=stop_threshold, ) -# %% Plot Barycenter (Iteration 10) -alpha = 0.5 +# %% +# Plot Barycenter (Iteration 3) +alpha = 0.4 +s = 80 labels = ["circle 1", "circle 2", "circle 3", "circle 4"] for Y, label in zip(Y_list, labels): - plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label) -plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha) + plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) +plt.scatter( + *(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s +) plt.axis("equal") plt.xlim(-0.3, 1.3) plt.ylim(-0.3, 1.3) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index a04d4de05..445a996df 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -429,11 +429,12 @@ class StoppingCriterionReached(Exception): def free_support_barycenter_generic_costs( - X_init, measure_locations, measure_weights, + X_init, cost_list, B, + a=None, numItermax=5, stopThr=1e-5, log=False, @@ -441,7 +442,7 @@ def free_support_barycenter_generic_costs( r""" Solves the OT barycenter problem for generic costs using the fixed point algorithm, iterating the ground barycenter function B on transport plans - between the current barycentre and the measures. + between the current barycenter and the measures. The problem finds an optimal barycenter support `X` of given size (n, d) (enforced by the initialisation), minimising a sum of pairwise transport @@ -452,12 +453,13 @@ def free_support_barycenter_generic_costs( where: - - :math:`X` (n, d) is the barycentre support, - - :math:`a` (n) is the (fixed) barycentre weights, - - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`), + - :math:`X` (n, d) is the barycenter support, + - :math:`a` (n) is the (fixed) barycenter weights, + - :math:`Y_k` (m_k, d_k) is the k-th measure support + (`measure_locations[k]`), - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix) - - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`: + - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`: .. math:: \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F @@ -471,9 +473,10 @@ def free_support_barycenter_generic_costs( in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k, c_k(X, Y_k))`. - The algorithm requires a given ground barycentre function `B` which computes - a solution of the following minimisation problem given :math:`(y_1, \cdots, - y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`: + The algorithm requires a given ground barycenter function `B` which computes + (broadcasted of `n`) solutions of the following minimisation problem given + :math:`(Y_1, \cdots, Y_K) \in + \mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: .. math:: B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), @@ -482,23 +485,32 @@ def free_support_barycenter_generic_costs( :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to this function, and for certain costs it can be computed explicitly of - through a numerical solver. + through a numerical solver. The input function B takes a list of K arrays of + shape (n, d_k) and returns an array of shape (n, d). This function implements [74] Algorithm 2, which generalises [20] and [43] - to general costs and includes convergence guarantees, including for discrete measures. + to general costs and includes convergence guarantees, including for discrete + measures. Parameters ---------- - X_init : array-like - Array of shape (n, d) representing initial barycentre points. measure_locations : list of array-like List of K arrays of measure positions, each of shape (m_k, d_k). measure_weights : list of array-like List of K arrays of measure weights, each of shape (m_k). + X_init : array-like + Array of shape (n, d) representing initial barycenter points. cost_list : list of callable - List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`. + List of K cost functions :math:`c_k: \mathbb{R}^{n\times + d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times + m_k}`. B : callable - Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre. + Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays + of shape (n\times d_K), computing the ground barycenters (broadcasted + over n). + a : array-like, optional + Array of shape (n,) representing weights of the barycenter + measure.Defaults to uniform. numItermax : int, optional Maximum number of iterations (default is 5). stopThr : float, optional @@ -509,7 +521,7 @@ def free_support_barycenter_generic_costs( Returns ------- X : array-like - Array of shape (n, d) representing barycentre points. + Array of shape (n, d) representing barycenter points. log_dict : list of array-like, optional log containing the exit status, list of iterations and list of displacements if log is True. @@ -518,22 +530,27 @@ def free_support_barycenter_generic_costs( References ---------- - .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) + .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing + barycenters of Measures for Generic Transport Costs. arXiv preprint + 2501.04016 (2024) - .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein + barycenters." International Conference on Machine Learning. 2014. - .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to + barycenters in Wasserstein space." Journal of Mathematical Analysis and + Applications 441.2 (2016): 744-762. See Also -------- - ot.lp.free_support_barycenter : Free support solver for the case where - :math:`c_k(x,y) = \|x-y\|_2^2`. + ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. """ nx = get_backend(X_init, measure_locations[0]) K = len(measure_locations) n = X_init.shape[0] - a = nx.ones(n) / n + if a is None: + a = nx.ones(n, type_as=X_init) / n X_list = [X_init] if log else [] # store the iterations X = X_init dX_list = [] # store the displacement squared norms diff --git a/test/test_ot.py b/test/test_ot.py index f84f8773a..4916d71aa 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -13,6 +13,8 @@ from ot.datasets import make_1D_gauss as gauss from ot.backend import torch, tf +# import ot.lp._barycenter_solvers # TODO: remove this import + def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch @@ -395,6 +397,99 @@ def test_generalised_free_support_barycenter_backends(nx): np.testing.assert_allclose(Y, nx.to_numpy(Y2)) +def test_free_support_barycenter_generic_costs(): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + + X_init = np.array([-12.0]).reshape((1, 1)) + + # obvious barycenter location between two Diracs + bar_locations = np.array([0.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def B(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, measures_weights, X_init, cost_list, B + ) + + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + # test with log and specific weights + X2, log = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + B, + a=ot.unif(1), + log=True, + ) + + assert "X_list" in log + assert "exit_status" in log + assert "dX_list" in log + + np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7) + + # test with one iteration for Max Iterations Reached + X3, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + B, + numItermax=1, + log=True, + ) + assert log2["exit_status"] == "Max iterations reached" + + +def test_free_support_barycenter_generic_costs_backends(nx): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + X_init = np.array([-12.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def B(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, measures_weights, X_init, cost_list, B + ) + + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations2, measures_weights2, X_init2, cost_list, B + ) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + @pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] From 37b9c80cad43f3b71768a265a4c57ef57734e06c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 15:46:20 +0100 Subject: [PATCH 17/27] python 3.13 in test workflow + added ggmot barycenter (WIP) --- .github/workflows/build_tests.yml | 2 +- ot/gmm.py | 114 +++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 4356daa2b..52b4e1d99 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -47,7 +47,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12, "3.13"] steps: - uses: actions/checkout@v4 diff --git a/ot/gmm.py b/ot/gmm.py index d99d4e5db..bf4e700d3 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -13,7 +13,7 @@ from .lp import emd2, emd import numpy as np from .utils import dist -from .gaussian import bures_wasserstein_mapping +from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter def gaussian_logpdf(x, m, C): @@ -440,3 +440,115 @@ def Tk0k1(k0, k1): ] ) return nx.sum(mat, axis=(0, 1)) + + +def solve_gmm_barycenter_fixed_point( + means, + covs, + means_list, + covs_list, + b_list, + weights, + max_its=300, + log=False, + barycentric_proj_method="euclidean", +): + r""" + Solves the GMM OT barycenter problem using the fixed point algorithm. + + Parameters + ---------- + means : array-like + Initial (n, d) GMM means. + covs : array-like + Initial (n, d, d) GMM covariances. + means_list : list of array-like + List of K (m_k, d) GMM means. + covs_list : list of array-like + List of K (m_k, d, d) GMM covariances. + b_list : list of array-like + List of K (m_k) arrays of weights. + weights : array-like + Array (K,) of the barycentre coefficients. + max_its : int, optional + Maximum number of iterations (default is 300). + log : bool, optional + Whether to return the list of iterations (default is False). + barycentric_proj_method : str, optional + Method to project the barycentre weights: 'euclidean' (default) or 'bures'. + + Returns + ------- + means : array-like + (n, d) barycentre GMM means. + covs : array-like + (n, d, d) barycentre GMM covariances. + log_dict : dict, optional + Dictionary containing the list of iterations if log is True. + """ + nx = get_backend(means, covs[0], means_list[0], covs_list[0]) + K = len(means_list) + n = means.shape[0] + d = means.shape[1] + means_its = [means.copy()] + covs_its = [covs.copy()] + a = nx.ones(n, type_as=means) / n + + for _ in range(max_its): + pi_list = [ + gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k]) + for k in range(K) + ] + + means_selection, covs_selection = None, None + # in the euclidean case, the selection of Gaussians from each K sources + # comes from a barycentric projection is a convex combination of the + # selected means and covariances, which can be computed without a + # for loop on i + if barycentric_proj_method == "euclidean": + means_selection = nx.zeros((n, K, d), type_as=means) + covs_selection = nx.zeros((n, K, d, d), type_as=means) + + for k in range(K): + means_selection[:, k, :] = n * pi_list[k] @ means_list[k] + covs_selection[:, k, :, :] = ( + nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n + ) + + # each component i of the barycentre will be a Bures barycentre of the + # selected components of the K GMMs. In the 'bures' barycentric + # projection option, the selected components are also Bures barycentres. + for i in range(n): + # means_slice_i (K, d) is the selected means, each comes from a + # Gaussian barycentre along the disintegration of pi_k at i + # covs_slice_i (K, d, d) are the selected covariances + means_selection_i = [] + covs_selection_i = [] + + # use previous computation (convex combination) + if barycentric_proj_method == "euclidean": + means_selection_i = means_selection[i] + covs_selection_i = covs_selection[i] + + # compute Bures barycentre of the selected components + elif barycentric_proj_method == "bures": + w = (1 / a[i]) * pi_list[k][i, :] + for k in range(K): + m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) + means_selection_i.append(m) + covs_selection_i.append(C) + + else: + raise ValueError("Unknown barycentric_proj_method") + + means[i], covs[i] = bures_wasserstein_barycenter( + means_selection_i, covs_selection_i, weights + ) + + if log: + means_its.append(means.copy()) + covs_its.append(covs.copy()) + + if log: + return means, covs, {"means_its": means_its, "covs_its": covs_its} + return means, covs From a20d3f0656e0e64c0dc4b7a74e94cc9a407c9bd9 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 16:06:43 +0100 Subject: [PATCH 18/27] fixed github action file --- .github/workflows/build_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 52b4e1d99..a8e27b323 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -47,7 +47,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12, "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 From 0b6217b00188f4f01bc80f5de7ba838e039cb39e Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 17:19:56 +0100 Subject: [PATCH 19/27] ot bar doc + test coverage --- .github/workflows/build_tests.yml | 2 +- ot/gmm.py | 103 ++++++++++++++++++++---------- ot/lp/_barycenter_solvers.py | 4 +- test/test_gmm.py | 54 +++++++++++++++- 4 files changed, 124 insertions(+), 39 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index a8e27b323..4356daa2b 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -47,7 +47,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/ot/gmm.py b/ot/gmm.py index bf4e700d3..214720d1e 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -442,36 +442,50 @@ def Tk0k1(k0, k1): return nx.sum(mat, axis=(0, 1)) -def solve_gmm_barycenter_fixed_point( - means, - covs, +def gmm_barycenter_fixed_point( means_list, covs_list, - b_list, + w_list, + means_init, + covs_init, weights, - max_its=300, + w_bar=None, + iterations=100, log=False, barycentric_proj_method="euclidean", ): r""" - Solves the GMM OT barycenter problem using the fixed point algorithm. + Solves the Gaussian Mixture Model OT barycenter problem (defined in [69]) + using the fixed point algorithm (proposed in [74]). The + weights of the barycenter are not optimized, and stay the same as the input + `w_list` or are initialized to uniform. + + The algorithm uses barycentric projections of GMM-OT plans, and these can be + computed either through Bures Barycenters (slow but accurate, + barycentric_proj_method='bures') or by convex combination (fast, + barycentric_proj_method='euclidean', default). + + This is a special case of the generic free-support barycenter solver + `ot.lp.free_support_barycenter_generic_costs`. Parameters ---------- - means : array-like - Initial (n, d) GMM means. - covs : array-like - Initial (n, d, d) GMM covariances. means_list : list of array-like List of K (m_k, d) GMM means. covs_list : list of array-like List of K (m_k, d, d) GMM covariances. - b_list : list of array-like + w_list : list of array-like List of K (m_k) arrays of weights. + means_init : array-like + Initial (n, d) GMM means. + covs_init : array-like + Initial (n, d, d) GMM covariances. weights : array-like Array (K,) of the barycentre coefficients. - max_its : int, optional - Maximum number of iterations (default is 300). + w_bar : array-like, optional + Initial weights (n) of the barycentre GMM. If None, initialized to uniform. + iterations : int, optional + Number of iterations (default is 100). log : bool, optional Whether to return the list of iterations (default is False). barycentric_proj_method : str, optional @@ -485,30 +499,46 @@ def solve_gmm_barycenter_fixed_point( (n, d, d) barycentre GMM covariances. log_dict : dict, optional Dictionary containing the list of iterations if log is True. + + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + + .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) + + See Also + -------- + ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs. """ - nx = get_backend(means, covs[0], means_list[0], covs_list[0]) + nx = get_backend( + means_init, covs_init, means_list[0], covs_list[0], w_list[0], weights + ) K = len(means_list) - n = means.shape[0] - d = means.shape[1] - means_its = [means.copy()] - covs_its = [covs.copy()] - a = nx.ones(n, type_as=means) / n + n = means_init.shape[0] + d = means_init.shape[1] + means_its = [nx.copy(means_init)] + covs_its = [nx.copy(covs_init)] + means, covs = means_init, covs_init + + if w_bar is None: + w_bar = nx.ones(n, type_as=means) / n - for _ in range(max_its): + for _ in range(iterations): pi_list = [ - gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k]) + gmm_ot_plan(means, means_list[k], covs, covs_list[k], w_bar, w_list[k]) for k in range(K) ] + # filled in the euclidean case means_selection, covs_selection = None, None + # in the euclidean case, the selection of Gaussians from each K sources - # comes from a barycentric projection is a convex combination of the - # selected means and covariances, which can be computed without a - # for loop on i + # comes from a barycentric projection: it is a convex combination of the + # selected means and covariances, which can be computed without a + # for loop on i = 0, ..., n -1 if barycentric_proj_method == "euclidean": means_selection = nx.zeros((n, K, d), type_as=means) covs_selection = nx.zeros((n, K, d, d), type_as=means) - for k in range(K): means_selection[:, k, :] = n * pi_list[k] @ means_list[k] covs_selection[:, k, :, :] = ( @@ -519,24 +549,27 @@ def solve_gmm_barycenter_fixed_point( # selected components of the K GMMs. In the 'bures' barycentric # projection option, the selected components are also Bures barycentres. for i in range(n): - # means_slice_i (K, d) is the selected means, each comes from a + # means_selection_i (K, d) is the selected means, each comes from a # Gaussian barycentre along the disintegration of pi_k at i - # covs_slice_i (K, d, d) are the selected covariances - means_selection_i = [] - covs_selection_i = [] + # covs_selection_i (K, d, d) are the selected covariances + means_selection_i = None + covs_selection_i = None # use previous computation (convex combination) if barycentric_proj_method == "euclidean": means_selection_i = means_selection[i] covs_selection_i = covs_selection[i] - # compute Bures barycentre of the selected components + # compute Bures barycentre of certain components to get the + # selection at i elif barycentric_proj_method == "bures": - w = (1 / a[i]) * pi_list[k][i, :] + means_selection_i = nx.zeros((K, d), type_as=means) + covs_selection_i = nx.zeros((K, d, d), type_as=means) for k in range(K): + w = (1 / w_bar[i]) * pi_list[k][i, :] m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) - means_selection_i.append(m) - covs_selection_i.append(C) + means_selection_i[k] = m + covs_selection_i[k] = C else: raise ValueError("Unknown barycentric_proj_method") @@ -546,8 +579,8 @@ def solve_gmm_barycenter_fixed_point( ) if log: - means_its.append(means.copy()) - covs_its.append(covs.copy()) + means_its.append(nx.copy(means)) + covs_its.append(nx.copy(covs)) if log: return means, covs, {"means_its": means_its, "covs_its": covs_its} diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 445a996df..9589121bd 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -435,7 +435,7 @@ def free_support_barycenter_generic_costs( cost_list, B, a=None, - numItermax=5, + numItermax=100, stopThr=1e-5, log=False, ): @@ -512,7 +512,7 @@ def free_support_barycenter_generic_costs( Array of shape (n,) representing weights of the barycenter measure.Defaults to uniform. numItermax : int, optional - Maximum number of iterations (default is 5). + Maximum number of iterations (default is 100). stopThr : float, optional If the iterations move less than this, terminate (default is 1e-5). log : bool, optional diff --git a/test/test_gmm.py b/test/test_gmm.py index 5f1a92965..629a68d57 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -1,6 +1,6 @@ """Tests for module gaussian""" -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # @@ -17,6 +17,7 @@ gmm_ot_plan, gmm_ot_apply_map, gmm_ot_plan_density, + gmm_barycenter_fixed_point, ) try: @@ -193,3 +194,54 @@ def test_gmm_ot_plan_density(nx): with pytest.raises(AssertionError): gmm_ot_plan_density(x[:, 1:], y, m_s, m_t, C_s, C_t, w_s, w_t) + + +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") +def test_gmm_barycenter_fixed_point(nx): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) + means_list = [m_s, m_t] + covs_list = [C_s, C_t] + w_list = [w_s, w_t] + n_iter = 3 + n = m_s.shape[0] # number of components of barycenter + means_init = m_s + covs_init = C_s + weights = nx.ones(2, type_as=m_s) / 2 # barycenter coefficients + + # with euclidean barycentric projections + means, covs = gmm_barycenter_fixed_point( + means_list, covs_list, w_list, means_init, covs_init, weights, iterations=n_iter + ) + + # with bures barycentric projections and assigned weights to uniform + means_bures_proj, covs_bures_proj, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + w_bar=nx.ones(n, type_as=m_s) / n, + barycentric_proj_method="bures", + log=True, + ) + + assert "means_its" in log + assert "covs_its" in log + + assert np.allclose(means, means_bures_proj, atol=1e-6) + assert np.allclose(covs, covs_bures_proj, atol=1e-6) + + with pytest.raises(ValueError): + gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + barycentric_proj_method="unknown", + ) From 21bf86b944f2ce6cb71f381718c50095ca485850 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 17:52:11 +0100 Subject: [PATCH 20/27] examples: ot bar with projections onto circles + gmm ot bar --- README.md | 4 +- ...t_free_support_barycenter_generic_cost.py} | 8 +- examples/barycenters/plot_gmm_barycenter.py | 144 ++++++++++++++++++ 3 files changed, 149 insertions(+), 7 deletions(-) rename examples/barycenters/{plot_barycenter_generic_cost.py => plot_free_support_barycenter_generic_cost.py} (96%) create mode 100644 examples/barycenters/plot_gmm_barycenter.py diff --git a/README.md b/README.md index 9a8e5b371..9266c99c6 100644 --- a/README.md +++ b/README.md @@ -392,6 +392,4 @@ Artificial Intelligence. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing -Barycentres of Measures for Generic Transport -Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py similarity index 96% rename from examples/barycenters/plot_barycenter_generic_cost.py rename to examples/barycenters/plot_free_support_barycenter_generic_cost.py index e5e5af73a..55a75b157 100644 --- a/examples/barycenters/plot_barycenter_generic_cost.py +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -4,8 +4,8 @@ OT Barycenter with Generic Costs Demo ===================================== -This example illustrates the computation of an Optimal Transport for a ground -cost that is not a power of a norm. We take the example of ground costs +This example illustrates the computation of an Optimal Transport Barycenter for +a ground cost that is not a power of a norm. We take the example of ground costs :math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) projection onto a circle k. This is an example of the fixed-point barycenter solver introduced in [74] which generalises [20] and [43]. @@ -15,8 +15,8 @@ :math:`x` with Pytorch. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing -Barycentres of Measures for Generic Transport Costs. -arXiv preprint 2501.04016 (2024) +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) [20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein Barycenters. InternationalConference in Machine Learning diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py new file mode 100644 index 000000000..07792c0dd --- /dev/null +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +""" +===================================== +Gaussian Mixture Model OT Barycenters +===================================== + +This example illustrates the computation of a barycenter between Gaussian +Mixtures in the sense of GMM-OT [69]. This computation is done using the +fixed-point method for OT barycenters with generic costs [74], for which POT +provides a general solver, and a specific GMM solver. Note that this is a +'free-support' method, implying that the number of components of the barycenter +GMM and their weights are fixed. + +The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over +the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the +Bures-Wasserstein manifold), and to compute barycenters with respect to the +2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a +gaussian mixture is a finite combination of Diracs on specific gaussians, and +two mixtures are compared with the 2-Wasserstein distance on this space with +ground cost the squared Bures distance between gaussians. + +[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space +of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% +# Generate data +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse +import ot +from ot.gmm import gmm_barycenter_fixed_point + + +K = 3 # number of GMMs +d = 2 # dimension +n = 6 # number of components of the desired barycenter + + +def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2): + rng = np.random.RandomState(seed=seed) + means = rng.randn(K, d) + P = rng.randn(K, d, d) * cov_scale + # C[k] = P[k] @ P[k]^T + min_cov_eig * I + covariances = np.einsum("kab,kcb->kac", P, P) + covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)]) + weights = rng.random(K) + weights /= np.sum(weights) + return means, covariances, weights + + +m_list = [5, 6, 7] # number of components in each GMM +offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])] +means_list = [] # list of means for each GMM +covs_list = [] # list of covariances for each GMM +w_list = [] # list of weights for each GMM + +# generate GMMs +for k in range(K): + means, covs, b = get_random_gmm( + m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5 + ) + means = means / 2 + offsets[k][None, :] + means_list.append(means) + covs_list.append(covs) + w_list.append(b) + +# %% +# Compute the barycenter using the fixed-point method +init_means, init_covs, _ = get_random_gmm(n, d, seed=0) +weights = ot.unif(K) # barycenter coefficients +means_bar, covs_bar, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + init_means, + init_covs, + weights, + iterations=3, + log=True, +) + + +# %% +# Define plotting functions + + +# draw a covariance ellipse +def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None): + def eigsorted(cov): + vals, vecs = np.linalg.eigh(cov) + order = vals.argsort()[::-1].copy() + return vals[order], vecs[:, order] + + vals, vecs = eigsorted(C) + theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) + w, h = 2 * nstd * np.sqrt(vals) + ell = Ellipse( + xy=(mu[0], mu[1]), + width=w, + height=h, + alpha=alpha, + angle=theta, + facecolor=color, + edgecolor=color, + label=label, + fill=True, + ) + if ax is None: + ax = plt.gca() + ax.add_artist(ell) + + +# draw a gmm as a set of ellipses with weights shown in alpha value +def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): + for k in range(ms.shape[0]): + draw_cov( + ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax + ) + + +# %% +# Plot the results +fig, ax = plt.subplots(figsize=(6, 6)) +axis = [-4, 4, -2, 6] +ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) +for k in range(K): + draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax) +draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) +ax.axis(axis) +ax.axis("off") + +# %% From 0820e513e3415a1aa03abb6cd6a9acb27a7096d9 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 18:03:59 +0100 Subject: [PATCH 21/27] releases + readme + docs update --- README.md | 2 ++ RELEASES.md | 3 ++- examples/barycenters/plot_gmm_barycenter.py | 2 +- ot/lp/_barycenter_solvers.py | 27 ++++++++++++--------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 9266c99c6..48a4a87fe 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [74] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 74] POT provides the following Machine Learning related solvers: diff --git a/RELEASES.md b/RELEASES.md index ff8496bef..add09378c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,7 +8,8 @@ - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) - Implement fixed-point solver for OT barycenters with generic cost functions - (generalizes `ot.lp.free_support_barycenter`). (PR #715) + (generalizes `ot.lp.free_support_barycenter`), with example. (PR #715) +- Implement fixed-point solver for barycenters between GMMs (PR #715), with example. #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py index 07792c0dd..84d0ee638 100644 --- a/examples/barycenters/plot_gmm_barycenter.py +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -16,7 +16,7 @@ Bures-Wasserstein manifold), and to compute barycenters with respect to the 2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a gaussian mixture is a finite combination of Diracs on specific gaussians, and -two mixtures are compared with the 2-Wasserstein distance on this space with +two mixtures are compared with the 2-Wasserstein distance on this space, where ground cost the squared Bures distance between gaussians. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 9589121bd..5e53c66d2 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -458,7 +458,9 @@ def free_support_barycenter_generic_costs( - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`), - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), - - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix) + - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} + \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function + (which computes the pairwise cost matrix) - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`: .. math:: @@ -475,18 +477,19 @@ def free_support_barycenter_generic_costs( The algorithm requires a given ground barycenter function `B` which computes (broadcasted of `n`) solutions of the following minimisation problem given - :math:`(Y_1, \cdots, Y_K) \in - \mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: + :math:`(Y_1, \cdots, Y_K) \in \mathbb{R}^{n\times + d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: .. math:: B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points - :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times - \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to - this function, and for certain costs it can be computed explicitly of - through a numerical solver. The input function B takes a list of K arrays of - shape (n, d_k) and returns an array of shape (n, d). + :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{n\times + d_1}\times \cdots\times\mathbb{R}^{n\times d_K} \longrightarrow + \mathbb{R}^{n\times d}` is an input to this function, and for certain costs + it can be computed explicitly of through a numerical solver. The input + function B takes a list of K arrays of shape (n, d_k) and returns an array + of shape (n, d). This function implements [74] Algorithm 2, which generalises [20] and [43] to general costs and includes convergence guarantees, including for discrete @@ -526,8 +529,6 @@ def free_support_barycenter_generic_costs( log containing the exit status, list of iterations and list of displacements if log is True. - .. _references-free-support-barycenter-generic-costs: - References ---------- .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing @@ -543,8 +544,10 @@ def free_support_barycenter_generic_costs( See Also -------- - ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`. - ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. + ot.lp.free_support_barycenter : Free support solver for the case where + :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter : + Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` + with :math:`P_k` linear. """ nx = get_backend(X_init, measure_locations[0]) K = len(measure_locations) From 6bd4af8b9c280798c2d5d8b617d611340589fdc7 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 12 Mar 2025 15:15:45 +0100 Subject: [PATCH 22/27] ref fix --- README.md | 4 ++-- .../plot_free_support_barycenter_generic_cost.py | 4 ++-- examples/barycenters/plot_gmm_barycenter.py | 6 ++---- ot/gmm.py | 4 ++-- ot/lp/_barycenter_solvers.py | 4 ++-- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 124c5d809..a7f1ff830 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,8 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. -* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [74] -* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 74] +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [76] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 76] POT provides the following Machine Learning related solvers: diff --git a/examples/barycenters/plot_free_support_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py index 55a75b157..47e2c9236 100644 --- a/examples/barycenters/plot_free_support_barycenter_generic_cost.py +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -8,13 +8,13 @@ a ground cost that is not a power of a norm. We take the example of ground costs :math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) projection onto a circle k. This is an example of the fixed-point barycenter -solver introduced in [74] which generalises [20] and [43]. +solver introduced in [76] which generalises [20] and [43]. The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in \mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over :math:`x` with Pytorch. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py index 84d0ee638..f379a9914 100644 --- a/examples/barycenters/plot_gmm_barycenter.py +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -6,7 +6,7 @@ This example illustrates the computation of a barycenter between Gaussian Mixtures in the sense of GMM-OT [69]. This computation is done using the -fixed-point method for OT barycenters with generic costs [74], for which POT +fixed-point method for OT barycenters with generic costs [76], for which POT provides a general solver, and a specific GMM solver. Note that this is a 'free-support' method, implying that the number of components of the barycenter GMM and their weights are fixed. @@ -22,7 +22,7 @@ [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) @@ -140,5 +140,3 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) ax.axis(axis) ax.axis("off") - -# %% diff --git a/ot/gmm.py b/ot/gmm.py index 214720d1e..a065c73b0 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -456,7 +456,7 @@ def gmm_barycenter_fixed_point( ): r""" Solves the Gaussian Mixture Model OT barycenter problem (defined in [69]) - using the fixed point algorithm (proposed in [74]). The + using the fixed point algorithm (proposed in [76]). The weights of the barycenter are not optimized, and stay the same as the input `w_list` or are initialized to uniform. @@ -504,7 +504,7 @@ def gmm_barycenter_fixed_point( ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. - .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) See Also -------- diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 61b4fce49..f803d23db 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -495,7 +495,7 @@ def free_support_barycenter_generic_costs( function B takes a list of K arrays of shape (n, d_k) and returns an array of shape (n, d). - This function implements [74] Algorithm 2, which generalises [20] and [43] + This function implements [76] Algorithm 2, which generalises [20] and [43] to general costs and includes convergence guarantees, including for discrete measures. @@ -535,7 +535,7 @@ def free_support_barycenter_generic_costs( References ---------- - .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) From 51722bf65f1be26a453f5602f07d1ecb4752c896 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 17 Mar 2025 19:54:14 +0100 Subject: [PATCH 23/27] implementation comments --- ot/lp/_barycenter_solvers.py | 133 +++++++++++++++++++++++------------ test/test_ot.py | 99 +++++++++++++++++++++++--- 2 files changed, 178 insertions(+), 54 deletions(-) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index f803d23db..725af26c4 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -199,14 +199,12 @@ def free_support_barycenter( measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure - X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - numItermax : int, optional Max number of iterations stopThr : float, optional @@ -219,13 +217,11 @@ def free_support_barycenter( If compiled with OpenMP, chooses the number of threads to parallelize. "max" selects the highest number possible. - Returns ------- X : (k,d) array-like Support locations (on k atoms) of the barycenter - .. _references-free-support-barycenter: References ---------- @@ -428,20 +424,20 @@ def generalized_free_support_barycenter( return Y -class StoppingCriterionReached(Exception): - pass - - def free_support_barycenter_generic_costs( measure_locations, measure_weights, X_init, cost_list, - B, + ground_bary=None, a=None, numItermax=100, stopThr=1e-5, log=False, + ground_bary_lr=1e-2, + ground_bary_numItermax=100, + ground_bary_stopThr=1e-5, + ground_bary_solver="SGD", ): r""" Solves the OT barycenter problem for generic costs using the fixed point @@ -507,14 +503,15 @@ def free_support_barycenter_generic_costs( List of K arrays of measure weights, each of shape (m_k). X_init : array-like Array of shape (n, d) representing initial barycenter points. - cost_list : list of callable + cost_list : list of callable or callable List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times - m_k}`. - B : callable + m_k}`. If cost_list is a single callable, the same cost is used K times. + ground_bary : callable or None, optional Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays of shape (n\times d_K), computing the ground barycenters (broadcasted - over n). + over n). If not provided, done with Adam on PyTorch (requires PyTorch + backend) a : array-like, optional Array of shape (n,) representing weights of the barycenter measure.Defaults to uniform. @@ -524,6 +521,16 @@ def free_support_barycenter_generic_costs( If the iterations move less than this, terminate (default is 1e-5). log : bool, optional Whether to return the log dictionary (default is False). + ground_bary_lr : float, optional + Learning rate for the ground barycenter solver (if auto is used). + ground_bary_numItermax : int, optional + Maximum number of iterations for the ground barycenter solver (if auto + is used). + ground_bary_stopThr : float, optional + Stop threshold for the ground barycenter solver (if auto is used). + ground_bary_solver : str, optional + Solver for auto ground bary solver (torch SGD or Adam). Default is + "SGD". Returns ------- @@ -549,49 +556,85 @@ def free_support_barycenter_generic_costs( See Also -------- ot.lp.free_support_barycenter : Free support solver for the case where - :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter : - Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` - with :math:`P_k` linear. + :math:`c_k(x,y) = \lambda_k\|x-y\|_2^2`. + ot.lp.generalized_free_support_barycenter : Free support solver for the case + where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. """ nx = get_backend(X_init, measure_locations[0]) K = len(measure_locations) n = X_init.shape[0] if a is None: a = nx.ones(n, type_as=X_init) / n + if callable(cost_list): # use the given cost for all K pairs + cost_list = [cost_list] * K + auto_ground_bary = False + + if ground_bary is None: + auto_ground_bary = True + assert str(nx) == "torch", ( + f"Backend {str(nx)} is not compatible with ground_bary=None, it" + "must be provided if not using PyTorch backend" + ) + try: + import torch + from torch.optim import Adam, SGD + + def ground_bary(y, x_init): + x = x_init.clone().detach().requires_grad_(True) + solver = Adam if ground_bary_solver == "Adam" else SGD + opt = solver([x], lr=ground_bary_lr) + for _ in range(ground_bary_numItermax): + x_prev = x.data.clone() + opt.zero_grad() + # inefficient cost computation but compatible + # with the choice of cost_list[k] giving the cost matrix + loss = torch.sum( + torch.stack( + [torch.diag(cost_list[k](x, y[k])) for k in range(K)] + ) + ) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < ground_bary_stopThr: + break + return x.detach() + + except ImportError: + raise ImportError("PyTorch is required to use ground_bary=None") + X_list = [X_init] if log else [] # store the iterations X = X_init dX_list = [] # store the displacement squared norms - exit_status = "Unknown" - - try: - for _ in range(numItermax): - pi_list = [ # compute the pairwise transport plans - emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) - for k in range(K) - ] - Y_perm = [] - for k in range(K): # compute barycentric projections - Y_perm.append(n * pi_list[k] @ measure_locations[k]) - X_next = B(Y_perm) - - if log: - X_list.append(X_next) + exit_status = "Max iterations reached" + + for _ in range(numItermax): + pi_list = [ # compute the pairwise transport plans + emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) + for k in range(K) + ] + Y_perm = [] + for k in range(K): # compute barycentric projections + Y_perm.append(n * pi_list[k] @ measure_locations[k]) + if auto_ground_bary: # use previous position as initialization + X_next = ground_bary(Y_perm, X) + else: + X_next = ground_bary(Y_perm) - # stationary criterion: move less than the threshold - dX = nx.sum((X - X_next) ** 2) - X = X_next + if log: + X_list.append(X_next) - if log: - dX_list.append(dX) + # stationary criterion: move less than the threshold + dX = nx.sum((X - X_next) ** 2) + X = X_next - if dX < stopThr: - exit_status = "Stationary Point" - raise StoppingCriterionReached + if log: + dX_list.append(dX) - exit_status = "Max iterations reached" - raise StoppingCriterionReached + if dX < stopThr: + exit_status = "Stationary Point" + break - except StoppingCriterionReached: - if log: - return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} - return X + if log: + return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} + return X diff --git a/test/test_ot.py b/test/test_ot.py index 4916d71aa..22612fa4a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -13,8 +13,6 @@ from ot.datasets import make_1D_gauss as gauss from ot.backend import torch, tf -# import ot.lp._barycenter_solvers # TODO: remove this import - def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch @@ -414,14 +412,14 @@ def cost(x, y): cost_list = [cost, cost] - def B(y): + def ground_bary(y): out = 0 for yk in y: out += yk / len(y) return out X = ot.lp.free_support_barycenter_generic_costs( - measures_locations, measures_weights, X_init, cost_list, B + measures_locations, measures_weights, X_init, cost_list, ground_bary ) np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) @@ -432,7 +430,7 @@ def B(y): measures_weights, X_init, cost_list, - B, + ground_bary, a=ot.unif(1), log=True, ) @@ -449,12 +447,95 @@ def B(y): measures_weights, X_init, cost_list, - B, + ground_bary, numItermax=1, log=True, ) assert log2["exit_status"] == "Max iterations reached" + # test with a single callable cost + X3, log3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost, + ground_bary, + numItermax=1, + log=True, + ) + + # test with no ground_bary but in numpy: requires pytorch backend + with pytest.raises(AssertionError): + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + numItermax=1, + ) + + +@pytest.mark.skipif(not torch, reason="No torch available") +def test_free_support_barycenter_generic_costs_auto_ground_bary(): + measures_locations = [ + torch.tensor([1.0]).reshape((1, 1)), + torch.tensor([2.0]).reshape((1, 1)), + ] + measures_weights = [torch.tensor([1.0]), torch.tensor([1.0])] + + X_init = torch.tensor([1.2]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + ) + + X2, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=1e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=50, + numItermax=10, + log=True, + ) + + np.testing.assert_allclose(X2.numpy(), X.numpy(), rtol=1e-4, atol=1e-4) + + X3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=1e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=50, + numItermax=10, + ground_bary_solver="Adam", + ) + + np.testing.assert_allclose(X2.numpy(), X3.numpy(), rtol=1e-3, atol=1e-3) + def test_free_support_barycenter_generic_costs_backends(nx): measures_locations = [ @@ -469,14 +550,14 @@ def cost(x, y): cost_list = [cost, cost] - def B(y): + def ground_bary(y): out = 0 for yk in y: out += yk / len(y) return out X = ot.lp.free_support_barycenter_generic_costs( - measures_locations, measures_weights, X_init, cost_list, B + measures_locations, measures_weights, X_init, cost_list, ground_bary ) measures_locations2 = nx.from_numpy(*measures_locations) @@ -484,7 +565,7 @@ def B(y): X_init2 = nx.from_numpy(X_init) X2 = ot.lp.free_support_barycenter_generic_costs( - measures_locations2, measures_weights2, X_init2, cost_list, B + measures_locations2, measures_weights2, X_init2, cost_list, ground_bary ) np.testing.assert_allclose(X, nx.to_numpy(X2)) From f2269ac029c132482b5c454668fcb60443831c23 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 6 Jun 2025 10:52:21 +0200 Subject: [PATCH 24/27] (WIP) added true barycenter fixed-point algorithm with updated tests and examples --- ...ot_free_support_barycenter_generic_cost.py | 108 ++++++- ot/lp/__init__.py | 2 + ot/lp/_barycenter_solvers.py | 291 +++++++++++++++++- test/test_ot.py | 190 +++++++++++- 4 files changed, 553 insertions(+), 38 deletions(-) diff --git a/examples/barycenters/plot_free_support_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py index 47e2c9236..2886432ae 100644 --- a/examples/barycenters/plot_free_support_barycenter_generic_cost.py +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -6,14 +6,29 @@ This example illustrates the computation of an Optimal Transport Barycenter for a ground cost that is not a power of a norm. We take the example of ground costs -:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) -projection onto a circle k. This is an example of the fixed-point barycenter -solver introduced in [76] which generalises [20] and [43]. +:math:`c_k(x, y) = \lambda_k\|P_k(x)-y\|_2^2`, where :math:`P_k` is the +(non-linear) projection onto a circle k, and :math:`(\lambda_k)` are weights. A +barycenter is defined ([76]) as a minimiser of the energy :math:`V(\mu) = \sum_k +\mathcal{T}_{c_k}(\mu, \nu_k)` where :math:`\mu` is a candidate barycenter +measure, the measures :math:`\nu_k` are the target measures and +:math:`\mathcal{T}_{c_k}` is the OT cost for ground cost :math:`c_k`. This is an +example of the fixed-point barycenter solver introduced in [76] which +generalises [20] and [43]. The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in \mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over :math:`x` with Pytorch. +We compare two algorithms from [76]: the first ([76], Algorithm 2, +'true_fixed_point' in POT) has convergence guarantees but the iterations may +increase in support size and thus require more computational resources. The +second ([76], Algorithm 3, 'L2_barycentric_proj' in POT) is a simplified +heuristic that imposes a fixed support size for the barycenter and fixed +weights. + +We initialise both algorithms with a support size of 136, computing a barycenter +between measures with uniform weights and 50 points. + [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) @@ -36,16 +51,18 @@ # %% # Generate data import torch +import ot from torch.optim import Adam from ot.utils import dist import numpy as np from ot.lp import free_support_barycenter_generic_costs import matplotlib.pyplot as plt +from time import time torch.manual_seed(42) -n = 200 # number of points of the of the barycentre +n = 136 # number of points of the of the barycentre d = 2 # dimensions of the original measure K = 4 # number of measures to barycentre m = 50 # number of points of the measures @@ -128,7 +145,7 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): Computes the ground barycenter for measure supports y: List(n, d_k). Output: (n, d) array """ - x = torch.randn(n, d) + x = torch.randn(y[0].shape[0], d) x.requires_grad_(True) opt = Adam([x], lr=lr) for _ in range(its): @@ -144,10 +161,30 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): # %% -# Compute the barycenter measure +# Compute the barycenter measure with the true fixed-point algorithm +fixed_point_its = 3 +torch.manual_seed(42) +X_init = torch.rand(n, d) +t0 = time() +X_bar, a_bar = free_support_barycenter_generic_costs( + Y_list, + b_list, + X_init, + cost_list, + B, + numItermax=fixed_point_its, + stopThr=stop_threshold, + method="true_fixed_point", +) +dt_true_fixed_point = time() - t0 + +# %% +# Compute the barycenter measure with the barycentric (default) algorithm fixed_point_its = 3 +torch.manual_seed(42) X_init = torch.rand(n, d) -X_bar = free_support_barycenter_generic_costs( +t0 = time() +X_bar2 = free_support_barycenter_generic_costs( Y_list, b_list, X_init, @@ -156,22 +193,61 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): numItermax=fixed_point_its, stopThr=stop_threshold, ) +dt_barycentric = time() - t0 # %% -# Plot Barycenter (Iteration 3) +# Plot Barycenters (Iteration 3) alpha = 0.4 s = 80 labels = ["circle 1", "circle 2", "circle 3", "circle 4"] + + +# Compute barycenter energies +def V(X, a): + v = 0 + for k in range(K): + v += (1 / K) * ot.emd2(a, b_list[k], cost_list[k](X, Y_list[k])) + return v + + +fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + +# Plot for the true fixed-point algorithm for Y, label in zip(Y_list, labels): - plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) -plt.scatter( - *(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s + axes[0].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) +axes[0].scatter( + *(X_bar.detach().numpy()).T, + label="Barycenter", + c="black", + alpha=alpha * a_bar.numpy() / np.max(a_bar.numpy()), + s=s, +) +axes[0].set_title( + "True Fixed-Point Algorithm\n" + f"Support size: {a_bar.shape[0]}\n" + f"Barycenter cost: {V(X_bar, a_bar).item():.6f}\n" + f"Computation time {dt_true_fixed_point:.4f}s" ) -plt.axis("equal") -plt.xlim(-0.3, 1.3) -plt.ylim(-0.3, 1.3) -plt.axis("off") -plt.legend() +axes[0].axis("equal") +axes[0].axis("off") +axes[0].legend() + +# Plot for the heuristic algorithm +for Y, label in zip(Y_list, labels): + axes[1].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) +axes[1].scatter( + *(X_bar2.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s +) +axes[1].set_title( + "Heuristic Barycentric Algorithm\n" + f"Support size: {X_bar2.shape[0]}\n" + f"Barycenter cost: {V(X_bar2, torch.ones(n) / n).item():.6f}\n" + f"Computation time {dt_barycentric:.4f}s" +) +axes[1].axis("equal") +axes[1].axis("off") +axes[1].legend() + plt.tight_layout() # %% diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 974679440..03aeb958a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -15,6 +15,7 @@ free_support_barycenter, generalized_free_support_barycenter, free_support_barycenter_generic_costs, + NorthWestMMGluing, ) from ..utils import check_number_threads @@ -47,4 +48,5 @@ "dmmot_monge_1dgrid_optimize", "check_number_threads", "free_support_barycenter_generic_costs", + "NorthWestMMGluing", ] diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 725af26c4..9cd04ca2d 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -10,7 +10,7 @@ from ..backend import get_backend from ..utils import dist -from ._network_simplex import emd +from ._network_simplex import emd, emd2 import numpy as np import scipy as sp @@ -432,12 +432,14 @@ def free_support_barycenter_generic_costs( ground_bary=None, a=None, numItermax=100, + method="L2_barycentric_proj", stopThr=1e-5, log=False, ground_bary_lr=1e-2, ground_bary_numItermax=100, ground_bary_stopThr=1e-5, ground_bary_solver="SGD", + clean_measure=False, ): r""" Solves the OT barycenter problem for generic costs using the fixed point @@ -491,7 +493,22 @@ def free_support_barycenter_generic_costs( function B takes a list of K arrays of shape (n, d_k) and returns an array of shape (n, d). - This function implements [76] Algorithm 2, which generalises [20] and [43] + This function implements two algorithms: + + - Algorithm 2 from [76] when `method=true_fixed_point` is used, which may + increase the support size of the barycenter at each iteration, with a + maximum final size of :math:`N_0 + T\sum_k n_k - TK` for T iterations and + an initial support size of :math:`N_0`. The computation of the iterates is + done using the North West Corner multi-marginal gluing method. This method + has convergence guarantees [76]. + + - Algorithm 3 from [76] when `method=L2_barycentric_proj` is used, which is + a heuristic simplification which fixes the weights and support size of the + barycenter by performing barycentric projections of the pair-wise OT + matrices. This method is substantially faster than the first one, but does + not have convergence guarantees. (Default) + + The implemented methods ([76] Algorithms 2 and 3), generalises [20] and [43] to general costs and includes convergence guarantees, including for discrete measures. @@ -511,12 +528,16 @@ def free_support_barycenter_generic_costs( Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays of shape (n\times d_K), computing the ground barycenters (broadcasted over n). If not provided, done with Adam on PyTorch (requires PyTorch - backend) + backend), inefficiently using the cost functions in `cost_list`. a : array-like, optional Array of shape (n,) representing weights of the barycenter measure.Defaults to uniform. numItermax : int, optional Maximum number of iterations (default is 100). + method : str, optional + Barycentre method: 'L2_barycentric_proj' (default) for Euclidean + barycentric projection, or 'true_fixed_point' for iterates using the + North West Corner multi-marginal gluing method. stopThr : float, optional If the iterations move less than this, terminate (default is 1e-5). log : bool, optional @@ -531,6 +552,10 @@ def free_support_barycenter_generic_costs( ground_bary_solver : str, optional Solver for auto ground bary solver (torch SGD or Adam). Default is "SGD". + clean_measure : bool, optional + For method=='true_fixed_point', whether to clean the discrete measure + (X, a) at each iteration to remove duplicate points and sum their + weights (default is False). Returns ------- @@ -557,9 +582,16 @@ def free_support_barycenter_generic_costs( -------- ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \lambda_k\|x-y\|_2^2`. + ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. + + ot.lp.NorthWestMMGluing : gluing method used in the `true_fixed_point` method. """ + assert method in [ + "L2_barycentric_proj", + "true_fixed_point", + ], "Method must be 'L2_barycentric_proj' or 'true_fixed_point'" nx = get_backend(X_init, measure_locations[0]) K = len(measure_locations) n = X_init.shape[0] @@ -604,8 +636,9 @@ def ground_bary(y, x_init): raise ImportError("PyTorch is required to use ground_bary=None") X_list = [X_init] if log else [] # store the iterations + a_list = [nx.copy(a)] if log and method == "true_fixed_point" else [] X = X_init - dX_list = [] # store the displacement squared norms + diff_list = [] # store the displacement squared norms exit_status = "Max iterations reached" for _ in range(numItermax): @@ -614,27 +647,255 @@ def ground_bary(y, x_init): for k in range(K) ] Y_perm = [] - for k in range(K): # compute barycentric projections - Y_perm.append(n * pi_list[k] @ measure_locations[k]) - if auto_ground_bary: # use previous position as initialization - X_next = ground_bary(Y_perm, X) - else: + + if method == "L2_barycentric_proj": + a_next = a # barycentre weights are fixed + for k in range(K): # L2 barycentric projection of pi_k + Y_perm.append((1 / a[:, None]) * pi_list[k] @ measure_locations[k]) + if auto_ground_bary: # use previous position as initialization + X_next = ground_bary(Y_perm, X) + else: + X_next = ground_bary(Y_perm) + + elif method == "true_fixed_point": + # North West Corner gluing of pi_k + J, a_next = NorthWestMMGluing(pi_list) + # J is a (N, K) array of indices, w is a (N,) array of weights + # Each Y_perm[k] is a (N, d_k) array of some points in Y_list[k] + Y_perm = [measure_locations[k][J[:, k]] for k in range(K)] + # warm start impossible due to possible size mismatch X_next = ground_bary(Y_perm) + if clean_measure and method == "true_fixed_point": + # clean the discrete measure (X, a) to remove duplicates + X_next, a_next = _clean_discrete_measure(X_next, a_next) + if log: X_list.append(X_next) + if method == "true_fixed_point": + a_list.append(a_next) # stationary criterion: move less than the threshold - dX = nx.sum((X - X_next) ** 2) - X = X_next + diff = emd2(a, a_next, dist(X, X_next)) if log: - dX_list.append(dX) + diff_list.append(diff) + + X = X_next + a = a_next - if dX < stopThr: + if diff < stopThr * nx.sum(X**2) / X.shape[0]: exit_status = "Stationary Point" break + if log: + log_dict = { + "X_list": X_list, + "exit_status": exit_status, + "a_list": a_list, + "diff_list": diff_list, + } + if method == "true_fixed_point": + return X, a, log_dict + else: + return X, log_dict + + if method == "true_fixed_point": + return X, a + else: + return X + + +def to_int_array(x): + """ + Converts an array to an integer type array. + """ + nx = get_backend(x) + if str(nx) == "numpy": + return x.astype(int) + + if str(nx) == "torch": + return x.to(int) + + if str(nx) == "jax": + return x.astype(int) + + if str(nx) == "cupy": + return x.astype(int) + + if str(nx) == "tf": + import tensorflow as tf + + return tf.cast(x, tf.int32) + + raise TypeError(f"Unsupported backend {str(nx)}") + + +def NorthWestMMGluing(pi_list, log=False): + r""" + Glue transport plans :math:`(pi_1, ..., pi_K)` which have a common first + marginal using the (multi-marginal) North-West Corner method. Writing the + marginals of each :math:`pi_k\in \mathbb{R}^{n\times n_l}` as :math:`a \in + \mathbb{R}^n` and :math:`b_k \in \mathbb{R}^{n_k}`, the output represents a + particular K-marginal transport plan :math:`\rho \in + \mathbb{R}^{n_1\times\cdots\times n_K}` whose k-th marginal is :math:`b_k`. + This K-plan is such that there exists a K+1-marginal transport plan + :math:`\gamma \in \mathbb{R}^{n\times n_1 \times \cdots \times n_K}` such + that :math:`\sum_i\gamma_{i,j_1,\cdots,j_K} = \rho_{j_1, \cdots, j_K}` and + with Einstein summation convention, :math:`\gamma_{i, j_1, \cdots, j_K} = + [\pi_k]_{i, j_k}` for all :math:`k=1,\cdots,K`. + + Instead of outputting the full K-multi-marginal plan :math:`\rho`, this + function provides an array `J` of shape (N, K) where each `J[i]` is of the + form `(J[i, 1], ..., J[i, K])` with each `J[i, k]` between 0 and + :math:`n_k-1`, and a weight vector `w` of size N, such that the K-plan + :math:`rho` writes: + + .. math:: + \rho_{j_1, \cdots, j_K} = 1\left(\exists i \text{ s.t. } (j_1, \cdots, j_K) = (J[i, 1], \cdots, J[i, K])\right)\ w_i. + + This representation is useful for its memory efficiency, as it avoids + storing the full K-marginal plan. + + If `log=True`, the function computes the full K+1-marginal transport plan + :math:`\gamma`and stores it in log_dict['gamma']. Note that this option is + extremely costly in memory. + + Parameters + ---------- + pi_list : list of arrays (n, n_k) + List of transport plans. + + log : bool, optional + If True, return a log dictionary (computationally expensive). + + Returns + ------- + J : array (N, K) + The indices (J[i, 1], ..., J[i, K]) of the K-plan rho. + w : array (N,) + The weights w_i of the K-plan rho. + log_dict : dict, optional + If log=True, a dictionary containing the full K+1-marginal transport + plan under the key 'gamma'. + """ + nx = get_backend(pi_list[0]) + a = nx.sum(pi_list[0], axis=1) # common first marginal a in Delta_n + nk_list = [pi.shape[1] for pi in pi_list] # list of n_k + K = len(pi_list) + n = pi_list[0].shape[0] # number of points in the first marginal + gamma = None + + log_dict = {} + if log: # n x n_1 x ... x n_K tensor + gamma = nx.zeros([n] + nk_list, type_as=pi_list[0]) + + gamma_weights = {} # dict of (j_1, ..., j_K) : weight + P_list = [nx.copy(pi) for pi in pi_list] # copy of the transport plans + + # jjs is a list of K lists of size m_k + # checks if each jj_idx[k] is < m_k + # this is to avoid over-shooting the while loop due to numerical + # imprecision in the conditions "x > 0" + def jj_idx_in_range(jj_idx, jjs): + out = True + for k in range(K): + out = out and jj_idx[k] < len(jjs[k]) + return out + + for i in range(n): + # jjs[k] is the list of indices j in [0, n_k - 1] such that Pk[i, j] >0 + jjs = [nx.to_numpy(nx.where(P[i, :] > 0)[0]) for P in P_list] + # list [0, ..., 0] of size K for use with jjs: current indices in jjs + jj_idx = [0] * K + u = a[i] # mass at i, will decrease to 0 as we fill gamma[i, :] + + # while there is mass to add to gamma[i, :] + while u > 0 and jj_idx_in_range(jj_idx, jjs): + # current multi-index j_1 ... j_K + jj = tuple(jjs[k][jj_idx[k]] for k in range(K)) + # min transport plan value: min_k pi_k[i, j_k] + v = nx.min(nx.stack([P_list[k][i, jj[k]] for k in range(K)])) + if log: # assign mass v to gamma[i, j_1, ..., j_K] + gamma[(i,) + jj] = v + if jj in gamma_weights: + gamma_weights[jj] += v + else: + gamma_weights[jj] = v + u -= v # at i, we u-v mass left to assign + for k in range(K): # update plan copies Pk + P_list[k][i, jj[k]] -= v # Pk[i, j_k] has v less mass left + if P_list[k][i, jj[k]] == 0: + # move to next index in jjs[k] if Pk[i, j_k] is empty + jj_idx[k] += 1 + + log_dict["gamma"] = gamma + J = list(gamma_weights.keys()) # list of multi-indices (j_1, ..., j_K) + J = to_int_array(nx.from_numpy(np.array(J), type_as=pi_list[0])) + w = nx.stack(list(gamma_weights.values())) if log: - return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} - return X + return J, w, log_dict + return J, w + + +def _clean_discrete_measure(X, a, tol=1e-10): + r""" + Simplifies a discrete measure by consolidating duplicate points and summing + their weights. Given a discrete measure with support X (n, d) and weights a + (n), returns a points Y (m, d) and weights b (m) such that Y is the set of + unique points in X and b is the sum of weights in a for each point in Y + + Parameters + ---------- + X : array-like + Array of shape (n, d) representing the support points of the discrete + measure. + a : array-like + Array of shape (n,) representing the weights associated with the support + points. + tol : float, optional + Tolerance for determining uniqueness of points in `X`. Points closer + than `tol` are considered identical. Default is 1e-10. + + Returns + ------- + Y : array-like + Array of shape (m, d) representing the unique support points of the + discrete measure. + b : array-like + Array of shape (m,) representing the summed weights for each unique + point in `Y`. + """ + nx = get_backend(X, a) + D = dist(X, X) + # each D[I[k], J[k]] < tol so X[I[k]] = X[J[k]] + idxI, idxJ = nx.where(D < tol) + idxI = nx.to_numpy(idxI) + idxJ = nx.to_numpy(idxJ) + # keep only the cases I[k] <= J[k] to avoid pairs (i, j) (j, i) with i != j + mask = idxI <= idxJ + idxI, idxJ = idxI[mask], idxJ[mask] + X_idx_to_Y_idx = {} # X[i] = Y[X_idx_to_Y_idx[i]] + # indices of unique points in X, at the end, Y := X[unique_X_idx] + unique_X_idx = [] + + b = [] + for i, j in zip(idxI, idxJ): + if i not in X_idx_to_Y_idx: # i is a new point + unique_X_idx.append(i) + X_idx_to_Y_idx[i] = len(unique_X_idx) - 1 + b.append(a[i]) + # j is a duplicate of i + if j not in X_idx_to_Y_idx: + X_idx_to_Y_idx[j] = X_idx_to_Y_idx[i] + b[X_idx_to_Y_idx[i]] += a[j] + + else: # i is not new, check if j is known + if j not in X_idx_to_Y_idx: + b[X_idx_to_Y_idx[i]] += a[j] + X_idx_to_Y_idx[j] = X_idx_to_Y_idx[i] + + # create the unique points array Y + Y = X[tuple(unique_X_idx), :] + b = nx.from_numpy(np.array(b), type_as=X) + return Y, b diff --git a/test/test_ot.py b/test/test_ot.py index 22612fa4a..55043c0a3 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -11,7 +11,7 @@ import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import torch, tf +from ot.backend import torch, tf, get_backend def test_emd_dimension_and_mass_mismatch(): @@ -437,7 +437,7 @@ def ground_bary(y): assert "X_list" in log assert "exit_status" in log - assert "dX_list" in log + assert "diff_list" in log np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7) @@ -475,6 +475,49 @@ def ground_bary(y): numItermax=1, ) + # test with unknown method + with pytest.raises(AssertionError): + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + method="unknown_method", + ) + + # test true fixed-point method + X4, a4, log4 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=3, + method="true_fixed_point", + log=True, + ) + + assert "a_list" in log4 + assert X4.shape[0] == a4.shape[0] == 1 + np.testing.assert_allclose(a4, ot.unif(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(X, X4, rtol=1e-5, atol=1e-7) + + # test with measure cleaning and no log + X5, a5 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=3, + method="true_fixed_point", + clean_measure=True, + ) + np.testing.assert_allclose(a5, ot.unif(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(X, X5, rtol=1e-5, atol=1e-7) + @pytest.mark.skipif(not torch, reason="No torch available") def test_free_support_barycenter_generic_costs_auto_ground_bary(): @@ -512,9 +555,9 @@ def ground_bary(y): X_init, cost_list, ground_bary=None, - ground_bary_lr=1e-2, + ground_bary_lr=2e-2, ground_bary_stopThr=1e-20, - ground_bary_numItermax=50, + ground_bary_numItermax=100, numItermax=10, log=True, ) @@ -529,7 +572,7 @@ def ground_bary(y): ground_bary=None, ground_bary_lr=1e-2, ground_bary_stopThr=1e-20, - ground_bary_numItermax=50, + ground_bary_numItermax=100, numItermax=10, ground_bary_solver="Adam", ) @@ -557,7 +600,12 @@ def ground_bary(y): return out X = ot.lp.free_support_barycenter_generic_costs( - measures_locations, measures_weights, X_init, cost_list, ground_bary + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + method="L2_barycentric_proj", ) measures_locations2 = nx.from_numpy(*measures_locations) @@ -565,11 +613,139 @@ def ground_bary(y): X_init2 = nx.from_numpy(X_init) X2 = ot.lp.free_support_barycenter_generic_costs( - measures_locations2, measures_weights2, X_init2, cost_list, ground_bary + measures_locations2, + measures_weights2, + X_init2, + cost_list, + ground_bary, + method="L2_barycentric_proj", ) np.testing.assert_allclose(X, nx.to_numpy(X2)) + X, a = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + method="true_fixed_point", + ) + + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) + X_init2 = nx.from_numpy(X_init) + + X2, a2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations2, + measures_weights2, + X_init2, + cost_list, + ground_bary, + method="true_fixed_point", + ) + + np.testing.assert_allclose(a, nx.to_numpy(a2)) + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + +def verify_gluing_validity(gamma, J, w, pi_list): + """ + Test the validity of the North-West gluing. + """ + nx = get_backend(gamma) + K = len(pi_list) + n = pi_list[0].shape[0] + nk_list = [pi.shape[1] for pi in pi_list] + + # Check first marginal + a = nx.sum(gamma, axis=tuple(range(1, K + 1))) + assert nx.allclose(a, nx.sum(pi_list[0], axis=1)) + + # Check other marginals + for k in range(K): + b_k = nx.sum(gamma, axis=tuple(i for i in range(K + 1) if i != k + 1)) + assert nx.allclose(b_k, nx.sum(pi_list[k], axis=0)) + + # Check bi-marginals + for k in range(K): + gamma_0k = nx.sum(gamma, axis=tuple(i for i in range(1, K + 1) if i != k + 1)) + assert nx.allclose(gamma_0k, pi_list[k]) + + # Check that N <= n + sum_k n_k - K + N = J.shape[0] + n_k_sum = sum(nk_list) + assert N <= n + n_k_sum - K, f"N={N}, n={n}, sum(n_k)={n_k_sum}, K={K}" + + # Check that w is on the simplex + w_sum = nx.sum(w) + assert nx.allclose(w_sum, 1), f"Sum of weights w is not 1: {w_sum}" + + # Check that gamma_1...K and (J, w) are consistent + rho = nx.zeros(nk_list, type_as=gamma) + for i in range(N): + jj = J[i] + rho[tuple(jj)] += w[i] + + gamma_1toK = nx.sum(gamma, axis=0) + assert nx.allclose(rho, gamma_1toK), "rho and gamma_1...K are not consistent" + + +def test_north_west_mm_gluing(): + rng = np.random.RandomState(0) + n = 7 + nk_list = [5, 6, 4] + a = rng.rand(n) + a = a / np.sum(a) + b_list = [rng.rand(nk) for nk in nk_list] + b_list = [b / np.sum(b) for b in b_list] + M_list = [rng.rand(n, nk) for nk in nk_list] + pi_list = [ot.emd(a, b, M) for b, M in zip(b_list, M_list)] + J, w, log_dict = ot.lp.NorthWestMMGluing(pi_list, log=True) + # Test the validity of the gluing + gamma = log_dict["gamma"] + verify_gluing_validity(gamma, J, w, pi_list) + + # test without log + J2, w2 = ot.lp.NorthWestMMGluing(pi_list, log=False) + np.testing.assert_allclose(J, J2) + np.testing.assert_allclose(w, w2) + + +def test_north_west_mm_gluing_backends(nx): + rng = np.random.RandomState(0) + n = 7 + nk_list = [5, 6, 4] + a = rng.rand(n) + a = a / np.sum(a) + b_list = [rng.rand(nk) for nk in nk_list] + b_list = [b / np.sum(b) for b in b_list] + M_list = [rng.rand(n, nk) for nk in nk_list] + pi_list = [ot.emd(a, b, M) for b, M in zip(b_list, M_list)] + + pi_list2 = [nx.from_numpy(pi) for pi in pi_list] + J, w, log_dict = ot.lp.NorthWestMMGluing(pi_list2, log=True) + gamma = log_dict["gamma"] + + # Test equality with numpy solution + J_np, w_np, log_dict_np = ot.lp.NorthWestMMGluing(pi_list, log=True) + gamma_np = log_dict_np["gamma"] + np.testing.assert_allclose(J, J_np) + np.testing.assert_allclose(w, w_np) + np.testing.assert_allclose(gamma, gamma_np) + + +def test_clean_discrete_measure(nx): + a = nx.ones(3) / 3.0 + X = nx.from_numpy(np.array([[1.0, 1.0], [1.0, 1.0], [2.0, 2.0]])) + X_clean, a_clean = ot.lp._barycenter_solvers._clean_discrete_measure(X, a) + a_true = nx.from_numpy(np.array([2 / 3, 1 / 3])) + X_true = nx.from_numpy(np.array([[1.0, 1.0], [2.0, 2.0]])) + assert a_clean.shape == a_true.shape + assert X_clean.shape == X_true.shape + np.testing.assert_allclose(a_clean, a_true) + np.testing.assert_allclose(X_clean, X_true) + @pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): From 939b93d87d79c2ee18d9f892b31acee5a725c254 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 6 Jun 2025 16:06:11 +0200 Subject: [PATCH 25/27] test and fixes --- ...ot_free_support_barycenter_generic_cost.py | 39 +++++++++-- examples/barycenters/plot_gmm_barycenter.py | 2 + ot/lp/_barycenter_solvers.py | 38 +++++------ test/test_ot.py | 65 +++++++++++++++++++ 4 files changed, 118 insertions(+), 26 deletions(-) diff --git a/examples/barycenters/plot_free_support_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py index 2886432ae..536303a58 100644 --- a/examples/barycenters/plot_free_support_barycenter_generic_cost.py +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -162,11 +162,11 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): # %% # Compute the barycenter measure with the true fixed-point algorithm -fixed_point_its = 3 +fixed_point_its = 5 torch.manual_seed(42) X_init = torch.rand(n, d) t0 = time() -X_bar, a_bar = free_support_barycenter_generic_costs( +X_bar, a_bar, log_dict = free_support_barycenter_generic_costs( Y_list, b_list, X_init, @@ -175,16 +175,18 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): numItermax=fixed_point_its, stopThr=stop_threshold, method="true_fixed_point", + log=True, + clean_measure=True, ) dt_true_fixed_point = time() - t0 # %% # Compute the barycenter measure with the barycentric (default) algorithm -fixed_point_its = 3 +fixed_point_its = 5 torch.manual_seed(42) X_init = torch.rand(n, d) t0 = time() -X_bar2 = free_support_barycenter_generic_costs( +X_bar2, log_dict2 = free_support_barycenter_generic_costs( Y_list, b_list, X_init, @@ -192,6 +194,7 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): B, numItermax=fixed_point_its, stopThr=stop_threshold, + log=True, ) dt_barycentric = time() - t0 @@ -251,3 +254,31 @@ def V(X, a): plt.tight_layout() # %% +# Plot energy convergence +fig, axes = plt.subplots(1, 2, figsize=(8, 4)) + +V_list = [V(X, a).item() for (X, a) in zip(log_dict["X_list"], log_dict["a_list"])] +V_list2 = [V(X, torch.ones(n) / n).item() for X in log_dict2["X_list"]] + +# Plot for True Fixed-Point Algorithm +axes[0].plot(V_list, lw=5, alpha=0.6) +axes[0].scatter(range(len(V_list)), V_list, color="blue", alpha=0.8, s=100) +axes[0].set_title("True Fixed-Point Algorithm") +axes[0].set_xlabel("Iteration") +axes[0].set_ylabel("Barycenter Energy") +axes[0].set_yscale("log") +axes[0].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) + +# Plot for Heuristic Barycentric Algorithm +axes[1].plot(V_list2, lw=5, alpha=0.6) +axes[1].scatter(range(len(V_list2)), V_list2, color="blue", alpha=0.8, s=100) +axes[1].set_title("Heuristic Barycentric Algorithm") +axes[1].set_xlabel("Iteration") +axes[1].set_ylabel("Barycenter Energy") +axes[1].set_yscale("log") +axes[1].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) + +plt.tight_layout() +plt.show() + +# %% diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py index f379a9914..c2e799a64 100644 --- a/examples/barycenters/plot_gmm_barycenter.py +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -140,3 +140,5 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) ax.axis(axis) ax.axis("off") + +# %% diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 9cd04ca2d..b7d7e0b16 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -688,25 +688,25 @@ def ground_bary(y, x_init): exit_status = "Stationary Point" break - if log: - log_dict = { - "X_list": X_list, - "exit_status": exit_status, - "a_list": a_list, - "diff_list": diff_list, - } - if method == "true_fixed_point": - return X, a, log_dict - else: - return X, log_dict - + if log: + log_dict = { + "X_list": X_list, + "exit_status": exit_status, + "a_list": a_list, + "diff_list": diff_list, + } if method == "true_fixed_point": - return X, a + return X, a, log_dict else: - return X + return X, log_dict + + if method == "true_fixed_point": + return X, a + else: + return X -def to_int_array(x): +def _to_int_array(x): """ Converts an array to an integer type array. """ @@ -728,8 +728,6 @@ def to_int_array(x): return tf.cast(x, tf.int32) - raise TypeError(f"Unsupported backend {str(nx)}") - def NorthWestMMGluing(pi_list, log=False): r""" @@ -831,7 +829,7 @@ def jj_idx_in_range(jj_idx, jjs): log_dict["gamma"] = gamma J = list(gamma_weights.keys()) # list of multi-indices (j_1, ..., j_K) - J = to_int_array(nx.from_numpy(np.array(J), type_as=pi_list[0])) + J = _to_int_array(nx.from_numpy(np.array(J), type_as=pi_list[0])) w = nx.stack(list(gamma_weights.values())) if log: return J, w, log_dict @@ -885,10 +883,6 @@ def _clean_discrete_measure(X, a, tol=1e-10): unique_X_idx.append(i) X_idx_to_Y_idx[i] = len(unique_X_idx) - 1 b.append(a[i]) - # j is a duplicate of i - if j not in X_idx_to_Y_idx: - X_idx_to_Y_idx[j] = X_idx_to_Y_idx[i] - b[X_idx_to_Y_idx[i]] += a[j] else: # i is not new, check if j is known if j not in X_idx_to_Y_idx: diff --git a/test/test_ot.py b/test/test_ot.py index 55043c0a3..1735e9d07 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -518,6 +518,20 @@ def ground_bary(y): np.testing.assert_allclose(a5, ot.unif(1), rtol=1e-5, atol=1e-7) np.testing.assert_allclose(X, X5, rtol=1e-5, atol=1e-7) + # test with (too) lax convergence criterion + # for Stationary Point exit status + X6, log6 = ot.lp.free_support_barycenter_generic_costs( + [np.array([-1.0]).reshape((1, 1))], + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=3, + stopThr=1e20, + log=True, + ) + assert log6["exit_status"] == "Stationary Point" + @pytest.mark.skipif(not torch, reason="No torch available") def test_free_support_barycenter_generic_costs_auto_ground_bary(): @@ -579,6 +593,17 @@ def ground_bary(y): np.testing.assert_allclose(X2.numpy(), X3.numpy(), rtol=1e-3, atol=1e-3) + # test with (too) lax convergence criterion for ground barycenter + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + numItermax=1, + ground_bary_stopThr=100, + ) + def test_free_support_barycenter_generic_costs_backends(nx): measures_locations = [ @@ -711,6 +736,16 @@ def test_north_west_mm_gluing(): np.testing.assert_allclose(J, J2) np.testing.assert_allclose(w, w2) + # test setting with highly non-injective plans + n = 6 + a = ot.unif(n) + b_list = [a] * 3 + pi_list = [a[:, None] @ a[None, :]] * 3 + J, w, log_dict = ot.lp.NorthWestMMGluing(pi_list, log=True) + # Test the validity of the gluing + gamma = log_dict["gamma"] + verify_gluing_validity(gamma, J, w, pi_list) + def test_north_west_mm_gluing_backends(nx): rng = np.random.RandomState(0) @@ -746,6 +781,36 @@ def test_clean_discrete_measure(nx): np.testing.assert_allclose(a_clean, a_true) np.testing.assert_allclose(X_clean, X_true) + a = nx.ones(3) / 3.0 + X = nx.from_numpy(np.array([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0]])) + X_clean, a_clean = ot.lp._barycenter_solvers._clean_discrete_measure(X, a) + a_true = nx.from_numpy(np.array([2 / 3, 1 / 3])) + X_true = nx.from_numpy(np.array([[1.0, 1.0], [2.0, 2.0]])) + assert a_clean.shape == a_true.shape + assert X_clean.shape == X_true.shape + np.testing.assert_allclose(a_clean, a_true) + np.testing.assert_allclose(X_clean, X_true) + + n = 5 + a = nx.ones(n) / n + v = nx.from_numpy(np.array([1.0, 2.0, 3.0])) + X = nx.stack([v] * n, axis=0) + X_clean, a_clean = ot.lp._barycenter_solvers._clean_discrete_measure(X, a) + a_true = np.array([1.0]) + X_true = np.array([1.0, 2.0, 3.0]).reshape(1, 3) + assert a_clean.shape == a_true.shape + assert X_clean.shape == X_true.shape + np.testing.assert_allclose(a_clean, a_true) + np.testing.assert_allclose(X_clean, X_true) + + +def test_to_int_array(nx): + a_np = np.array([1.0, 2.0, 3.0]) + a = nx.from_numpy(a_np) + a_int = ot.lp._barycenter_solvers._to_int_array(a) + a_np_int = a_np.astype(int) + np.testing.assert_allclose(nx.to_numpy(a_int), a_np_int) + @pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): From 6f47b29c683ac01fd9d7e8475c333a1dcb446f91 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 6 Jun 2025 16:39:25 +0200 Subject: [PATCH 26/27] no jax or tf support for free_support_generic_costs due to array assignment --- test/test_ot.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_ot.py b/test/test_ot.py index 1735e9d07..d523d3248 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -605,6 +605,8 @@ def ground_bary(y): ) +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") def test_free_support_barycenter_generic_costs_backends(nx): measures_locations = [ np.array([-1.0]).reshape((1, 1)), @@ -747,6 +749,8 @@ def test_north_west_mm_gluing(): verify_gluing_validity(gamma, J, w, pi_list) +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") def test_north_west_mm_gluing_backends(nx): rng = np.random.RandomState(0) n = 7 From 9a344e18bc9d875cccdc6892908c60885dc68c9d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 6 Jun 2025 17:11:56 +0200 Subject: [PATCH 27/27] updated gmm bar colours --- examples/barycenters/plot_gmm_barycenter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py index c2e799a64..6dd0ad8be 100644 --- a/examples/barycenters/plot_gmm_barycenter.py +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -132,12 +132,14 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): # %% # Plot the results +c_list = ["#7ED321", "#4A90E2", "#9013FE", "#F5A623"] +c_bar = "#D0021B" fig, ax = plt.subplots(figsize=(6, 6)) axis = [-4, 4, -2, 6] ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) for k in range(K): - draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax) -draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) + draw_gmm(means_list[k], covs_list[k], w_list[k], color=c_list[k], ax=ax) +draw_gmm(means_bar, covs_bar, ot.unif(n), color=c_bar, ax=ax) ax.axis(axis) ax.axis("off")