diff --git a/README.md b/README.md index a9a94c53e..3c9b212ce 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,8 @@ The numerous contributors to this library are listed [here](CONTRIBUTORS.md). POT has benefited from the financing or manpower from the following partners: -ANRCNRS3IA +ANRCNRS3IAHi!PARIS + ## Contributions and code of conduct diff --git a/docs/source/_static/images/logo_hiparis.png b/docs/source/_static/images/logo_hiparis.png new file mode 100644 index 000000000..1ce6dfb5a Binary files /dev/null and b/docs/source/_static/images/logo_hiparis.png differ diff --git a/examples/others/plot_lowrank_sinkhorn.py b/examples/others/plot_lowrank_sinkhorn.py new file mode 100644 index 000000000..ece35b295 --- /dev/null +++ b/examples/others/plot_lowrank_sinkhorn.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +""" +======================================== +Low rank Sinkhorn +======================================== + +This example illustrates the computation of Low Rank Sinkhorn [26]. + +[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). +"Low-rank Sinkhorn factorization". In International Conference on Machine Learning. +""" + +# Author: Laurène David +# +# License: MIT License +# +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 +m = 120 + +# Gaussian distribution +a = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss(n, m=int(5 * n / 6), s=15 / np.sqrt(2)) +a = a / np.sum(a) + +b = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss(m, m=int(m / 2), s=35 / np.sqrt(2)) +b = b / np.sum(b) + +# Source and target distribution +X = np.arange(n).reshape(-1, 1) +Y = np.arange(m).reshape(-1, 1) + + +############################################################################## +# Solve Low rank sinkhorn +# ------------ + +#%% +# Solve low rank sinkhorn +Q, R, g, log = ot.lowrank_sinkhorn(X, Y, a, b, rank=10, init="random", gamma_init="rescale", rescale_cost=True, warn=False, log=True) +P = log["lazy_plan"][:] + +ot.plot.plot1D_mat(a, b, P, 'OT matrix Low rank') + + +############################################################################## +# Sinkhorn vs Low Rank Sinkhorn +# ----------------------- +# Compare Sinkhorn and Low rank sinkhorn with different regularizations and ranks. + +#%% Sinkhorn + +# Compute cost matrix for sinkhorn OT +M = ot.dist(X, Y) +M = M / np.max(M) + +# Solve sinkhorn with different regularizations using ot.solve +list_reg = [0.05, 0.005, 0.001] +list_P_Sin = [] + +for reg in list_reg: + P = ot.solve(M, a, b, reg=reg, max_iter=2000, tol=1e-8).plan + list_P_Sin.append(P) + +#%% Low rank sinkhorn + +# Solve low rank sinkhorn with different ranks using ot.solve_sample +list_rank = [3, 10, 50] +list_P_LR = [] + +for rank in list_rank: + P = ot.solve_sample(X, Y, a, b, method='lowrank', rank=rank).plan + P = P[:] + list_P_LR.append(P) + + +#%% + +# Plot sinkhorn vs low rank sinkhorn +pl.figure(1, figsize=(10, 4)) + +pl.subplot(1, 3, 1) +pl.imshow(list_P_Sin[0], interpolation='nearest') +pl.axis('off') +pl.title('Sinkhorn (reg=0.05)') + +pl.subplot(1, 3, 2) +pl.imshow(list_P_Sin[1], interpolation='nearest') +pl.axis('off') +pl.title('Sinkhorn (reg=0.005)') + +pl.subplot(1, 3, 3) +pl.imshow(list_P_Sin[2], interpolation='nearest') +pl.axis('off') +pl.title('Sinkhorn (reg=0.001)') +pl.show() + + +#%% + +pl.figure(2, figsize=(10, 4)) + +pl.subplot(1, 3, 1) +pl.imshow(list_P_LR[0], interpolation='nearest') +pl.axis('off') +pl.title('Low rank (rank=3)') + +pl.subplot(1, 3, 2) +pl.imshow(list_P_LR[1], interpolation='nearest') +pl.axis('off') +pl.title('Low rank (rank=10)') + +pl.subplot(1, 3, 3) +pl.imshow(list_P_LR[2], interpolation='nearest') +pl.axis('off') +pl.title('Low rank (rank=50)') + +pl.tight_layout() diff --git a/ot/lowrank.py b/ot/lowrank.py index 5c8f673cb..f6c1469bd 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -8,14 +8,142 @@ import warnings -from .utils import unif, get_lowrank_lazytensor +from .utils import unif, dist, get_lowrank_lazytensor from .backend import get_backend +from .bregman import sinkhorn +# test if sklearn is installed for linux-minimal-deps +try: + import sklearn.cluster + sklearn_import = True +except ImportError: + sklearn_import = False -def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): + +def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init, random_state, nx=None): + """ + Implementation of different initialization strategies for the low rank sinkhorn solver (Q ,R, g). + This function is specific to lowrank_sinkhorn. + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + rank : int + Nonnegative rank of the OT plan. + init : str + Initialization strategy for Q, R and g. 'random', 'trivial' or 'kmeans' + reg_init : float, optional. + Regularization term for a 'kmeans' init. + random_state : int, optional. + Random state for a "random" or 'kmeans' init strategy + nx : optional, Default is None + POT backend + + + Returns + --------- + Q : array-like, shape (n_samples_a, r) + Init for the first low-rank matrix decomposition of the OT plan (Q) + R: array-like, shape (n_samples_b, r) + Init for the second low-rank matrix decomposition of the OT plan (R) + g : array-like, shape (r, ) + Init for the weight vector of the low-rank decomposition of the OT plan (g) + + + References + ----------- + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). + "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. + + """ + + if nx is None: + nx = get_backend(X_s, X_t, a, b) + + ns = X_s.shape[0] + nt = X_t.shape[0] + r = rank + + if init == "random": + nx.seed(seed=random_state) + + # Init g + g = nx.abs(nx.randn(r, type_as=X_s)) + 1 + g = g / nx.sum(g) + + # Init Q + Q = nx.abs(nx.randn(ns, r, type_as=X_s)) + 1 + Q = (Q.T * (a / nx.sum(Q, axis=1))).T + + # Init R + R = nx.abs(nx.randn(nt, rank, type_as=X_s)) + 1 + R = (R.T * (b / nx.sum(R, axis=1))).T + + if init == "deterministic": + # Init g + g = nx.ones(rank) / rank + + lambda_1 = min(nx.min(a), nx.min(g), nx.min(b)) / 2 + a1 = nx.arange(start=1, stop=ns + 1, type_as=X_s) + a1 = a1 / nx.sum(a1) + a2 = (a - lambda_1 * a1) / (1 - lambda_1) + + b1 = nx.arange(start=1, stop=nt + 1, type_as=X_s) + b1 = b1 / nx.sum(b1) + b2 = (b - lambda_1 * b1) / (1 - lambda_1) + + g1 = nx.arange(start=1, stop=rank + 1, type_as=X_s) + g1 = g1 / nx.sum(g1) + g2 = (g - lambda_1 * g1) / (1 - lambda_1) + + # Init Q + Q1 = lambda_1 * nx.dot(a1[:, None], nx.reshape(g1, (1, -1))) + Q2 = (1 - lambda_1) * nx.dot(a2[:, None], nx.reshape(g2, (1, -1))) + Q = Q1 + Q2 + + # Init R + R1 = lambda_1 * nx.dot(b1[:, None], nx.reshape(g1, (1, -1))) + R2 = (1 - lambda_1) * nx.dot(b2[:, None], nx.reshape(g2, (1, -1))) + R = R1 + R2 + + if init == "kmeans": + if sklearn_import: + # Init g + g = nx.ones(rank, type_as=X_s) / rank + + # Init Q + kmeans_Xs = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto") + kmeans_Xs.fit(X_s) + Z_Xs = nx.from_numpy(kmeans_Xs.cluster_centers_) + C_Xs = dist(X_s, Z_Xs) # shape (ns, rank) + C_Xs = C_Xs / nx.max(C_Xs) + Q = sinkhorn(a, g, C_Xs, reg=reg_init, numItermax=10000, stopThr=1e-3) + + # Init R + kmeans_Xt = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto") + kmeans_Xt.fit(X_t) + Z_Xt = nx.from_numpy(kmeans_Xt.cluster_centers_) + C_Xt = dist(X_t, Z_Xt) # shape (nt, rank) + C_Xt = C_Xt / nx.max(C_Xt) + R = sinkhorn(b, g, C_Xt, reg=reg_init, numItermax=10000, stopThr=1e-3) + + else: + raise ImportError("Scikit-learn should be installed to use the 'kmeans' init.") + + return Q, R, g + + +def compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost, nx=None): """ Compute the low rank decomposition of a squared euclidean distance matrix. - This function won't work for any other distance metric. + This function won't work for other distance metrics. See "Section 3.5, proposition 1" @@ -25,7 +153,10 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain - nx : POT backend, default none + rescale_cost : bool + Rescale the low rank factorization of the sqeuclidean cost matrix + nx : default None + POT backend Returns @@ -37,9 +168,9 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): References - ---------- + ----------- .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). - "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. + "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. """ if nx is None: @@ -50,14 +181,18 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): # First low rank decomposition of the cost matrix (A) array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1)) - array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1)) + array2 = nx.ones((ns, 1), type_as=X_s) M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1) # Second low rank decomposition of the cost matrix (B) - array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1)) + array1 = nx.ones((nt, 1), type_as=X_s) array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1)) M2 = nx.concatenate((array1, array2, X_t), axis=1) + if rescale_cost is True: + M1 = M1 / nx.sqrt(nx.max(M1)) + M2 = M2 / nx.sqrt(nx.max(M2)) + return M1, M2 @@ -103,7 +238,7 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N References ---------- .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). - "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. + "Low-rank Sinkhorn Factorization". In International Conference on Machine Learning. """ @@ -163,7 +298,7 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N else: if warn: warnings.warn( - "Sinkhorn did not converge. You might want to " + "Dykstra did not converge. You might want to " "increase the number of iterations `numItermax` " ) @@ -174,10 +309,12 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N return Q, R, g -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None, - numItermax=1000, stopThr=1e-9, warn=True, log=False): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, rescale_cost=True, + init="random", reg_init=1e-1, seed_init=49, gamma_init="rescale", + numItermax=2000, stopThr=1e-7, warn=True, log=False): r""" - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints + on the couplings. The function solves the following optimization problem: @@ -207,14 +344,26 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None, samples weights in the target domain reg : float, optional Regularization term >0 - rank: int, optional. Default is None. (>0) + rank : int, optional. Default is None. (>0) Nonnegative rank of the OT plan. If None, min(ns, nt) is considered. - alpha: int, optional. Default is None. (>0 and <1/r) - Lower bound for the weight vector g. If None, 1e-10 is considered - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) + alpha : int, optional. Default is 1e-10. (>0 and <1/r) + Lower bound for the weight vector g. + rescale_cost : bool, optional. Default is False + Rescale the low rank factorization of the sqeuclidean cost matrix + init : str, optional. Default is 'random'. + Initialization strategy for the low rank couplings. 'random', 'deterministic' or 'kmeans' + reg_init : float, optional. Default is 1e-1. (>0) + Regularization term for a 'kmeans' init. If None, 1 is considered. + seed_init : int, optional. Default is 49. (>0) + Random state for a 'random' or 'kmeans' init strategy. + gamma_init : str, optional. Default is "rescale". + Initialization strategy for gamma. 'rescale', or 'theory' + Gamma is a constant that scales the convergence criterion of the Mirror Descent + optimization scheme used to compute the low-rank couplings (Q, R and g) + numItermax : int, optional. Default is 2000. + Max number of iterations for the Dykstra algorithm + stopThr : float, optional. Default is 1e-7. + Stop threshold on error (>0) in Dykstra warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. log : bool, optional @@ -222,26 +371,21 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None, Returns - ------- - lazy_plan : LazyTensor() - OT plan in a LazyTensor object of shape (shape_plan) - See :any:`LazyTensor` for more information. - value : float - Optimal value of the optimization problem - value_linear : float - Linear OT loss with the optimal OT + --------- Q : array-like, shape (n_samples_a, r) First low-rank matrix decomposition of the OT plan R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - Weight vector for the low-rank decomposition of the OT plan + Weight vector for the low-rank decomposition of the OT + log : dict (lazy_plan, value and value_linear) + log dictionary return only if log==True in parameters References ---------- - .. [65] Scetbon, M., Cuturi, M., & Peyré, G (2021). - "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). + "Low-rank Sinkhorn Factorization". In International Conference on Machine Learning. """ @@ -259,59 +403,70 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None, r = rank if rank is None: r = min(ns, nt) + else: + r = min(ns, nt, rank) - if alpha is None: - alpha = 1e-10 + if r <= 0: + raise ValueError("The rank parameter cannot have a negative value") - # Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank) - # (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper) + # Dykstra won't converge if 1/rank < alpha (see Section 3.2) if 1 / r < alpha: raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( a=alpha, r=1 / rank)) - if r <= 0: - raise ValueError("The rank parameter cannot have a negative value") + # Low rank decomposition of the sqeuclidean cost matrix + M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost, nx) - # Low rank decomposition of the sqeuclidean cost matrix (A, B) - M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None) + # Initialize the low rank matrices Q, R, g + Q, R, g = _init_lr_sinkhorn(X_s, X_t, a, b, r, init, reg_init, seed_init, nx=nx) - # Compute gamma (see "Section 3.4, proposition 4" in the paper) - L = nx.sqrt( - 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + - (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 - ) - gamma = 1 / (2 * L) + # Gamma initialization + if gamma_init == "theory": + L = nx.sqrt( + 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + + (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 + ) + gamma = 1 / (2 * L) - # Initialize the low rank matrices Q, R, g - Q = nx.ones((ns, r), type_as=a) - R = nx.ones((nt, r), type_as=a) - g = nx.ones(r, type_as=a) - k = 100 + if gamma_init not in ["rescale", "theory"]: + raise (NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init))) # -------------------------- Low rank algorithm ------------------------------ - # see "Section 3.3, Algorithm 3 LOT" in the paper + # see "Section 3.3, Algorithm 3 LOT" - for ii in range(k): - # Compute the C*R dot matrix using the lr decomposition of C - CR_ = nx.dot(M2.T, R) - CR = nx.dot(M1, CR_) + for ii in range(100): + # Compute C*R dot using the lr decomposition of C + CR = nx.dot(M2.T, R) + CR_ = nx.dot(M1, CR) + diag_g = (1 / g)[None, :] + CR_g = CR_ * diag_g - # Compute the C.t * Q dot matrix using the lr decomposition of C - CQ_ = nx.dot(M1.T, Q) - CQ = nx.dot(M2, CQ_) + # Compute C.T * Q using the lr decomposition of C + CQ = nx.dot(M1.T, Q) + CQ_ = nx.dot(M2, CQ) + CQ_g = CQ_ * diag_g - diag_g = (1 / g)[None, :] + # Compute omega + omega = nx.diag(nx.dot(Q.T, CR_)) + + # Rescale gamma at each iteration + if gamma_init == "rescale": + norm_1 = nx.max(nx.abs(CR_ * diag_g + reg * nx.log(Q))) ** 2 + norm_2 = nx.max(nx.abs(CQ_ * diag_g + reg * nx.log(R))) ** 2 + norm_3 = nx.max(nx.abs(-omega * diag_g)) ** 2 + gamma = 10 / max(norm_1, norm_2, norm_3) - eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q)) - eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g)) + eps1 = nx.exp(-gamma * CR_g - ((gamma * reg) - 1) * nx.log(Q)) + eps2 = nx.exp(-gamma * CQ_g - ((gamma * reg) - 1) * nx.log(R)) + eps3 = nx.exp((gamma * omega / (g**2)) - (gamma * reg - 1) * nx.log(g)) + # LR Dykstra algorithm Q, R, g = _LR_Dysktra( eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx ) Q = Q + 1e-16 R = R + 1e-16 + g = g + 1e-16 # ----------------- Compute lazy_plan, value and value_linear ------------------ # see "Section 3.2: The Low-rank OT Problem" in the paper @@ -324,7 +479,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None, v2 = nx.dot(R, (v1.T * diag_g).T) value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) - # Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper) + # Compute value with entropy reg (see "Section 3.2" in the paper) reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R diff --git a/ot/solvers.py b/ot/solvers.py index 40a03e974..c4c0c79ed 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1173,6 +1173,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). + Low-rank Sinkhorn Factorization. In International Conference on + Machine Learning. + """ @@ -1255,13 +1259,13 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) if max_iter is None: - max_iter = 1000 + max_iter = 2000 if tol is None: - tol = 1e-9 + tol = 1e-7 if reg is None: reg = 0 - Q, R, g, log = lowrank_sinkhorn(X_a, X_b, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) + Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) value = log['value'] value_linear = log['value_linear'] lazy_plan = log['lazy_plan'] diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 65f76a77b..60b2d633f 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -7,6 +7,7 @@ import ot import numpy as np import pytest +from ot.lowrank import sklearn_import # check sklearn installation def test_compute_lr_sqeuclidean_matrix(): @@ -15,7 +16,7 @@ def test_compute_lr_sqeuclidean_matrix(): X_s = np.reshape(1.0 * np.arange(2 * n), (n, 2)) X_t = np.reshape(1.0 * np.arange(2 * n), (n, 2)) - M1, M2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t) + M1, M2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost=False) M = ot.dist(X_s, X_t, metric="sqeuclidean") # original cost matrix np.testing.assert_allclose(np.dot(M1, M2.T), M, atol=1e-05) @@ -30,7 +31,7 @@ def test_lowrank_sinkhorn(): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(n), (n, 1)) - Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True) + Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False) P = log["lazy_plan"][:] value_linear = log["value_linear"] @@ -52,6 +53,30 @@ def test_lowrank_sinkhorn(): ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1) +@pytest.mark.parametrize(("init"), ("random", "deterministic", "kmeans")) +def test_lowrank_sinkhorn_init(init): + # test lowrank inits + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(n), (n, 1)) + + # test ImportError if init="kmeans" and sklearn not imported + if init in ["random", "deterministic"] or ((init == "kmeans") and (sklearn_import is True)): + Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, init=init, log=True) + P = log["lazy_plan"][:] + + # check constraints for P + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + else: + with pytest.raises(ImportError): + Q, R, g = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, init=init) + + @pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6))) def test_lowrank_sinkhorn_alpha_error(alpha, rank): # Test warning for value of alpha @@ -63,9 +88,25 @@ def test_lowrank_sinkhorn_alpha_error(alpha, rank): X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) with pytest.raises(ValueError): - ot.lowrank.lowrank_sinkhorn( - X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False - ) + ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) + + +@pytest.mark.parametrize(("gamma_init"), ("rescale", "theory")) +def test_lowrank_sinkhorn_gamma_init(gamma_init): + # Test lr sinkhorn with different init strategies + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(n), (n, 1)) + + Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + P = log["lazy_plan"][:] + + # check constraints for P + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) @pytest.skip_backend('tf') diff --git a/test/test_solvers.py b/test/test_solvers.py index 343220c45..164989811 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -30,7 +30,7 @@ {'method': 'gaussian'}, {'method': 'gaussian', 'reg': 1}, {'method': 'factored', 'rank': 10}, - {'method': 'lowrank', 'reg': 0.1} + {'method': 'lowrank', 'rank': 10} ] lst_parameters_solve_sample_NotImplemented = [