diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c7916f50a..5cc34f38b 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -45,6 +45,7 @@ The contributors to this library are: * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) +* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn) ## Acknowledgments diff --git a/README.md b/README.md index 939cc6158..a9a94c53e 100644 --- a/README.md +++ b/README.md @@ -347,3 +347,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations. [64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. + +[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 3c428c521..b21e5b0dc 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -22,6 +22,7 @@ + Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578) + Add new BAPG solvers with KL projections for GW and FGW (PR #581) + Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584) ++ Added support for [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf) (PR #568) #### Closed issues diff --git a/ot/__init__.py b/ot/__init__.py index 9c33e9feb..99d075e5a 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions @@ -52,6 +53,7 @@ from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq @@ -69,4 +71,4 @@ 'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn'] diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..5c8f673cb --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,341 @@ +""" +Low rank OT solvers +""" + +# Author: Laurène David +# +# License: MIT License + + +import warnings +from .utils import unif, get_lowrank_lazytensor +from .backend import get_backend + + +def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): + """ + Compute the low rank decomposition of a squared euclidean distance matrix. + This function won't work for any other distance metric. + + See "Section 3.5, proposition 1" + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + nx : POT backend, default none + + + Returns + ---------- + M1 : array-like, shape (n_samples_a, dim+2) + First low rank decomposition of the distance matrix + M2 : array-like, shape (n_samples_b, dim+2) + Second low rank decomposition of the distance matrix + + + References + ---------- + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). + "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. + """ + + if nx is None: + nx = get_backend(X_s, X_t) + + ns = X_s.shape[0] + nt = X_t.shape[0] + + # First low rank decomposition of the cost matrix (A) + array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1)) + array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1)) + M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1) + + # Second low rank decomposition of the cost matrix (B) + array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1)) + array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1)) + M2 = nx.concatenate((array1, array2, X_t), axis=1) + + return M1, M2 + + +def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): + """ + Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. + This function is specific to lowrank_sinkhorn. + + Parameters + ---------- + eps1 : array-like, shape (n_samples_a, r) + First input parameter of the Dykstra algorithm + eps2 : array-like, shape (n_samples_b, r) + Second input parameter of the Dykstra algorithm + eps3 : array-like, shape (r,) + Third input parameter of the Dykstra algorithm + p1 : array-like, shape (n_samples_a,) + Samples weights in the source domain (same as "a" in lowrank_sinkhorn) + p2 : array-like, shape (n_samples_b,) + Samples weights in the target domain (same as "b" in lowrank_sinkhorn) + alpha: int + Lower bound for the weight vector g (same as "alpha" in lowrank_sinkhorn) + stopThr : float + Stop threshold on error + numItermax : int + Max number of iterations + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + nx : default None + POT backend + + + Returns + ---------- + Q : array-like, shape (n_samples_a, r) + Dykstra update of the first low-rank matrix decomposition Q + R: array-like, shape (n_samples_b, r) + Dykstra update of the Second low-rank matrix decomposition R + g : array-like, shape (r, ) + Dykstra update of the weight vector g + + + References + ---------- + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). + "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. + + """ + + # POT backend if None + if nx is None: + nx = get_backend(eps1, eps2, eps3, p1, p2) + + # ----------------- Initialisation of Dykstra algorithm ----------------- + r = len(eps3) # rank + g_ = nx.copy(eps3) # \tilde{g} + q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 + v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)} + q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} + err = 1 # initial error + + # --------------------- Dykstra algorithm ------------------------- + + # See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper + + for ii in range(numItermax): + if err > stopThr: + # Compute u^{(1)} and u^{(2)} + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + # Compute g, g^{(3)}_1 and update \tilde{g} + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = nx.copy(g) + + # Compute new value of g with \prod + prod1 = (v1_ * q1) * nx.dot(eps1.T, u1) + prod2 = (v2_ * q2) * nx.dot(eps2.T, u2) + g = (g_ * q3_2 * prod1 * prod2) ** (1 / 3) + + # Compute v^{(1)} and v^{(2)} + v1 = g / nx.dot(eps1.T, u1) + v2 = g / nx.dot(eps2.T, u2) + + # Compute q^{(1)}, q^{(2)} and q^{(3)}_2 + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + # Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g} + v1_, v2_ = nx.copy(v1), nx.copy(v2) + g_ = nx.copy(g) + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + else: + break + + else: + if warn: + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + ) + + # Compute low rank matrices Q, R + Q = u1[:, None] * eps1 * v1[None, :] + R = u2[:, None] * eps2 * v2[None, :] + + return Q, R, g + + +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None, + numItermax=1000, stopThr=1e-9, warn=True, log=False): + r""" + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + The function solves the following optimization problem: + + .. math:: + \mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - + \mathrm{reg} \cdot H((Q,R,g)) + + where : + - :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. + - :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan + - :math: `g` is the weight vector for the low-rank decomposition of the OT plan + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math: `r` is the rank of the OT plan + - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem + + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional. Default is None. (>0) + Nonnegative rank of the OT plan. If None, min(ns, nt) is considered. + alpha: int, optional. Default is None. (>0 and <1/r) + Lower bound for the weight vector g. If None, 1e-10 is considered + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + log : bool, optional + record log if True + + + Returns + ------- + lazy_plan : LazyTensor() + OT plan in a LazyTensor object of shape (shape_plan) + See :any:`LazyTensor` for more information. + value : float + Optimal value of the optimization problem + value_linear : float + Linear OT loss with the optimal OT + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + .. [65] Scetbon, M., Cuturi, M., & Peyré, G (2021). + "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. + + """ + + # POT backend + nx = get_backend(X_s, X_t) + ns, nt = X_s.shape[0], X_t.shape[0] + + # Initialize weights a, b + if a is None: + a = unif(ns, type_as=X_s) + if b is None: + b = unif(nt, type_as=X_t) + + # Compute rank (see Section 3.1, def 1) + r = rank + if rank is None: + r = min(ns, nt) + + if alpha is None: + alpha = 1e-10 + + # Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank) + # (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper) + if 1 / r < alpha: + raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( + a=alpha, r=1 / rank)) + + if r <= 0: + raise ValueError("The rank parameter cannot have a negative value") + + # Low rank decomposition of the sqeuclidean cost matrix (A, B) + M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None) + + # Compute gamma (see "Section 3.4, proposition 4" in the paper) + L = nx.sqrt( + 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + + (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 + ) + gamma = 1 / (2 * L) + + # Initialize the low rank matrices Q, R, g + Q = nx.ones((ns, r), type_as=a) + R = nx.ones((nt, r), type_as=a) + g = nx.ones(r, type_as=a) + k = 100 + + # -------------------------- Low rank algorithm ------------------------------ + # see "Section 3.3, Algorithm 3 LOT" in the paper + + for ii in range(k): + # Compute the C*R dot matrix using the lr decomposition of C + CR_ = nx.dot(M2.T, R) + CR = nx.dot(M1, CR_) + + # Compute the C.t * Q dot matrix using the lr decomposition of C + CQ_ = nx.dot(M1.T, Q) + CQ = nx.dot(M2, CQ_) + + diag_g = (1 / g)[None, :] + + eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q)) + eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g)) + + Q, R, g = _LR_Dysktra( + eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx + ) + Q = Q + 1e-16 + R = R + 1e-16 + + # ----------------- Compute lazy_plan, value and value_linear ------------------ + # see "Section 3.2: The Low-rank OT Problem" in the paper + + # Compute lazy plan (using LazyTensor class) + lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) + + # Compute value_linear (using trace formula) + v1 = nx.dot(Q.T, M1) + v2 = nx.dot(R, (v1.T * diag_g).T) + value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) + + # Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper) + reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q + reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g + reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R + value = value_linear + reg * (reg_Q + reg_g + reg_R) + + if log: + dict_log = dict() + dict_log["value"] = value + dict_log["value_linear"] = value_linear + dict_log["lazy_plan"] = lazy_plan + + return Q, R, g, dict_log + + return Q, R, g diff --git a/ot/solvers.py b/ot/solvers.py index a41762a5c..40a03e974 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -22,6 +22,7 @@ from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport +from .lowrank import lowrank_sinkhorn lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] @@ -1248,6 +1249,25 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if not lazy0: # store plan if not lazy plan = lazy_plan[:] + elif method == "lowrank": + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if reg is None: + reg = 0 + + Q, R, g, log = lowrank_sinkhorn(X_a, X_b, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) + value = log['value'] + value_linear = log['value_linear'] + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + elif method.startswith('geomloss'): # Geomloss solver for entropi OT split_method = method.split('_') diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..65f76a77b --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,88 @@ +""" Test for low rank sinkhorn solvers """ + +# Author: Laurène DAVID +# +# License: MIT License + +import ot +import numpy as np +import pytest + + +def test_compute_lr_sqeuclidean_matrix(): + # test computation of low rank cost matrices M1 and M2 + n = 100 + X_s = np.reshape(1.0 * np.arange(2 * n), (n, 2)) + X_t = np.reshape(1.0 * np.arange(2 * n), (n, 2)) + + M1, M2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t) + M = ot.dist(X_s, X_t, metric="sqeuclidean") # original cost matrix + + np.testing.assert_allclose(np.dot(M1, M2.T), M, atol=1e-05) + + +def test_lowrank_sinkhorn(): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(n), (n, 1)) + + Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True) + P = log["lazy_plan"][:] + value_linear = log["value_linear"] + + # check constraints for P + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + # check if lazy_plan is equal to the fully computed plan + P_true = np.dot(Q, np.dot(np.diag(1 / g), R.T)) + np.testing.assert_allclose(P, P_true, atol=1e-05) + + # check if value_linear is correct with its original formula + M = ot.dist(X_s, X_t, metric="sqeuclidean") + value_linear_true = np.sum(M * P_true) + np.testing.assert_allclose(value_linear, value_linear_true, atol=1e-05) + + # check warn parameter when Dykstra algorithm doesn't converge + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1) + + +@pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6))) +def test_lowrank_sinkhorn_alpha_error(alpha, rank): + # Test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.raises(ValueError): + ot.lowrank.lowrank_sinkhorn( + X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False + ) + + +@pytest.skip_backend('tf') +def test_lowrank_sinkhorn_backends(nx): + # Test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, reg=0.1, log=True) + lazy_plan = log["lazy_plan"] + P = lazy_plan[:] + + np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) + np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) diff --git a/test/test_solvers.py b/test/test_solvers.py index bf07b7af8..343220c45 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -30,12 +30,14 @@ {'method': 'gaussian'}, {'method': 'gaussian', 'reg': 1}, {'method': 'factored', 'rank': 10}, + {'method': 'lowrank', 'reg': 0.1} ] lst_parameters_solve_sample_NotImplemented = [ {'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics {'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean + {"method": 'lowrank', 'metric': 'euclidean'}, # fail lowrank on metric not euclidean {'lazy': True}, # fail lazy for non regularized {'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced {'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized @@ -413,7 +415,7 @@ def test_solve_sample_methods(nx, method_params): assert_allclose_sol(sol, solb) sol2 = ot.solve_sample(x, x, **method_params) - if method_params['method'] != 'factored': + if method_params['method'] not in ['factored', 'lowrank']: np.testing.assert_allclose(sol2.value, 0)