diff --git a/RELEASES.md b/RELEASES.md index 321d9a78b..915b5c34e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -16,6 +16,7 @@ + Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) + Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) + Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) ++ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..9c33e9feb 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import solvers from . import gaussian + # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, wasserstein_circle, @@ -50,7 +51,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample # utils functions from .utils import dist, unif, tic, toc, toq @@ -65,7 +66,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', 'solve_gromov', + 'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/bregman/_empirical.py b/ot/bregman/_empirical.py index 55c5ada3f..b84c3b389 100644 --- a/ot/bregman/_empirical.py +++ b/ot/bregman/_empirical.py @@ -11,12 +11,56 @@ import warnings -from ..utils import dist, list_to_array, unif +from ..utils import dist, list_to_array, unif, LazyTensor from ..backend import get_backend from ._sinkhorn import sinkhorn, sinkhorn2 +def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=None): + r""" Get a LazyTensor of Sinkhorn solution from the dual potentials + + The returned LazyTensor is + :math:`\mathbf{T} = exp( \mathbf{f} \mathbf{1}_b^\top + \mathbf{1}_a \mathbf{g}^\top - \mathbf{C}/reg)`, where :math:`\mathbf{C}` is the pairwise metric matrix between samples :math:`\mathbf{X}_a` and :math:`\mathbf{X}_b`. + + Parameters + ---------- + X_a : array-like, shape (n_samples_a, dim) + samples in the source domain + X_b : array-like, shape (n_samples_b, dim) + samples in the target domain + f : array-like, shape (n_samples_a,) + First dual potentials (log space) + g : array-like, shape (n_samples_b,) + Second dual potentials (log space) + metric : str, default='sqeuclidean' + Metric used for the cost matrix computation + reg : float, default=1e-1 + Regularization term >0 + nx : Backend(), default=None + Numerical backend used + + + Returns + ------- + T : LazyTensor + Sinkhorn solution tensor + """ + + if nx is None: + nx = get_backend(X_a, X_b, f, g) + + shape = (X_a.shape[0], X_b.shape[0]) + + def func(i, j, X_a, X_b, f, g, metric, reg): + C = dist(X_a[i], X_b[j], metric=metric) + return nx.exp(f[i, None] + g[None, j] - C / reg) + + T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, metric=metric, reg=reg) + + return T + + def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, warn=True, warmstart=None, **kwargs): @@ -198,6 +242,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if log: dict_log["u"] = f dict_log["v"] = g + dict_log["niter"] = i_ot + dict_log["lazy_plan"] = get_sinkhorn_lazytensor(X_s, X_t, f, g, metric, reg) return (f, g, dict_log) else: return (f, g) diff --git a/ot/gaussian.py b/ot/gaussian.py index 708f9eb16..0ddb92013 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -249,7 +249,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Cs12 = nx.sqrtm(Cs) B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) - W = nx.sqrt(nx.norm(ms - mt)**2 + B) + W = nx.sqrt(nx.maximum(nx.norm(ms - mt)**2 + B, 0)) if log: log = {} diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..aed7e8ffe 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -7,11 +7,11 @@ # # License: MIT License -from .utils import OTResult -from .lp import emd2 +from .utils import OTResult, dist +from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log +from .bregman import sinkhorn_log, empirical_sinkhorn2 from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, @@ -20,12 +20,12 @@ entropic_semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_gromov_wasserstein2) from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 - -#, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 +from .gaussian import empirical_bures_wasserstein_distance +from .factored import factored_optimal_transport def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None, + unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, potentials_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object @@ -59,7 +59,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + Type of unbalanced penalization function :math:`U` either "KL", "L2", + "TV", by default "KL" + method : str, optional + Method for solving the problem when multiple algorithms are available, + default None for automatic selection. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional @@ -90,7 +94,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, The following methods are available for solving the OT problems: - - **Classical exact OT problem** (default parameters): + - **Classical exact OT problem [1]** (default parameters) : .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F @@ -107,7 +111,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, res = ot.solve(M, a, b) - - **Entropic regularized OT** (when ``reg!=None``): + - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) @@ -127,7 +131,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') - - **Quadratic regularized OT** (when ``reg!=None`` and ``reg_type="L2"``): + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) @@ -144,7 +148,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, res = ot.solve(M,a,b,reg=1.0,reg_type='L2') - - **Unbalanced OT** (when ``unbalanced!=None``): + - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) @@ -154,14 +158,14 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, .. code-block:: python # default is ``"KL"`` - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) + res = ot.solve(M,a,b,unbalanced=1.0) # quadratic unbalanced OT - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') + res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2') # TV = partial OT - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='TV') + res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV') - - **Regularized unbalanced regularized OT** (when ``unbalanced!=None`` and ``reg!=None``): + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) @@ -182,6 +186,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, References ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 @@ -199,6 +208,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + """ # detect backend @@ -413,9 +426,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 method : str, optional - Method for solving the problem, for entropic problems "PGD" is projected - gradient descent and "PPA" for proximal point, default None for - automatic selection ("PGD"). + Method for solving the problem when multiple algorithms are available, + default None for automatic selection. max_iter : int, optional Maximum number of iterations, by default None (default values in each solvers) @@ -601,6 +613,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_quad = None plan = None status = None + log = None loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} @@ -845,6 +858,396 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) res = OTResult(potentials=potentials, value=value, - value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) + value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx, log=log) return res + + +def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", + unbalanced=None, + unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, + potentials_init=None, X_init=None, tol=None, verbose=False): + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + where the cost matrix :math:`\mathbf{M}` is computed from the samples in the + source and target domains such that :math:`M_{i,j} = d(x_i,y_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + batch_size : int, optional + Batch size for lazy solver, by default None (default values in each + solvers) + method : str, optional + Method for solving the problem, this can be used to select the solver + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b) + + # for uniform weights + res = ot.solve_sample(xa, xb) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_sample(xa, xb, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') + + # lazy solver of memory complexity O(n) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) + # lazy OT plan + lazy_plan = res.lazy_plan + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + with M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + with M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', + unbalanced=1.0, unbalanced_type='L2') + + + - **Factored OT [2]** (when ``method='factored'``): + + This method solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated + to the samples in the source and target domains, and :math:`W_2` is the + Wasserstein distance. This problem is solved using exact OT solvers for + `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides + two transport plans that can be used to recover a low rank OT plan between + the two distributions. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='factored', rank=10) + + # recover the lazy low rank plan + factored_solution_lazy = res.lazy_plan + + # recover the full low rank plan + factored_solution = factored_solution_lazy[:] + + - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): + + This method computes the Gaussian Bures-Wasserstein distance between two + Gaussian distributions estimated from teh empirical distributions + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + The covariances and means are estimated from the data. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='gaussian') + + # recover the squared Gaussian Bures-Wasserstein distance + BW_dist = res.value + + - **Wasserstein 1d [1]** (when ``method='1D'``): + + This method computes the Wasserstein distance between two 1d distributions + estimated from the empirical distributions. For multivariate data the + distances are computed independently for each dimension. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='1D') + + # recover the squared Wasserstein distances + W_dists = res.value + + + .. _references-solve-sample: + References + ---------- + + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse + Optimal Transport. Proceedings of the Twenty-First International + Conference on Artificial Intelligence and Statistics (AISTATS). + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + + + """ + + if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']: + lazy0 = lazy + lazy = True + + if not lazy: # default non lazy solver calls ot.solve + + # compute cost matrix M and use solve function + M = dist(X_a, X_b, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + + return res + + else: + + # Detect backend + nx = get_backend(X_a, X_b, a, b) + + # default values for solutions + potentials = None + value = None + value_linear = None + plan = None + lazy_plan = None + status = None + log = None + + method = method.lower() if method is not None else '' + + if method == '1d': # Wasserstein 1d (parallel on all dimensions) + if metric == 'sqeuclidean': + p = 2 + elif metric in ['euclidean', 'cityblock']: + p = 1 + else: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + value = wasserstein_1d(X_a, X_b, a, b, p=p) + value_linear = value + + elif method == 'gaussian': # Gaussian Bures-Wasserstein + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if reg is None: + reg = 1e-6 + + value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True) + value = value**2 # return the value (squared bures distance) + value_linear = value # return the value + + elif method == 'factored': # Factored OT + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if max_iter is None: + max_iter = 100 + if tol is None: + tol = 1e-7 + if reg is None: + reg = 0 + + Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose) + log['X'] = X + + value_linear = log['costa'] + log['costb'] + value = value_linear # TODO add reg term + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + + elif reg is None or reg == 0: # exact OT + + if unbalanced is None: # balanced EMD solver not available for lazy + raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) + + else: + raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) + + else: + if unbalanced is None: + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + lazy_plan = log['lazy_plan'] + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan, + value_linear=value_linear, plan=plan, status=status, backend=nx, log=log) + return res diff --git a/ot/utils.py b/ot/utils.py index 0936648ca..f64c2fea6 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1115,6 +1115,14 @@ def status(self): else: raise NotImplementedError() + @property + def log(self): + """Dictionary containing potential information about the solver.""" + if self._log is not None: + return self._log + else: + raise NotImplementedError() + # Barycentric mappings ------------------------- # Return the displacement vectors as an array # that has the same shape as "xa"/"xb" (for samples) diff --git a/test/test_bregman.py b/test/test_bregman.py index 8627df3c6..67257f899 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1078,10 +1078,10 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) - G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) + G_log = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn( @@ -1091,10 +1091,14 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True) + G_lazy = nx.to_numpy(log['lazy_plan'][:]) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) + loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=False) + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian @@ -1109,6 +1113,7 @@ def test_lazy_empirical_sinkhorn(nx): np.testing.assert_allclose( sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + np.testing.assert_allclose(G_log, G_lazy, atol=1e-05) def test_empirical_sinkhorn_divergence(nx): diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..c6e1a3770 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -23,6 +23,27 @@ lst_unbalanced_gromov = [None, 0.9] lst_alpha = [0, 0.4, 0.9, 1] +lst_method_params_solve_sample = [ + {'method': '1d'}, + {'method': '1d', 'metric': 'euclidean'}, + {'method': 'gaussian'}, + {'method': 'gaussian', 'reg': 1}, + {'method': 'factored', 'rank': 10}, +] + +lst_parameters_solve_sample_NotImplemented = [ + {'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics + {'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean + {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean + {'lazy': True}, # fail lazy for non regularized + {'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced + {'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized +] + +# set readable ids for each param +lst_method_params_solve_sample = [pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample] +lst_parameters_solve_sample_NotImplemented = [pytest.param(param, id=str(param)) for param in lst_parameters_solve_sample_NotImplemented] + def assert_allclose_sol(sol1, sol2): @@ -255,3 +276,118 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 20 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + M = ot.dist(X_s, X_t) + + # solve with ot.solve + sol00 = ot.solve(M, a, b) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + assert_allclose_sol(sol0, sol00) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + +def test_solve_sample_lazy(nx): + # test solve_sample when is_Lazy = False + n = 20 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + X_s, X_t, a, b = nx.from_numpy(X_s, X_t, a, b) + + M = ot.dist(X_s, X_t) + + # solve with ot.solve + sol00 = ot.solve(M, a, b, reg=1) + + sol0 = ot.solve_sample(X_s, X_t, a, b, reg=1) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=1, lazy=True) + + assert_allclose_sol(sol0, sol00) + + np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("method_params", lst_method_params_solve_sample) +def test_solve_sample_methods(nx, method_params): + + n_samples_s = 20 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + xb, yb, ab, bb = nx.from_numpy(x, y, a, b) + + sol = ot.solve_sample(x, y, **method_params) + solb = ot.solve_sample(xb, yb, ab, bb, **method_params) + + # check some attributes (no need ) + assert_allclose_sol(sol, solb) + + sol2 = ot.solve_sample(x, x, **method_params) + if method_params['method'] != 'factored': + np.testing.assert_allclose(sol2.value, 0) + + +@pytest.mark.parametrize("method_params", lst_parameters_solve_sample_NotImplemented) +def test_solve_sample_NotImplemented(nx, method_params): + + n_samples_s = 20 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + xb, yb, ab, bb = nx.from_numpy(x, y, a, b) + + with pytest.raises(NotImplementedError): + ot.solve_sample(xb, yb, ab, bb, **method_params) diff --git a/test/test_utils.py b/test/test_utils.py index 3a9d590ab..258a1c742 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -318,6 +318,9 @@ def test_cost_normalization(nx): M1 = nx.to_numpy(M) np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max()) + with pytest.raises(ValueError): + ot.utils.cost_normalization(C1, 'error') + def test_check_params(): @@ -328,6 +331,16 @@ def test_check_params(): assert res0 is False +def test_check_random_state_error(): + with pytest.raises(ValueError): + ot.utils.check_random_state('error') + + +def test_get_parameter_pair_error(): + with pytest.raises(ValueError): + ot.utils.get_parameter_pair((1, 2, 3)) # not pair ;) + + def test_deprecated_func(): @ot.utils.deprecated('deprecated text for fun') @@ -408,7 +421,8 @@ def test_OTResult(): 'status', 'value', 'value_linear', - 'value_quad'] + 'value_quad', + 'log'] for at in lst_attributes: print(at) with pytest.raises(NotImplementedError):