From f49f6b4f34ddd3a2313e1df00c487bd7f47df845 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 14:54:42 +0200 Subject: [PATCH 01/22] new file for lr sinkhorn --- ot/lowrank.py | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 ot/lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..ba46cd1ed --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,171 @@ +################################################################################################################# +############################################## WORK IN PROGRESS ################################################# +################################################################################################################# + + +from ot.utils import unif, list_to_array +from ot.backend import get_backend +from ot.datasets import make_1D_gauss as gauss + + + +################################## LR-DYSKTRA ALGORITHM ########################################## + +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): + """ + Implementation of the Dykstra algorithm for low rank Sinkhorn + """ + + # get dykstra parameters + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + + # POT backend + eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) + q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) + + nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) + + # ------- Dykstra algorithm ------ + g_ = eps3 + + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1, v2 + g_ = g + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + # Compute low rank matrices Q, R + Q = u1[:,None] * eps1 * v1[None,:] + R = u2[:,None] * eps2 * v2[None,:] + + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + + return Q, R, g, err, dykstra_w + + + +#################################### LOW RANK SINKHORN ALGORITHM ######################################### + + +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): + r''' + Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + + Returns + ------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + ... + + ''' + + X_s, X_t = list_to_array(X_s, X_t) + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + M = ot.dist(X_s,X_t, metric=metric) + + # Compute rank + r = min(ns, nt, r) + + # Compute gamma + L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + gamma = 1/(2*L) + + # Initialisation + Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + q3_1, q3_2 = nx.ones(r), nx.ones(r) + v1_, v2_ = nx.ones(r), nx.ones(r) + q1, q2 = nx.ones(r), nx.ones(r) + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + n_iter = 0 + err = 1 + + while n_iter < numIterMax: + if err > stopThr: + n_iter = n_iter + 1 + + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + else: + break + + return Q, R, g + + + + + +############################################################################ +## Test with X_s, X_t from ot.datasets +############################################################################# + +import numpy as np +import ot + +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + + +Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + +print(np.sum(P)) + + + + From 3c4b50fdb660f27cc080618edb664d17086d93a9 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 16:47:21 +0200 Subject: [PATCH 02/22] lr sinkhorn, solve_sample, OTResultLazy --- ot/lowrank.py | 40 +++++++------ ot/solvers.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++++++ ot/utils.py | 90 ++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+), 19 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index ba46cd1ed..a1c73bdf3 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -2,8 +2,10 @@ ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# +## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms -from ot.utils import unif, list_to_array + +from ot.utils import unif, list_to_array, dist from ot.backend import get_backend from ot.datasets import make_1D_gauss as gauss @@ -11,13 +13,13 @@ ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ Implementation of the Dykstra algorithm for low rank Sinkhorn """ # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p # POT backend eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) @@ -58,18 +60,18 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - return Q, R, g, err, dykstra_w + return Q, R, g, err, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. Parameters ---------- @@ -95,7 +97,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - ... + Third low-rank matrix decomposition of the OT plan ''' @@ -108,7 +110,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) - M = ot.dist(X_s,X_t, metric=metric) + M = dist(X_s,X_t, metric=metric) # Compute rank r = min(ns, nt, r) @@ -122,7 +124,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', q3_1, q3_2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] n_iter = 0 err = 1 @@ -139,7 +141,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) else: break @@ -153,18 +155,18 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', ## Test with X_s, X_t from ot.datasets ############################################################################# -import numpy as np -import ot +# import numpy as np +# import ot -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# M = ot.dist(Xs,Xt) +# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) -print(np.sum(P)) +# print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..9c2746c25 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -848,3 +848,164 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + + + +################################## WORK IN PROGRESS ##################################### + +## Implementation of the ot.solve_sample function +## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) + + +from .utils import unif, list_to_array, dist, OTResultLazy +from .bregman import empirical_sinkhorn + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res_lazy : OTResultLazy() + Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. + The information can be obtained as follows: + + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.potentials : OT dual potentials + + See :any:`OTResultLazy` for more information. + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + + X_s, X_t = list_to_array(X_s,X_t) + + # detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) + pass + elif unbalanced_type.lower() in ['kl', 'l2']: + pass + elif unbalanced_type.lower() == 'tv': + pass + pass + ############################################# + + else: + # compute potentials + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + + diff --git a/ot/utils.py b/ot/utils.py index 8cbb0db25..d570b9f30 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -938,3 +938,93 @@ def citation(self): url = {http://jmlr.org/papers/v22/20-451.html} } """ + + + +############################## WORK IN PROGRESS #################################### + +## Implementation of the OTResultLazy class for ot.solve_sample() with potentials and lazy_plan as attributes + +class OTResultLazy: + def __init__(self, potentials=None, lazy_plan=None, backend=None): + + self._potentials = potentials + self._lazy_plan = lazy_plan + self._backend = backend if backend is not None else NumpyBackend() + + + # Dual potentials -------------------------------------------- + + def __repr__(self): + s = 'OTResultLazy(' + if self._lazy_plan is not None: + s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) + + if s[-1] != '(': + s = s[:-1] + ')' + else: + s = s + ')' + return s + + @property + def potentials(self): + """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. + + This pair of arrays has the same shape, numerical type + and properties as the input weights "a" and "b". + """ + if self._potentials is not None: + return self._potentials + else: + raise NotImplementedError() + + @property + def potential_a(self): + """First dual potential, associated to the "source" measure "a".""" + if self._potentials is not None: + return self._potentials[0] + else: + raise NotImplementedError() + + @property + def potential_b(self): + """Second dual potential, associated to the "target" measure "b".""" + if self._potentials is not None: + return self._potentials[1] + else: + raise NotImplementedError() + + # Transport plan ------------------------------------------- + @property + def lazy_plan(self): + """A subset of the Transport plan, encoded as a dense array.""" + + if self._lazy_plan is not None: + return self._lazy_plan + else: + raise NotImplementedError() + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ \ No newline at end of file From 3034e575c55d2ce56499be6849e1906fe52f0573 Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 25 Oct 2023 17:39:08 +0200 Subject: [PATCH 03/22] add test functions + small modif lr_sin/solve_sample --- ot/lowrank.py | 97 ++++++++++++++++++++++++++++------------- ot/solvers.py | 47 +++++++++++--------- test/test_lowrank.py | 84 ++++++++++++++++++++++++++++++++++++ test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 27 ++++++++++++ 5 files changed, 304 insertions(+), 51 deletions(-) create mode 100644 test/test_lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py index a1c73bdf3..22ff8b754 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -4,10 +4,9 @@ ## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms - -from ot.utils import unif, list_to_array, dist -from ot.backend import get_backend -from ot.datasets import make_1D_gauss as gauss +import warnings +from .utils import unif, list_to_array, dist +from .backend import get_backend @@ -15,7 +14,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ - Implementation of the Dykstra algorithm for low rank Sinkhorn + Implementation of the Dykstra algorithm for low rank sinkhorn """ # get dykstra parameters @@ -69,9 +68,12 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", + numItermax=10000, stopThr=1e-9, warn=True, verbose=False): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters ---------- @@ -79,17 +81,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain - reg : float - Regularization term >0 a : array-like, shape (n_samples_a,) samples weights in the source domain b : array-like, shape (n_samples_b,) samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional + Nonnegative rank of the OT plan + alpha: int, optional + Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) + Returns ------- Q : array-like, shape (n_samples_a, r) @@ -97,7 +104,14 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - Third low-rank matrix decomposition of the OT plan + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' @@ -110,13 +124,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + # Compute cost matrix M = dist(X_s,X_t, metric=metric) - + # Compute rank - r = min(ns, nt, r) + rank = min(ns, nt, rank) + r = rank + + if alpha == 'auto': + alpha = 1.0 / (r + 1) + + if (1/r < alpha) or (alpha < 0): + warnings.warn("The provided alpha value might lead to instabilities.") + # Compute gamma - L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) gamma = 1/(2*L) # Initialisation @@ -125,25 +148,34 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - n_iter = 0 err = 1 - while n_iter < numIterMax: - if err > stopThr: - n_iter = n_iter + 1 - - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - else: + for ii in range(numItermax): + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) + + if err < stopThr: break + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") return Q, R, g @@ -161,8 +193,13 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', # Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) # Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# ns = Xs.shape[0] +# nt = Xt.shape[0] + +# a = unif(ns) +# b = unif(nt) -# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) # M = ot.dist(Xs,Xt) # P = np.dot(Q,np.dot(np.diag(1/g),R.T)) diff --git a/ot/solvers.py b/ot/solvers.py index 9c2746c25..c176969ca 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -926,7 +926,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. The information can be obtained as follows: - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - res.potentials : OT dual potentials See :any:`OTResultLazy` for more information. @@ -975,29 +975,34 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if is_Lazy: ################# WIP #################### if reg is None or reg == 0: # EMD solver for isLazy ? - if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) - pass - elif unbalanced_type.lower() in ['kl', 'l2']: - pass - elif unbalanced_type.lower() == 'tv': - pass - pass + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + ############################################# else: - # compute potentials - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) else: # compute cost matrix M and use solve function diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..6e1f24067 --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,84 @@ +##################################################################################################### +####################################### WORK IN PROGRESS ############################################ +##################################################################################################### + + +""" Test for low rank sinkhorn solvers """ + +import ot +import numpy as np +import pytest +from itertools import product + + +def test_LR_Dykstra(): + # test for LR_Dykstra algorithm ? catch nan values ? + pass + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + + Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') + P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + + # check constraints + np.testing.assert_allclose( + a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + a, P_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + b, P_m.sum(0), atol=1e-05) # metric euclidian + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + + + +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +def test_lowrank_sinkhorn_alpha_warning(alpha,rank): + # test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + + + +def test_lowrank_sinkhorn_backends(nx): + # test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) + P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..5a05d54cf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,103 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +########################################################################################################### +############################################ WORK IN PROGRESS ############################################# +########################################################################################################### + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 40324518e..a14be460e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -401,3 +401,30 @@ def test_get_coordinate_circle(): x_p = ot.utils.get_coordinate_circle(x) np.testing.assert_allclose(u[0], x_p) + + + +############################################################################################## +##################################### WORK IN PROGRESS ####################################### +############################################################################################## + +# test function for OTResultLazy + +def test_OTResultLazy(): + + res_lazy = ot.utils.OTResultLazy() + + # test print + print(res_lazy) + + # tets get citation + print(res_lazy.citation) + + lst_attributes = ['lazy_plan', + 'potential_a', + 'potential_b', + 'potentials'] + + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res_lazy, at) \ No newline at end of file From 085863aef96f0d19e740879dfae158a762275a67 Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 26 Oct 2023 10:49:23 +0200 Subject: [PATCH 04/22] add import to __init__ --- ot/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..cb00f4553 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -50,7 +51,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq From 9becafc305fd6b2cc5390b0de16bae015bd41121 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 11:38:40 +0100 Subject: [PATCH 05/22] modify low rank, remove solve_sample,OTResultLazy --- ot/__init__.py | 4 +- ot/lowrank.py | 200 ++++++++++++++++++++++++------------------- ot/solvers.py | 160 ---------------------------------- ot/utils.py | 89 ------------------- test/test_lowrank.py | 54 ++++++------ test/test_solvers.py | 97 --------------------- test/test_utils.py | 25 ------ 7 files changed, 142 insertions(+), 487 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index cb00f4553..4aba450af 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov from .lowrank import lowrank_sinkhorn # utils functions @@ -70,4 +70,4 @@ 'factored_optimal_transport', 'solve', 'solve_gromov', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn'] diff --git a/ot/lowrank.py b/ot/lowrank.py index 22ff8b754..b3fce8de0 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -1,78 +1,88 @@ +""" +Low rank OT solvers +""" + +# Author: Laurène David +# +# License: MIT License + + + ################################################################################################################# ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# -## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms import warnings -from .utils import unif, list_to_array, dist -from .backend import get_backend +from ot.utils import unif +from ot.backend import get_backend ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): """ - Implementation of the Dykstra algorithm for low rank sinkhorn + Implementation of the Dykstra algorithm for the Low rank sinkhorn solver + """ + # Get dykstra parameters + g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2 = dykstra_p + g_ = eps3.copy() + err = 1 - # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p + # POT backend if needed + if nx is None: + nx = get_backend(eps1, eps2, eps3, p1, p2, + g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2) - # POT backend - eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) - q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) - - nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) - - # ------- Dykstra algorithm ------ - g_ = eps3 - u1 = p1 / nx.dot(eps1, v1_) - u2 = p2 / nx.dot(eps2, v2_) + # ------- Dykstra algorithm ------ + while err > stopThr : + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) - g = nx.maximum(alpha, g_ * q3_1) - q3_1 = (g_ * q3_1) / g - g_ = g + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g.copy() - prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) - prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) - g = (g_ * q3_2 * prod1 * prod2)**(1/3) + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) - v1 = g / nx.dot(eps1.T,u1) - v2 = g / nx.dot(eps2.T,u2) + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + q1 = (v1_ * q1) / v1 - q1 = (v1_ * q1) / v1 - q2 = (v2_ * q2) / v2 - q3_2 = (g_ * q3_2) / g - - v1_, v2_ = v1, v2 - g_ = g + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1.copy(), v2.copy() + g_ = g.copy() - # Compute error - err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) - err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) - err = err1 + err2 + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 # Compute low rank matrices Q, R Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - return Q, R, g, err, dykstra_p + return Q, R, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", - numItermax=10000, stopThr=1e-9, warn=True, verbose=False): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", + numItermax=1000, stopThr=1e-9, warn=True, verbose=False): #stopThr = 1e-9 + r''' Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. - This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters @@ -95,6 +105,9 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea Max number of iterations stopThr : float, optional Stop threshold on error (>0) + warn: + + verbose: Returns @@ -109,73 +122,87 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea References ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' - X_s, X_t = list_to_array(X_s, X_t) nx = get_backend(X_s, X_t) - ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) + a = unif(ns, type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) + b = unif(nt, type_as=X_t) - # Compute cost matrix - M = dist(X_s,X_t, metric=metric) + d = X_s.shape[1] + + # First low rank decomposition of the cost matrix (A) + M1 = nx.zeros((ns,(d+2))) + M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] + M1[:,1] = nx.ones(ns) + M1[:,2:] = -2*X_s + + # Second low rank decomposition of the cost matrix (B) + M2 = nx.zeros((nt,(d+2))) + M2[:,0] = nx.ones(nt) + M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] + M2[:,2:] = X_t # Compute rank rank = min(ns, nt, rank) r = rank + # Alpha: lower bound for 1/rank if alpha == 'auto': - alpha = 1.0 / (r + 1) + alpha = 1e-3 # no convergence with alpha = 1 / (r+1) if (1/r < alpha) or (alpha < 0): - warnings.warn("The provided alpha value might lead to instabilities.") + warnings.warn("The provided alpha value might lead to instabilities.") - # Compute gamma - L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) + L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) gamma = 1/(2*L) - # Initialisation + # Initialize the low rank matrices Q, R, g Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + + # Initialize parameters for Dykstra algorithm q3_1, q3_2 = nx.ones(r), nx.ones(r) + u1, u2 = nx.ones(ns), nx.ones(nt) + v1, v2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - err = 1 + dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] + - for ii in range(numItermax): - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) + for ii in range(numItermax): + CR_ = nx.dot(M2.T, R) + CR = nx.dot(M1, CR_) + + CQ_ = nx.dot(M1.T, Q) + CQ = nx.dot(M2, CQ_) + diag_g = (1/g)[:,None] eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - - if err < stopThr: - break - - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, nx) + g = dykstra_p[0] + + # if verbose: + # if ii % 200 == 0: + # print( + # '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + # print('{:5d}|{:8e}|'.format(ii, err)) + + # else: + # if warn: + # warnings.warn("Sinkhorn did not converge. You might want to " + # "increase the number of iterations `numItermax` " + # "or the regularization parameter `reg`.") return Q, R, g @@ -187,24 +214,23 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea ## Test with X_s, X_t from ot.datasets ############################################################################# -# import numpy as np -# import ot - -# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +import numpy as np +import ot -# ns = Xs.shape[0] -# nt = Xt.shape[0] +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -# a = unif(ns) -# b = unif(nt) +ns = Xs.shape[0] +nt = Xt.shape[0] -# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) -# M = ot.dist(Xs,Xt) -# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +a = unif(ns) +b = unif(nt) -# print(np.sum(P)) +Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, verbose=True, numItermax=20) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index c176969ca..8d6e10a5f 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -854,163 +854,3 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, -################################## WORK IN PROGRESS ##################################### - -## Implementation of the ot.solve_sample function -## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) - - -from .utils import unif, list_to_array, dist, OTResultLazy -from .bregman import empirical_sinkhorn - - -def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False): - - r"""Solve the discrete optimal transport problem using the samples in the source and target domains. - It returns either a :any:`OTResult` or :any:`OTResultLazy` object. - - The function solves the following general optimal transport problem - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + - \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + - \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By - default ``reg=None`` and there is no regularization. The unbalanced marginal - penalization can be selected with `unbalanced` (:math:`\lambda_u`) and - `unbalanced_type`. By default ``unbalanced=None`` and the function - solves the exact optimal transport problem (respecting the marginals). - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - a : array-like, shape (dim_a,), optional - Samples weights in the source domain (default is uniform) - b : array-like, shape (dim_b,), optional - Samples weights in the source domain (default is uniform) - reg : float, optional - Regularization weight :math:`\lambda_r`, by default None (no reg., exact - OT) - reg_type : str, optional - Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT) - unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" - is_Lazy : bool, optional - Return :any:`OTResultlazy` object to reduce memory cost when True, by default False - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 - max_iter : int, optional - Maximum number of iteration, by default None (default values in each solvers) - plan_init : array_like, shape (dim_a, dim_b), optional - Initialization of the OT plan for iterative methods, by default None - potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional - Initialization of the OT dual potentials for iterative methods, by default None - tol : _type_, optional - Tolerance for solution precision, by default None (default values in each solvers) - verbose : bool, optional - Print information in the solver, by default False - - Returns - ------- - - res_lazy : OTResultLazy() - Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. - The information can be obtained as follows: - - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - - res.potentials : OT dual potentials - - See :any:`OTResultLazy` for more information. - - res : OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials - - res.value : Optimal value of the optimization problem - - res.value_linear : Linear OT loss with the optimal OT plan - - See :any:`OTResult` for more information. - - - """ - - X_s, X_t = list_to_array(X_s,X_t) - - # detect backend - arr = [X_s,X_t] - if a is not None: - arr.append(a) - if b is not None: - arr.append(b) - nx = get_backend(*arr) - - # create uniform weights if not given - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) - if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) - - # default values for solutions - potentials = None - lazy_plan = None - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if batch_size is None: - batch_size = 100 - - if is_Lazy: - ################# WIP #################### - if reg is None or reg == 0: # EMD solver for isLazy ? - - if unbalanced is None: # balanced EMD solver for isLazy ? - raise (NotImplementedError('Not implemented balanced with no regularization')) - - else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) - - - ############################################# - - else: - if unbalanced is None: - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - # compute potentials - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy - - else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) - - else: - # compute cost matrix M and use solve function - M = dist(X_s, X_t, metric) - - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) - return res - - - - diff --git a/ot/utils.py b/ot/utils.py index d570b9f30..01944f56b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -939,92 +939,3 @@ def citation(self): } """ - - -############################## WORK IN PROGRESS #################################### - -## Implementation of the OTResultLazy class for ot.solve_sample() with potentials and lazy_plan as attributes - -class OTResultLazy: - def __init__(self, potentials=None, lazy_plan=None, backend=None): - - self._potentials = potentials - self._lazy_plan = lazy_plan - self._backend = backend if backend is not None else NumpyBackend() - - - # Dual potentials -------------------------------------------- - - def __repr__(self): - s = 'OTResultLazy(' - if self._lazy_plan is not None: - s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) - - if s[-1] != '(': - s = s[:-1] + ')' - else: - s = s + ')' - return s - - @property - def potentials(self): - """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. - - This pair of arrays has the same shape, numerical type - and properties as the input weights "a" and "b". - """ - if self._potentials is not None: - return self._potentials - else: - raise NotImplementedError() - - @property - def potential_a(self): - """First dual potential, associated to the "source" measure "a".""" - if self._potentials is not None: - return self._potentials[0] - else: - raise NotImplementedError() - - @property - def potential_b(self): - """Second dual potential, associated to the "target" measure "b".""" - if self._potentials is not None: - return self._potentials[1] - else: - raise NotImplementedError() - - # Transport plan ------------------------------------------- - @property - def lazy_plan(self): - """A subset of the Transport plan, encoded as a dense array.""" - - if self._lazy_plan is not None: - return self._lazy_plan - else: - raise NotImplementedError() - - @property - def citation(self): - """Appropriate citation(s) for this result, in plain text and BibTex formats.""" - - # The string below refers to the POT library: - # successor methods may concatenate the relevant references - # to the original definitions, solvers and underlying numerical backends. - return """POT library: - - POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. - Website: https://pythonot.github.io/ - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; - - @article{flamary2021pot, - author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, - title = {{POT}: {Python} {Optimal} {Transport}}, - journal = {Journal of Machine Learning Research}, - year = {2021}, - volume = {22}, - number = {78}, - pages = {1-8}, - url = {http://jmlr.org/papers/v22/20-451.html} - } - """ \ No newline at end of file diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 6e1f24067..7d90ce9ef 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -16,34 +16,34 @@ def test_LR_Dykstra(): pass -@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -def test_lowrank_sinkhorn(verbose, warn): - # test low rank sinkhorn - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) - P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) - - Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') - P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) - - # check constraints - np.testing.assert_allclose( - a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - a, P_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - b, P_m.sum(0), atol=1e-05) # metric euclidian +# @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +# def test_lowrank_sinkhorn(verbose, warn): +# # test low rank sinkhorn +# n = 100 +# a = ot.unif(n) +# b = ot.unif(n) + +# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) +# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + +# Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) +# P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + +# Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') +# P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + +# # check constraints +# np.testing.assert_allclose( +# a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian +# np.testing.assert_allclose( +# b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian +# np.testing.assert_allclose( +# a, P_m.sum(1), atol=1e-05) # metric euclidian +# np.testing.assert_allclose( +# b, P_m.sum(0), atol=1e-05) # metric euclidian - with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) +# with pytest.warns(UserWarning): +# ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) diff --git a/test/test_solvers.py b/test/test_solvers.py index 5a05d54cf..e845ac7c2 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -258,100 +258,3 @@ def test_solve_gromov_not_implemented(nx): - -########################################################################################################### -############################################ WORK IN PROGRESS ############################################# -########################################################################################################### - -def assert_allclose_sol_sample(sol1, sol2): - # test attributes of OTResultLazy class - lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] - - nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() - nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() - - for attr in lst_attr: - try: - np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) - except NotImplementedError: - pass - - -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) -def test_solve_sample(nx): - # test solve_sample when is_Lazy = False - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t) - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b) - - # check some attributes - sol.potentials - sol.sparse_plan - sol.marginals - sol.status - - assert_allclose_sol(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb) - - assert_allclose_sol(sol, solb) - - # test not implemented unbalanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') - - # test not implemented reg_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') - - - -def test_lazy_solve_sample(nx): - # test solve_sample when is_Lazy = True - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) - - # check some attributes - sol.potentials - sol.lazy_plan - - assert_allclose_sol_sample(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) - - assert_allclose_sol_sample(sol, solb) - - # test not implemented reg==0 (or None) + balanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default - - # test not implemented reg==0 (or None) + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default - - # test not implemented reg != 0 + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index a14be460e..bbadec65a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -403,28 +403,3 @@ def test_get_coordinate_circle(): np.testing.assert_allclose(u[0], x_p) - -############################################################################################## -##################################### WORK IN PROGRESS ####################################### -############################################################################################## - -# test function for OTResultLazy - -def test_OTResultLazy(): - - res_lazy = ot.utils.OTResultLazy() - - # test print - print(res_lazy) - - # tets get citation - print(res_lazy.citation) - - lst_attributes = ['lazy_plan', - 'potential_a', - 'potential_b', - 'potentials'] - - for at in lst_attributes: - with pytest.raises(NotImplementedError): - getattr(res_lazy, at) \ No newline at end of file From 6ea251c89ecf52603eb81c798a0769e9a2cb9f54 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 14:54:42 +0200 Subject: [PATCH 06/22] new file for lr sinkhorn --- ot/lowrank.py | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 ot/lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..ba46cd1ed --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,171 @@ +################################################################################################################# +############################################## WORK IN PROGRESS ################################################# +################################################################################################################# + + +from ot.utils import unif, list_to_array +from ot.backend import get_backend +from ot.datasets import make_1D_gauss as gauss + + + +################################## LR-DYSKTRA ALGORITHM ########################################## + +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): + """ + Implementation of the Dykstra algorithm for low rank Sinkhorn + """ + + # get dykstra parameters + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + + # POT backend + eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) + q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) + + nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) + + # ------- Dykstra algorithm ------ + g_ = eps3 + + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1, v2 + g_ = g + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + # Compute low rank matrices Q, R + Q = u1[:,None] * eps1 * v1[None,:] + R = u2[:,None] * eps2 * v2[None,:] + + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + + return Q, R, g, err, dykstra_w + + + +#################################### LOW RANK SINKHORN ALGORITHM ######################################### + + +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): + r''' + Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + + Returns + ------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + ... + + ''' + + X_s, X_t = list_to_array(X_s, X_t) + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + M = ot.dist(X_s,X_t, metric=metric) + + # Compute rank + r = min(ns, nt, r) + + # Compute gamma + L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + gamma = 1/(2*L) + + # Initialisation + Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + q3_1, q3_2 = nx.ones(r), nx.ones(r) + v1_, v2_ = nx.ones(r), nx.ones(r) + q1, q2 = nx.ones(r), nx.ones(r) + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + n_iter = 0 + err = 1 + + while n_iter < numIterMax: + if err > stopThr: + n_iter = n_iter + 1 + + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + else: + break + + return Q, R, g + + + + + +############################################################################ +## Test with X_s, X_t from ot.datasets +############################################################################# + +import numpy as np +import ot + +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + + +Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + +print(np.sum(P)) + + + + From 965e4d69113f6fe8eab106412b652dabdbc05712 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 16:47:21 +0200 Subject: [PATCH 07/22] lr sinkhorn, solve_sample, OTResultLazy --- ot/lowrank.py | 40 +++++++------ ot/solvers.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++++++ ot/utils.py | 3 +- 3 files changed, 183 insertions(+), 21 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index ba46cd1ed..a1c73bdf3 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -2,8 +2,10 @@ ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# +## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms -from ot.utils import unif, list_to_array + +from ot.utils import unif, list_to_array, dist from ot.backend import get_backend from ot.datasets import make_1D_gauss as gauss @@ -11,13 +13,13 @@ ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ Implementation of the Dykstra algorithm for low rank Sinkhorn """ # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p # POT backend eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) @@ -58,18 +60,18 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - return Q, R, g, err, dykstra_w + return Q, R, g, err, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. Parameters ---------- @@ -95,7 +97,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - ... + Third low-rank matrix decomposition of the OT plan ''' @@ -108,7 +110,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) - M = ot.dist(X_s,X_t, metric=metric) + M = dist(X_s,X_t, metric=metric) # Compute rank r = min(ns, nt, r) @@ -122,7 +124,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', q3_1, q3_2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] n_iter = 0 err = 1 @@ -139,7 +141,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) else: break @@ -153,18 +155,18 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', ## Test with X_s, X_t from ot.datasets ############################################################################# -import numpy as np -import ot +# import numpy as np +# import ot -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# M = ot.dist(Xs,Xt) +# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) -print(np.sum(P)) +# print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..9c2746c25 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -848,3 +848,164 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + + + +################################## WORK IN PROGRESS ##################################### + +## Implementation of the ot.solve_sample function +## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) + + +from .utils import unif, list_to_array, dist, OTResultLazy +from .bregman import empirical_sinkhorn + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res_lazy : OTResultLazy() + Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. + The information can be obtained as follows: + + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.potentials : OT dual potentials + + See :any:`OTResultLazy` for more information. + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + + X_s, X_t = list_to_array(X_s,X_t) + + # detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) + pass + elif unbalanced_type.lower() in ['kl', 'l2']: + pass + elif unbalanced_type.lower() == 'tv': + pass + pass + ############################################# + + else: + # compute potentials + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + + diff --git a/ot/utils.py b/ot/utils.py index 0936648ca..2f4cfc9e7 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1168,7 +1168,6 @@ def citation(self): } """ - class LazyTensor(object): """ A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. @@ -1233,4 +1232,4 @@ def __getitem__(self, key): return self._getitem(*k, **self.kwargs) def __repr__(self): - return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) + return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) \ No newline at end of file From fd5e26d86e484f310f55a792bca89de13bd7340f Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 25 Oct 2023 17:39:08 +0200 Subject: [PATCH 08/22] add test functions + small modif lr_sin/solve_sample --- ot/lowrank.py | 97 ++++++++++++++++++++++++++++------------- ot/solvers.py | 47 +++++++++++--------- test/test_lowrank.py | 84 ++++++++++++++++++++++++++++++++++++ test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 1 + 5 files changed, 278 insertions(+), 51 deletions(-) create mode 100644 test/test_lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py index a1c73bdf3..22ff8b754 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -4,10 +4,9 @@ ## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms - -from ot.utils import unif, list_to_array, dist -from ot.backend import get_backend -from ot.datasets import make_1D_gauss as gauss +import warnings +from .utils import unif, list_to_array, dist +from .backend import get_backend @@ -15,7 +14,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ - Implementation of the Dykstra algorithm for low rank Sinkhorn + Implementation of the Dykstra algorithm for low rank sinkhorn """ # get dykstra parameters @@ -69,9 +68,12 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", + numItermax=10000, stopThr=1e-9, warn=True, verbose=False): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters ---------- @@ -79,17 +81,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain - reg : float - Regularization term >0 a : array-like, shape (n_samples_a,) samples weights in the source domain b : array-like, shape (n_samples_b,) samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional + Nonnegative rank of the OT plan + alpha: int, optional + Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) + Returns ------- Q : array-like, shape (n_samples_a, r) @@ -97,7 +104,14 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - Third low-rank matrix decomposition of the OT plan + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' @@ -110,13 +124,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + # Compute cost matrix M = dist(X_s,X_t, metric=metric) - + # Compute rank - r = min(ns, nt, r) + rank = min(ns, nt, rank) + r = rank + + if alpha == 'auto': + alpha = 1.0 / (r + 1) + + if (1/r < alpha) or (alpha < 0): + warnings.warn("The provided alpha value might lead to instabilities.") + # Compute gamma - L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) gamma = 1/(2*L) # Initialisation @@ -125,25 +148,34 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - n_iter = 0 err = 1 - while n_iter < numIterMax: - if err > stopThr: - n_iter = n_iter + 1 - - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - else: + for ii in range(numItermax): + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) + + if err < stopThr: break + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") return Q, R, g @@ -161,8 +193,13 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', # Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) # Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# ns = Xs.shape[0] +# nt = Xt.shape[0] + +# a = unif(ns) +# b = unif(nt) -# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) # M = ot.dist(Xs,Xt) # P = np.dot(Q,np.dot(np.diag(1/g),R.T)) diff --git a/ot/solvers.py b/ot/solvers.py index 9c2746c25..c176969ca 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -926,7 +926,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. The information can be obtained as follows: - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - res.potentials : OT dual potentials See :any:`OTResultLazy` for more information. @@ -975,29 +975,34 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if is_Lazy: ################# WIP #################### if reg is None or reg == 0: # EMD solver for isLazy ? - if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) - pass - elif unbalanced_type.lower() in ['kl', 'l2']: - pass - elif unbalanced_type.lower() == 'tv': - pass - pass + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + ############################################# else: - # compute potentials - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) else: # compute cost matrix M and use solve function diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..6e1f24067 --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,84 @@ +##################################################################################################### +####################################### WORK IN PROGRESS ############################################ +##################################################################################################### + + +""" Test for low rank sinkhorn solvers """ + +import ot +import numpy as np +import pytest +from itertools import product + + +def test_LR_Dykstra(): + # test for LR_Dykstra algorithm ? catch nan values ? + pass + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + + Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') + P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + + # check constraints + np.testing.assert_allclose( + a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + a, P_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + b, P_m.sum(0), atol=1e-05) # metric euclidian + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + + + +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +def test_lowrank_sinkhorn_alpha_warning(alpha,rank): + # test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + + + +def test_lowrank_sinkhorn_backends(nx): + # test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) + P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..5a05d54cf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,103 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +########################################################################################################### +############################################ WORK IN PROGRESS ############################################# +########################################################################################################### + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 3a9d590ab..942f403ce 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -569,3 +569,4 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + \ No newline at end of file From 3df3b77de2605d233224e0ccefa6ee127af9f040 Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 26 Oct 2023 10:49:23 +0200 Subject: [PATCH 09/22] add import to __init__ --- ot/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..cb00f4553 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -50,7 +51,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq From ab5475b894207f66e017beb07e97ff8da0d381aa Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 15:30:51 +0100 Subject: [PATCH 10/22] remove test solve_sample --- test/test_solvers.py | 105 ------------------------------------------- 1 file changed, 105 deletions(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index ff5719251..5e398d732 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -256,108 +256,3 @@ def test_solve_gromov_not_implemented(nx): with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) - - - - - - - - - -########################################################################################################### -############################################ WORK IN PROGRESS ############################################# -########################################################################################################### - -def assert_allclose_sol_sample(sol1, sol2): - # test attributes of OTResultLazy class - lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] - - nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() - nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() - - for attr in lst_attr: - try: - np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) - except NotImplementedError: - pass - - -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) -def test_solve_sample(nx): - # test solve_sample when is_Lazy = False - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t) - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b) - - # check some attributes - sol.potentials - sol.sparse_plan - sol.marginals - sol.status - - assert_allclose_sol(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb) - - assert_allclose_sol(sol, solb) - - # test not implemented unbalanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') - - # test not implemented reg_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') - - - -def test_lazy_solve_sample(nx): - # test solve_sample when is_Lazy = True - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) - - # check some attributes - sol.potentials - sol.lazy_plan - - assert_allclose_sol_sample(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) - - assert_allclose_sol_sample(sol, solb) - - # test not implemented reg==0 (or None) + balanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default - - # test not implemented reg==0 (or None) + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default - - # test not implemented reg != 0 + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file From f1c8cdd9ca88568b035d691a234497cf9f54fab5 Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 8 Nov 2023 11:08:25 +0100 Subject: [PATCH 11/22] add value, value_linear, lazy_plan --- ot/lowrank.py | 245 +++++++++++++++++++++++++++++--------------------- 1 file changed, 144 insertions(+), 101 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index d583f4741..cd060bdd2 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -14,17 +14,45 @@ import warnings -from ot.utils import unif +from ot.utils import unif, LazyTensor from ot.backend import get_backend +def compute_lr_cost_matrix(X_s, X_t, nx=None): + """ + Compute low rank decomposition of a sqeuclidean cost matrix. + This function won't work for other metrics. + + See Proposition 1 of the low rank sinkhorn paper + """ + + if nx is None: + nx = get_backend(X_s,X_t) + + ns = X_s.shape[0] + nt = X_t.shape[0] + d = X_s.shape[1] + + # First low rank decomposition of the cost matrix (A) + M1 = nx.zeros((ns,(d+2))) + M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] + M1[:,1] = nx.ones(ns) + M1[:,2:] = -2*X_s + + # Second low rank decomposition of the cost matrix (B) + M2 = nx.zeros((nt,(d+2))) + M2[:,0] = nx.ones(nt) + M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] + M2[:,2:] = X_t + + return M1, M2 + -################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, numItermax, warn, nx=None): """ - Implementation of the Dykstra algorithm for the Low rank sinkhorn solver - + Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. + """ # Get dykstra parameters g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2 = dykstra_p @@ -36,34 +64,44 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): nx = get_backend(eps1, eps2, eps3, p1, p2, g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2) - - # ------- Dykstra algorithm ------ - while err > stopThr : - u1 = p1 / nx.dot(eps1, v1_) - u2 = p2 / nx.dot(eps2, v2_) - - g = nx.maximum(alpha, g_ * q3_1) - q3_1 = (g_ * q3_1) / g - g_ = g.copy() - - prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) - prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) - g = (g_ * q3_2 * prod1 * prod2)**(1/3) - - v1 = g / nx.dot(eps1.T,u1) - v2 = g / nx.dot(eps2.T,u2) - q1 = (v1_ * q1) / v1 - - q2 = (v2_ * q2) / v2 - q3_2 = (g_ * q3_2) / g - - v1_, v2_ = v1.copy(), v2.copy() - g_ = g.copy() - - # Compute error - err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) - err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) - err = err1 + err2 + # ------------- Dykstra algorithm ---------------- + # see "Algorithm 2 LR-Dykstra" in paper + for ii in range(numItermax): + if err > stopThr: + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g.copy() + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + q1 = (v1_ * q1) / v1 + + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1.copy(), v2.copy() + g_ = g.copy() + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + else: + break + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") # Compute low rank matrices Q, R Q = u1[:,None] * eps1 * v1[None,:] @@ -80,12 +118,28 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", - numItermax=1000, stopThr=1e-9, warn=True, verbose=False): #stopThr = 1e-9 + numItermax=1000, stopThr=1e-9, warn=True, shape_plan="auto"): r''' Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. - This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. + + The function solves the following optimization problem: + + .. math:: + \mathop{\inf}_{Q,R,g \in \mathcal{C(a,b,r)}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - + \mathrm{reg} \cdot H((Q,R,g)) + + where : + - :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`H` is the entropic regularization term + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) + (add r, C(a,b,r), Q, R, g) !!! + + The entropy H is to be understood as that of the values of the three respective + entropies evaluated for each term. + Parameters ---------- X_s : array-like, shape (n_samples_a, dim) @@ -106,13 +160,21 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", Max number of iterations stopThr : float, optional Stop threshold on error (>0) - warn: - - verbose: + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + shape_plan : tuple + Shape of the lazy_plan Returns ------- + lazy_plan : + OT plan in a LazyTensor object + value : + Optimal value of the optimization problem, if reg=0 it will return the full value + if reg != 0, will return LazyTensor object + value_linear : + Linear OT loss with the optimal OT Q : array-like, shape (n_samples_a, r) First low-rank matrix decomposition of the OT plan R: array-like, shape (n_samples_b, r) @@ -127,44 +189,44 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' + # POT backend nx = get_backend(X_s, X_t) ns, nt = X_s.shape[0], X_t.shape[0] + + # Initialize weights a, b if a is None: a = unif(ns, type_as=X_s) if b is None: b = unif(nt, type_as=X_t) - - d = X_s.shape[1] - - # First low rank decomposition of the cost matrix (A) - M1 = nx.zeros((ns,(d+2))) - M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] - M1[:,1] = nx.ones(ns) - M1[:,2:] = -2*X_s - # Second low rank decomposition of the cost matrix (B) - M2 = nx.zeros((nt,(d+2))) - M2[:,0] = nx.ones(nt) - M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] - M2[:,2:] = X_t - - # Compute rank + # Low rank decomposition of the sqeuclidean cost matrix (M1, M2) + M1, M2 = compute_lr_cost_matrix(X_s, X_t, nx=None) + + # Compute rank (not sure ?) rank = min(ns, nt, rank) r = rank - # Alpha: lower bound for 1/rank + # Check values of alpha, the lower bound for 1/rank (see ) if alpha == 'auto': - alpha = 1e-3 # no convergence with alpha = 1 / (r+1) - + alpha = 1e-10 + if (1/r < alpha) or (alpha < 0): warnings.warn("The provided alpha value might lead to instabilities.") - # Compute gamma + + # Compute gamma (see Proposition 4 of low rank sinkhorn paper) L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) gamma = 1/(2*L) + + # Shape_plan default + if shape_plan == "auto": + shape_plan = (ns,nt) - # Initialize the low rank matrices Q, R, g + + # ----------- Initialisation of LR sinkhorn + Dykstra -------------- + + # Initialize the low rank matrices Q, R, g (not sure ?) Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) # Initialize parameters for Dykstra algorithm @@ -174,74 +236,55 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] + k = 100 # not specified in paper ? + + # ----------------- Low rank algorithm ------------------ - for ii in range(numItermax): + for ii in range(k): + # Compute the C*R dot matrix using the lr decomposition of C CR_ = nx.dot(M2.T, R) CR = nx.dot(M1, CR_) + # Compute the C.t * Q dot matrix using the lr decomposition of C CQ_ = nx.dot(M1.T, Q) CQ = nx.dot(M2, CQ_) - diag_g = (1/g)[:,None] + #diag_g = (1/g)[:,None] + diag_g = nx.diag(1/g) eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, nx) + Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, numItermax, warn, nx) g = dykstra_p[0] - # if verbose: - # if ii % 200 == 0: - # print( - # '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - # print('{:5d}|{:8e}|'.format(ii, err)) - # else: - # if warn: - # warnings.warn("Sinkhorn did not converge. You might want to " - # "increase the number of iterations `numItermax` " - # "or the regularization parameter `reg`.") + # ----------------- Compute lazy_plan, value and value_linear ------------------ - # Compute OT value using trace formula for scalar product + # Compute lazy plan (using LazyTensor class) + plan1 = Q + plan2 = nx.dot(nx.diag(1/g),R.T) # low memory cost since shape r*m + compute_plan = lambda i,j,P1,P2: nx.dot(P1[i,:], P2[:,j]) # function for LazyTensor + lazy_plan = LazyTensor(shape_plan, compute_plan, P1=plan1, P2=plan2) + + # Compute value_linear (using trace formula) v1 = nx.dot(Q.T,M1) v2 = nx.dot(R,nx.dot(diag_g.T,v1)) - value_linear = nx.sum(nx.diag(nx.dot(v2,M2.T))) # compute Trace - - #value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - #value - - return value_linear, Q, R, g - - - - - -############################################################################ -## Test with X_s, X_t from ot.datasets -############################################################################# - -import numpy as np -import ot - -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) -ns = Xs.shape[0] -nt = Xt.shape[0] + # Compute value with entropy reg + reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) + reg_g = nx.sum(g * nx.log(g + 1e-16)) + reg_R = nx.sum(R * nx.log(R + 1e-16)) + value = value_linear + reg * (reg_Q + reg_g + reg_R) -a = unif(ns) -b = unif(nt) + return value, value_linear, lazy_plan, Q, R, g -Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, verbose=True, numItermax=20) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) -print(np.sum(P)) From df01cff10cc449ba3bde5f1a1809a335f5f1e9e8 Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 8 Nov 2023 12:43:48 +0100 Subject: [PATCH 12/22] add comments to lr algorithm --- ot/lowrank.py | 161 ++++++++++++++++++++++++++------------------------ 1 file changed, 83 insertions(+), 78 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index cd060bdd2..4c7f4c2bb 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -7,12 +7,6 @@ # License: MIT License - -################################################################################################################# -############################################## WORK IN PROGRESS ################################################# -################################################################################################################# - - import warnings from ot.utils import unif, LazyTensor from ot.backend import get_backend @@ -23,7 +17,12 @@ def compute_lr_cost_matrix(X_s, X_t, nx=None): Compute low rank decomposition of a sqeuclidean cost matrix. This function won't work for other metrics. - See Proposition 1 of the low rank sinkhorn paper + See "Section 3.5, proposition 1" of the paper + + References + ---------- + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. """ if nx is None: @@ -49,43 +48,62 @@ def compute_lr_cost_matrix(X_s, X_t, nx=None): -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, numItermax, warn, nx=None): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): """ Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. + References + ---------- + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. + """ - # Get dykstra parameters - g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2 = dykstra_p - g_ = eps3.copy() - err = 1 - # POT backend if needed + # POT backend if None if nx is None: - nx = get_backend(eps1, eps2, eps3, p1, p2, - g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2) + nx = get_backend(eps1, eps2, eps3, p1, p2) + + + # ----------------- Initialisation of Dykstra algorithm ----------------- + r = len(eps3) # rank + g_ = eps3.copy() # \tilde{g} + q3_1, q3_2 = nx.ones(r), nx.ones(r) # q^{(3)}_1, q^{(3)}_2 + v1_, v2_ = nx.ones(r), nx.ones(r) # \tilde{v}^{(1)}, \tilde{v}^{(2)} + q1, q2 = nx.ones(r), nx.ones(r) # q^{(1)}, q^{(2)} + err = 1 # initial error + - # ------------- Dykstra algorithm ---------------- - # see "Algorithm 2 LR-Dykstra" in paper + # --------------------- Dykstra algorithm ------------------------- + + # See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper + for ii in range(numItermax): if err > stopThr: + + # Compute u^{(1)} and u^{(2)} u1 = p1 / nx.dot(eps1, v1_) u2 = p2 / nx.dot(eps2, v2_) + # Compute g, g^{(3)}_1 and update \tilde{g} g = nx.maximum(alpha, g_ * q3_1) q3_1 = (g_ * q3_1) / g g_ = g.copy() + # Compute new value of g with \prod prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) g = (g_ * q3_2 * prod1 * prod2)**(1/3) + # Compute v^{(1)} and v^{(2)} v1 = g / nx.dot(eps1.T,u1) v2 = g / nx.dot(eps2.T,u2) - q1 = (v1_ * q1) / v1 + # Compute q^{(1)}, q^{(2)} and q^{(3)}_2 + q1 = (v1_ * q1) / v1 q2 = (v2_ * q2) / v2 q3_2 = (g_ * q3_2) / g + # Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g} v1_, v2_ = v1.copy(), v2.copy() g_ = g.copy() @@ -100,16 +118,13 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, numItermax, else: if warn: warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + "increase the number of iterations `numItermax` ") # Compute low rank matrices Q, R Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - - return Q, R, dykstra_p + return Q, R, g @@ -118,7 +133,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, numItermax, def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", - numItermax=1000, stopThr=1e-9, warn=True, shape_plan="auto"): + numItermax=10000, stopThr=1e-9, warn=True, shape_plan="auto"): r''' Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. @@ -126,18 +141,20 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", The function solves the following optimization problem: .. math:: - \mathop{\inf}_{Q,R,g \in \mathcal{C(a,b,r)}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - + \mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - \mathrm{reg} \cdot H((Q,R,g)) - + where : - :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`H` is the entropic regularization term - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) - (add r, C(a,b,r), Q, R, g) !!! - - The entropy H is to be understood as that of the values of the three respective - entropies evaluated for each term. + - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. + - :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan + - :math: `g` is the weight vector for the low-rank decomposition of the OT plan + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math: `r` is the rank of the OT plan + - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem + \mathcal{C(a,b,r)} = \mathcal{C_1(a,b,r)} \cap \mathcal{C_2(r)} with + \mathcal{C_1(a,b,r)} = \{ (Q,R,g) s.t Q\mathbb{1}_r = a, R^T \mathbb{1}_m = b \} + \mathcal{C_2(r)} = \{ (Q,R,g) s.t Q\mathbb{1}_n = R^T \mathbb{1}_m = g \} Parameters @@ -152,9 +169,9 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", samples weights in the target domain reg : float, optional Regularization term >0 - rank: int, optional + rank: int, default "auto" Nonnegative rank of the OT plan - alpha: int, optional + alpha: int, default "auto" Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations @@ -168,12 +185,12 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", Returns ------- - lazy_plan : - OT plan in a LazyTensor object - value : - Optimal value of the optimization problem, if reg=0 it will return the full value - if reg != 0, will return LazyTensor object - value_linear : + lazy_plan : LazyTensor() + OT plan in a LazyTensor object of shape (shape_plan) + See :any:`LazyTensor` for more information. + value : float + Optimal value of the optimization problem, + value_linear : float Linear OT loss with the optimal OT Q : array-like, shape (n_samples_a, r) First low-rank matrix decomposition of the OT plan @@ -200,46 +217,36 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", if b is None: b = unif(nt, type_as=X_t) - # Low rank decomposition of the sqeuclidean cost matrix (M1, M2) - M1, M2 = compute_lr_cost_matrix(X_s, X_t, nx=None) - - # Compute rank (not sure ?) - rank = min(ns, nt, rank) - r = rank - - # Check values of alpha, the lower bound for 1/rank (see ) + # Compute rank (see Section 3.1, def 1) + if rank == "auto": + r = min(ns, nt) + + # Check values of alpha, the lower bound for 1/rank + # (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper) if alpha == 'auto': alpha = 1e-10 if (1/r < alpha) or (alpha < 0): - warnings.warn("The provided alpha value might lead to instabilities.") + warnings.warn("The provided alpha value might lead to instabilities.") + # Default value for shape tensor parameter in LazyTensor + if shape_plan == "auto": + shape_plan = (ns,nt) - # Compute gamma (see Proposition 4 of low rank sinkhorn paper) + # Low rank decomposition of the sqeuclidean cost matrix (A, B) + M1, M2 = compute_lr_cost_matrix(X_s, X_t, nx=None) + + # Compute gamma (see "Section 3.4, proposition 4" in the paper) L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) gamma = 1/(2*L) - - # Shape_plan default - if shape_plan == "auto": - shape_plan = (ns,nt) - - - # ----------- Initialisation of LR sinkhorn + Dykstra -------------- - # Initialize the low rank matrices Q, R, g (not sure ?) + # Initialize the low rank matrices Q, R, g Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) - - # Initialize parameters for Dykstra algorithm - q3_1, q3_2 = nx.ones(r), nx.ones(r) - u1, u2 = nx.ones(ns), nx.ones(nt) - v1, v2 = nx.ones(r), nx.ones(r) - v1_, v2_ = nx.ones(r), nx.ones(r) - q1, q2 = nx.ones(r), nx.ones(r) - dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] k = 100 # not specified in paper ? - # ----------------- Low rank algorithm ------------------ + # -------------------------- Low rank algorithm ------------------------------ + # see "Section 3.3, Algorithm 3 LOT" in the paper for ii in range(k): # Compute the C*R dot matrix using the lr decomposition of C @@ -250,24 +257,22 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", CQ_ = nx.dot(M1.T, Q) CQ = nx.dot(M2, CQ_) - #diag_g = (1/g)[:,None] diag_g = nx.diag(1/g) - + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, numItermax, warn, nx) - g = dykstra_p[0] - + Q, R, g = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx) # ----------------- Compute lazy_plan, value and value_linear ------------------ + # see "Section 3.2: The Low-rank OT Problem" in the paper # Compute lazy plan (using LazyTensor class) plan1 = Q - plan2 = nx.dot(nx.diag(1/g),R.T) # low memory cost since shape r*m + plan2 = nx.dot(nx.diag(1/g),R.T) # low memory cost since shape (r*m) compute_plan = lambda i,j,P1,P2: nx.dot(P1[i,:], P2[:,j]) # function for LazyTensor lazy_plan = LazyTensor(shape_plan, compute_plan, P1=plan1, P2=plan2) @@ -276,10 +281,10 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", v2 = nx.dot(R,nx.dot(diag_g.T,v1)) value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) - # Compute value with entropy reg - reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) - reg_g = nx.sum(g * nx.log(g + 1e-16)) - reg_R = nx.sum(R * nx.log(R + 1e-16)) + # Compute value with entropy reg (entropy of Q, R, g must be computed separatly) + reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q + reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g + reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R value = value_linear + reg * (reg_Q + reg_g + reg_R) return value, value_linear, lazy_plan, Q, R, g From 5bc9de96719de312bd41ef4df9f8df499fbe4e1b Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 9 Nov 2023 17:06:51 +0100 Subject: [PATCH 13/22] modify test functions + add comments to lowrank --- ot/lowrank.py | 6 ++-- test/test_lowrank.py | 72 ++++++++++++++++++++------------------------ 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index 4c7f4c2bb..af9f06d62 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -229,7 +229,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", if (1/r < alpha) or (alpha < 0): warnings.warn("The provided alpha value might lead to instabilities.") - # Default value for shape tensor parameter in LazyTensor + # Default value for shape tensor parameter in LazyTensor if shape_plan == "auto": shape_plan = (ns,nt) @@ -245,6 +245,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", k = 100 # not specified in paper ? + # -------------------------- Low rank algorithm ------------------------------ # see "Section 3.3, Algorithm 3 LOT" in the paper @@ -267,6 +268,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", Q, R, g = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx) + # ----------------- Compute lazy_plan, value and value_linear ------------------ # see "Section 3.2: The Low-rank OT Problem" in the paper @@ -281,7 +283,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", v2 = nx.dot(R,nx.dot(diag_g.T,v1)) value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) - # Compute value with entropy reg (entropy of Q, R, g must be computed separatly) + # Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper) reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 7d90ce9ef..7d326388b 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -11,45 +11,39 @@ from itertools import product -def test_LR_Dykstra(): - # test for LR_Dykstra algorithm ? catch nan values ? - pass - - -# @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -# def test_lowrank_sinkhorn(verbose, warn): -# # test low rank sinkhorn -# n = 100 -# a = ot.unif(n) -# b = ot.unif(n) - -# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) -# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - -# Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) -# P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) - -# Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') -# P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) - -# # check constraints -# np.testing.assert_allclose( -# a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian -# np.testing.assert_allclose( -# b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian -# np.testing.assert_allclose( -# a, P_m.sum(1), atol=1e-05) # metric euclidian -# np.testing.assert_allclose( -# b, P_m.sum(0), atol=1e-05) # metric euclidian + +################################################## WORK IN PROGRESS ####################################################### + +# Add test functions for each function in lowrank.py file ? + +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + # what to test for value, value_linear, Q, R and g ? + value, value_linear, lazy_plan, Q, R, g = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P = lazy_plan[:] # default shape for lazy_plan in lowrank_sinkhorn is (ns, nt) + + # check constraints for P + np.testing.assert_allclose( + a, P.sum(1), atol=1e-05) + np.testing.assert_allclose( + b, P.sum(0), atol=1e-05) -# with pytest.warns(UserWarning): -# ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + # check warn parameter when Dykstra algorithm doesn't converge + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) @pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) def test_lowrank_sinkhorn_alpha_warning(alpha,rank): - # test warning for value of alpha + # Test warning for value of alpha n = 100 a = ot.unif(n) b = ot.unif(n) @@ -58,12 +52,12 @@ def test_lowrank_sinkhorn_alpha_warning(alpha,rank): X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) # remove warning for lack of convergence def test_lowrank_sinkhorn_backends(nx): - # test low rank sinkhorn for different backends + # Test low rank sinkhorn for different backends n = 100 a = ot.unif(n) b = ot.unif(n) @@ -73,11 +67,11 @@ def test_lowrank_sinkhorn_backends(nx): ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) - Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) - P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + value, value_linear, lazy_plan, Q, R, g = ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1) + P = lazy_plan[:] # default shape for lazy_plan in lowrank_sinkhorn is (ns, nt) - np.testing.assert_allclose(a, P.sum(1), atol=1e-05) - np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) + np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) From 6040e6ffacb91673b4628118c23cc5b2ab44e22d Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 9 Nov 2023 17:09:40 +0100 Subject: [PATCH 14/22] modify __init__ with lowrank --- ot/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ot/__init__.py b/ot/__init__.py index 3a4f21083..4aba450af 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,7 +36,6 @@ from . import solvers from . import gaussian from . import lowrank -from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, From a7fdffd4656ecd6b5e80d37392d040734eefd462 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 14 Nov 2023 14:01:29 +0100 Subject: [PATCH 15/22] debug lowrank + test --- ot/lowrank.py | 6 +++--- test/test_lowrank.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index af9f06d62..f2160e2ce 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -8,8 +8,8 @@ import warnings -from ot.utils import unif, LazyTensor -from ot.backend import get_backend +from .utils import unif, LazyTensor +from .backend import get_backend def compute_lr_cost_matrix(X_s, X_t, nx=None): @@ -218,6 +218,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", b = unif(nt, type_as=X_t) # Compute rank (see Section 3.1, def 1) + r = rank if rank == "auto": r = min(ns, nt) @@ -294,4 +295,3 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", - diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 7d326388b..b1aa64e32 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -16,7 +16,7 @@ # Add test functions for each function in lowrank.py file ? -def test_lowrank_sinkhorn(verbose, warn): +def test_lowrank_sinkhorn(): # test low rank sinkhorn n = 100 a = ot.unif(n) From d90c186668e66d0c0ec8b496e08a589089a436f2 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 14 Nov 2023 14:13:32 +0100 Subject: [PATCH 16/22] debug test function low_rank --- test/test_lowrank.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_lowrank.py b/test/test_lowrank.py index b1aa64e32..22c6c2f88 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -23,7 +23,7 @@ def test_lowrank_sinkhorn(): b = ot.unif(n) X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(n), (n, 1)) # what to test for value, value_linear, Q, R and g ? value, value_linear, lazy_plan, Q, R, g = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) @@ -52,7 +52,7 @@ def test_lowrank_sinkhorn_alpha_warning(alpha,rank): X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) # remove warning for lack of convergence + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, rank=rank, alpha=alpha, warn=False) # remove warning for lack of convergence From ea3a3e0948ec60c7c8b94df1ca36497aa8138408 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 14 Nov 2023 14:31:35 +0100 Subject: [PATCH 17/22] error test --- ot/lowrank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index f2160e2ce..c6653f6c8 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -132,7 +132,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", numItermax=10000, stopThr=1e-9, warn=True, shape_plan="auto"): r''' From 165e8f5352dc8688a7f9df16e915d8c0df2d597f Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 15 Nov 2023 16:03:06 +0100 Subject: [PATCH 18/22] final debug of lowrank + add new test functions --- ot/lowrank.py | 15 +++++----- test/test_lowrank.py | 66 +++++++++++++++++++++++++++----------------- 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index c6653f6c8..ed526cf1c 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -171,7 +171,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", Regularization term >0 rank: int, default "auto" Nonnegative rank of the OT plan - alpha: int, default "auto" + alpha: int, default "auto" (1e-10) Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations @@ -221,14 +221,14 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", r = rank if rank == "auto": r = min(ns, nt) + + if alpha == "auto": + alpha = 1e-10 - # Check values of alpha, the lower bound for 1/rank + # Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank) # (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper) - if alpha == 'auto': - alpha = 1e-10 - - if (1/r < alpha) or (alpha < 0): - warnings.warn("The provided alpha value might lead to instabilities.") + if 1/r < alpha : + raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format(a=alpha,r=1/rank)) # Default value for shape tensor parameter in LazyTensor if shape_plan == "auto": @@ -295,3 +295,4 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", + diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 22c6c2f88..820c147bf 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -8,13 +8,23 @@ import ot import numpy as np import pytest -from itertools import product ################################################## WORK IN PROGRESS ####################################################### -# Add test functions for each function in lowrank.py file ? +def test_compute_lr_cost_matrix(): + # test computation of low rank cost matrices M1 and M2 + n = 100 + X_s = np.reshape(1.0 * np.arange(2*n), (n, 2)) + X_t = np.reshape(1.0 * np.arange(2*n), (n, 2)) + + M1, M2 = ot.lowrank.compute_lr_cost_matrix(X_s, X_t) + M = ot.dist(X_s, X_t, metric="sqeuclidean") # original cost matrix + + np.testing.assert_allclose( + np.dot(M1,M2.T), M, atol=1e-05) + def test_lowrank_sinkhorn(): # test low rank sinkhorn @@ -27,21 +37,29 @@ def test_lowrank_sinkhorn(): # what to test for value, value_linear, Q, R and g ? value, value_linear, lazy_plan, Q, R, g = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P = lazy_plan[:] # default shape for lazy_plan in lowrank_sinkhorn is (ns, nt) # check constraints for P - np.testing.assert_allclose( - a, P.sum(1), atol=1e-05) - np.testing.assert_allclose( - b, P.sum(0), atol=1e-05) + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + # check if lazy_plan is equal to the fully computed plan + P_true = np.dot(Q,np.dot(np.diag(1/g),R.T)) + np.testing.assert_allclose(P, P_true, atol=1e-05) + + # check if value_linear is correct with its original formula + M = ot.dist(X_s, X_t, metric="sqeuclidean") + value_linear_true = np.sum(M * P_true) + np.testing.assert_allclose(value_linear, value_linear_true, atol=1e-05) # check warn parameter when Dykstra algorithm doesn't converge with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1) -@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,6))) def test_lowrank_sinkhorn_alpha_warning(alpha,rank): # Test warning for value of alpha n = 100 @@ -51,28 +69,24 @@ def test_lowrank_sinkhorn_alpha_warning(alpha,rank): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, rank=rank, alpha=alpha, warn=False) # remove warning for lack of convergence - - - -def test_lowrank_sinkhorn_backends(nx): - # Test low rank sinkhorn for different backends - n = 100 - a = ot.unif(n) - b = ot.unif(n) + with pytest.raises(ValueError): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) - - value, value_linear, lazy_plan, Q, R, g = ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1) - P = lazy_plan[:] # default shape for lazy_plan in lowrank_sinkhorn is (ns, nt) - np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) - np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) +# def test_lowrank_sinkhorn_backends(nx): +# # Test low rank sinkhorn for different backends +# n = 100 +# a = ot.unif(n) +# b = ot.unif(n) +# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) +# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) +# ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) +# value, value_linear, lazy_plan, Q, R, g = lowrank_sinkhorn(X_sb, X_tb, ab, bb, reg=0.1) +# P = lazy_plan[:] # default shape for lazy_plan in lowrank_sinkhorn is (ns, nt) +# np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) +# np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) \ No newline at end of file From 8c6ac67a715096f9d8b1cc41e0dca73dbfff8774 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 24 Nov 2023 15:43:40 +0100 Subject: [PATCH 19/22] Debug tests + add lowrank to solve_sample --- CONTRIBUTORS.md | 1 + README.md | 2 ++ RELEASES.md | 1 + ot/__init__.py | 3 +- ot/lowrank.py | 67 ++++++++++++++++++++++---------------------- ot/solvers.py | 21 ++++++++++++++ test/test_lowrank.py | 37 ++++++++++++------------ test/test_solvers.py | 6 ++-- 8 files changed, 84 insertions(+), 54 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c7916f50a..5cc34f38b 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -45,6 +45,7 @@ The contributors to this library are: * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) +* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn) ## Acknowledgments diff --git a/README.md b/README.md index 84b3cf0ee..94bf97043 100644 --- a/README.md +++ b/README.md @@ -343,3 +343,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462. [62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning. + +[63] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). diff --git a/RELEASES.md b/RELEASES.md index 349c56214..befda9d30 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -20,6 +20,7 @@ + Wrapper for `geomloss`` solver on empirical samples (PR #571) + Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578) + Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578) ++ Added support for [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf) (PR #568) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/__init__.py b/ot/__init__.py index bd26d96e3..99d075e5a 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions @@ -51,7 +52,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample from .lowrank import lowrank_sinkhorn # utils functions diff --git a/ot/lowrank.py b/ot/lowrank.py index ed526cf1c..b2e443b74 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -8,11 +8,11 @@ import warnings -from .utils import unif, LazyTensor +from .utils import unif, get_lowrank_lazytensor from .backend import get_backend -def compute_lr_cost_matrix(X_s, X_t, nx=None): +def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): """ Compute low rank decomposition of a sqeuclidean cost matrix. This function won't work for other metrics. @@ -21,8 +21,8 @@ def compute_lr_cost_matrix(X_s, X_t, nx=None): References ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). - Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. + .. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). + "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. """ if nx is None: @@ -54,9 +54,9 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No References ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). - Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. - + .. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). + "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. + """ # POT backend if None @@ -66,7 +66,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No # ----------------- Initialisation of Dykstra algorithm ----------------- r = len(eps3) # rank - g_ = eps3.copy() # \tilde{g} + g_ = nx.copy(eps3) # \tilde{g} q3_1, q3_2 = nx.ones(r), nx.ones(r) # q^{(3)}_1, q^{(3)}_2 v1_, v2_ = nx.ones(r), nx.ones(r) # \tilde{v}^{(1)}, \tilde{v}^{(2)} q1, q2 = nx.ones(r), nx.ones(r) # q^{(1)}, q^{(2)} @@ -87,7 +87,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No # Compute g, g^{(3)}_1 and update \tilde{g} g = nx.maximum(alpha, g_ * q3_1) q3_1 = (g_ * q3_1) / g - g_ = g.copy() + g_ = nx.copy(g) # Compute new value of g with \prod prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) @@ -104,8 +104,8 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No q3_2 = (g_ * q3_2) / g # Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g} - v1_, v2_ = v1.copy(), v2.copy() - g_ = g.copy() + v1_, v2_ = nx.copy(v1), nx.copy(v2) + g_ = nx.copy(g) # Compute error err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) @@ -133,7 +133,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", - numItermax=10000, stopThr=1e-9, warn=True, shape_plan="auto"): + numItermax=10000, stopThr=1e-9, warn=True, log=False): r''' Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. @@ -179,8 +179,8 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", Stop threshold on error (>0) warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. - shape_plan : tuple - Shape of the lazy_plan + log : bool, optional + record log if True Returns @@ -202,8 +202,8 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", References ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). - Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. + .. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). + "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. ''' @@ -230,12 +230,8 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", if 1/r < alpha : raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format(a=alpha,r=1/rank)) - # Default value for shape tensor parameter in LazyTensor - if shape_plan == "auto": - shape_plan = (ns,nt) - # Low rank decomposition of the sqeuclidean cost matrix (A, B) - M1, M2 = compute_lr_cost_matrix(X_s, X_t, nx=None) + M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None) # Compute gamma (see "Section 3.4, proposition 4" in the paper) L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) @@ -246,7 +242,6 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", k = 100 # not specified in paper ? - # -------------------------- Low rank algorithm ------------------------------ # see "Section 3.3, Algorithm 3 LOT" in the paper @@ -259,29 +254,27 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", CQ_ = nx.dot(M1.T, Q) CQ = nx.dot(M2, CQ_) - diag_g = nx.diag(1/g) - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + diag_g = (1/g)[None,:] + + eps1 = nx.exp(-gamma*(CR*diag_g) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(CQ*diag_g) - ((gamma*reg)-1)*nx.log(R)) omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) Q, R, g = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx) - + Q = Q+1e-16 + R = R+1e-16 # ----------------- Compute lazy_plan, value and value_linear ------------------ # see "Section 3.2: The Low-rank OT Problem" in the paper # Compute lazy plan (using LazyTensor class) - plan1 = Q - plan2 = nx.dot(nx.diag(1/g),R.T) # low memory cost since shape (r*m) - compute_plan = lambda i,j,P1,P2: nx.dot(P1[i,:], P2[:,j]) # function for LazyTensor - lazy_plan = LazyTensor(shape_plan, compute_plan, P1=plan1, P2=plan2) + lazy_plan = get_lowrank_lazytensor(Q, R, 1/g) # Compute value_linear (using trace formula) v1 = nx.dot(Q.T,M1) - v2 = nx.dot(R,nx.dot(diag_g.T,v1)) + v2 = nx.dot(R,(v1.T*diag_g).T) value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) # Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper) @@ -290,7 +283,15 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R value = value_linear + reg * (reg_Q + reg_g + reg_R) - return value, value_linear, lazy_plan, Q, R, g + if log: + dict_log = dict() + dict_log["value"] = value + dict_log["value_linear"] = value_linear + dict_log["lazy_plan"] = lazy_plan + + return Q, R, g, dict_log + + return Q, R, g diff --git a/ot/solvers.py b/ot/solvers.py index a41762a5c..958b951d1 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -22,6 +22,7 @@ from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport +from .lowrank import lowrank_sinkhorn lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] @@ -1247,6 +1248,26 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] + + elif method == "lowrank": + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if reg is None: + reg = 0 + + Q, R, g, log = lowrank_sinkhorn(X_a, X_b, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) + value = log['value'] + value_linear = log['value_linear'] + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + elif method.startswith('geomloss'): # Geomloss solver for entropi OT diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 43c3655f4..e3ffe6df3 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -13,13 +13,13 @@ ################################################## WORK IN PROGRESS ####################################################### -def test_compute_lr_cost_matrix(): +def test_compute_lr_sqeuclidean_matrix(): # test computation of low rank cost matrices M1 and M2 n = 100 X_s = np.reshape(1.0 * np.arange(2*n), (n, 2)) X_t = np.reshape(1.0 * np.arange(2*n), (n, 2)) - M1, M2 = ot.lowrank.compute_lr_cost_matrix(X_s, X_t) + M1, M2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t) M = ot.dist(X_s, X_t, metric="sqeuclidean") # original cost matrix np.testing.assert_allclose( @@ -35,8 +35,9 @@ def test_lowrank_sinkhorn(): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(n), (n, 1)) - value, value_linear, lazy_plan, Q, R, g = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) - P = lazy_plan[:] # default shape for lazy_plan in lowrank_sinkhorn is (ns, nt) + Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True) + P = log["lazy_plan"][:] + value_linear = log["value_linear"] # check constraints for P np.testing.assert_allclose(a, P.sum(1), atol=1e-05) @@ -58,7 +59,7 @@ def test_lowrank_sinkhorn(): @pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,6))) -def test_lowrank_sinkhorn_alpha_warning(alpha,rank): +def test_lowrank_sinkhorn_alpha_error(alpha,rank): # Test warning for value of alpha n = 100 a = ot.unif(n) @@ -71,20 +72,20 @@ def test_lowrank_sinkhorn_alpha_warning(alpha,rank): ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) +def test_lowrank_sinkhorn_backends(nx): + # Test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) -# def test_lowrank_sinkhorn_backends(nx): -# # Test low rank sinkhorn for different backends -# n = 100 -# a = ot.unif(n) -# b = ot.unif(n) - -# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) -# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) -# ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) -# value, value_linear, lazy_plan, Q, R, g = lowrank_sinkhorn(X_sb, X_tb, ab, bb, reg=0.1) -# P = lazy_plan[:] # default shape for lazy_plan in lowrank_sinkhorn is (ns, nt) + Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, reg=0.1, log=True) + lazy_plan = log["lazy_plan"] + P = lazy_plan[:] -# np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) -# np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) \ No newline at end of file + np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) + np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) \ No newline at end of file diff --git a/test/test_solvers.py b/test/test_solvers.py index bf07b7af8..7cb26a096 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -30,12 +30,14 @@ {'method': 'gaussian'}, {'method': 'gaussian', 'reg': 1}, {'method': 'factored', 'rank': 10}, + {'method': 'lowrank', 'reg':0.1} ] lst_parameters_solve_sample_NotImplemented = [ {'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics {'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean - {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean + {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean + {"method": 'lowrank', 'metric':'euclidean'}, # fail lowrank on metric not euclidean {'lazy': True}, # fail lazy for non regularized {'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced {'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized @@ -413,7 +415,7 @@ def test_solve_sample_methods(nx, method_params): assert_allclose_sol(sol, solb) sol2 = ot.solve_sample(x, x, **method_params) - if method_params['method'] != 'factored': + if method_params['method'] not in ['factored','lowrank']: np.testing.assert_allclose(sol2.value, 0) From bc7af6b33faa929dc2747a4018d6a9b07a4d71ee Mon Sep 17 00:00:00 2001 From: laudavid Date: Sat, 25 Nov 2023 13:06:07 +0100 Subject: [PATCH 20/22] fix torch backend for lowrank --- ot/lowrank.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index b2e443b74..7b44ad4a0 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -8,7 +8,7 @@ import warnings -from .utils import unif, get_lowrank_lazytensor +from .utils import unif, list_to_array, get_lowrank_lazytensor from .backend import get_backend @@ -33,15 +33,17 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): d = X_s.shape[1] # First low rank decomposition of the cost matrix (A) - M1 = nx.zeros((ns,(d+2))) - M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] - M1[:,1] = nx.ones(ns) + M1 = nx.zeros((ns,(d+2)), type_as=X_s) + norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)]) + M1[:,0] = nx.from_numpy(norm_M1) + M1[:,1] = nx.ones(ns, type_as=X_s) M1[:,2:] = -2*X_s # Second low rank decomposition of the cost matrix (B) - M2 = nx.zeros((nt,(d+2))) - M2[:,0] = nx.ones(nt) - M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] + M2 = nx.zeros((nt,(d+2)), type_as=X_s) + M2[:,0] = nx.ones(nt, type_as=X_s) + norm_M2 = list_to_array([nx.norm(X_t[i,:])**2 for i in range(nt)]) + M2[:,1] = nx.from_numpy(norm_M2) M2[:,2:] = X_t return M1, M2 @@ -67,9 +69,9 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No # ----------------- Initialisation of Dykstra algorithm ----------------- r = len(eps3) # rank g_ = nx.copy(eps3) # \tilde{g} - q3_1, q3_2 = nx.ones(r), nx.ones(r) # q^{(3)}_1, q^{(3)}_2 - v1_, v2_ = nx.ones(r), nx.ones(r) # \tilde{v}^{(1)}, \tilde{v}^{(2)} - q1, q2 = nx.ones(r), nx.ones(r) # q^{(1)}, q^{(2)} + q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 + v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)} + q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} err = 1 # initial error @@ -238,7 +240,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", gamma = 1/(2*L) # Initialize the low rank matrices Q, R, g - Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + Q, R, g = nx.ones((ns,r), type_as=a), nx.ones((nt,r), type_as=a), nx.ones(r, type_as=a) k = 100 # not specified in paper ? @@ -295,5 +297,3 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", - - From b40705c463d363777f1d79c94bebdc14e8e1c33d Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 28 Nov 2023 13:41:10 +0100 Subject: [PATCH 21/22] fix jax backend and skip tf --- ot/lowrank.py | 185 ++++++++++++++++++++----------------------- test/test_lowrank.py | 53 ++++++------- 2 files changed, 113 insertions(+), 125 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index 7b44ad4a0..ad804efd8 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -14,43 +14,37 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): """ - Compute low rank decomposition of a sqeuclidean cost matrix. - This function won't work for other metrics. + Compute low rank decomposition of a sqeuclidean cost matrix. + This function won't work for other metrics. See "Section 3.5, proposition 1" of the paper References ---------- .. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). - "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. + "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. """ if nx is None: - nx = get_backend(X_s,X_t) - + nx = get_backend(X_s, X_t) + ns = X_s.shape[0] nt = X_t.shape[0] - d = X_s.shape[1] # First low rank decomposition of the cost matrix (A) - M1 = nx.zeros((ns,(d+2)), type_as=X_s) - norm_M1 = list_to_array([nx.norm(X_s[i,:])**2 for i in range(ns)]) - M1[:,0] = nx.from_numpy(norm_M1) - M1[:,1] = nx.ones(ns, type_as=X_s) - M1[:,2:] = -2*X_s + array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1)) + array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1)) + M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1) # Second low rank decomposition of the cost matrix (B) - M2 = nx.zeros((nt,(d+2)), type_as=X_s) - M2[:,0] = nx.ones(nt, type_as=X_s) - norm_M2 = list_to_array([nx.norm(X_t[i,:])**2 for i in range(nt)]) - M2[:,1] = nx.from_numpy(norm_M2) - M2[:,2:] = X_t + array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1)) + array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1)) + M2 = nx.concatenate((array1, array2, X_t), axis=1) return M1, M2 - -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): """ Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. @@ -58,30 +52,27 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No ---------- .. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. - + """ # POT backend if None if nx is None: nx = get_backend(eps1, eps2, eps3, p1, p2) - # ----------------- Initialisation of Dykstra algorithm ----------------- - r = len(eps3) # rank - g_ = nx.copy(eps3) # \tilde{g} - q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 - v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)} - q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} - err = 1 # initial error - + r = len(eps3) # rank + g_ = nx.copy(eps3) # \tilde{g} + q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 + v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)} + q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} + err = 1 # initial error # --------------------- Dykstra algorithm ------------------------- - + # See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper - - for ii in range(numItermax): - if err > stopThr: + for ii in range(numItermax): + if err > stopThr: # Compute u^{(1)} and u^{(2)} u1 = p1 / nx.dot(eps1, v1_) u2 = p2 / nx.dot(eps2, v2_) @@ -92,19 +83,19 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No g_ = nx.copy(g) # Compute new value of g with \prod - prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) - prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) - g = (g_ * q3_2 * prod1 * prod2)**(1/3) + prod1 = (v1_ * q1) * nx.dot(eps1.T, u1) + prod2 = (v2_ * q2) * nx.dot(eps2.T, u2) + g = (g_ * q3_2 * prod1 * prod2) ** (1 / 3) # Compute v^{(1)} and v^{(2)} - v1 = g / nx.dot(eps1.T,u1) - v2 = g / nx.dot(eps2.T,u2) + v1 = g / nx.dot(eps1.T, u1) + v2 = g / nx.dot(eps2.T, u2) # Compute q^{(1)}, q^{(2)} and q^{(3)}_2 q1 = (v1_ * q1) / v1 q2 = (v2_ * q2) / v2 q3_2 = (g_ * q3_2) / g - + # Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g} v1_, v2_ = nx.copy(v1), nx.copy(v2) g_ = nx.copy(g) @@ -116,36 +107,32 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=No else: break - - else: + + else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` ") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + ) # Compute low rank matrices Q, R - Q = u1[:,None] * eps1 * v1[None,:] - R = u2[:,None] * eps2 * v2[None,:] + Q = u1[:, None] * eps1 * v1[None, :] + R = u2[:, None] * eps2 * v2[None, :] return Q, R, g - - -#################################### LOW RANK SINKHORN ALGORITHM ######################################### - - -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", - numItermax=10000, stopThr=1e-9, warn=True, log=False): - - r''' +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", + numItermax=1000, stopThr=1e-9, warn=True, log=False): + r""" Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. The function solves the following optimization problem: .. math:: - \mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - + \mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - \mathrm{reg} \cdot H((Q,R,g)) - + where : - :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. @@ -153,11 +140,11 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", - :math: `g` is the weight vector for the low-rank decomposition of the OT plan - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - :math: `r` is the rank of the OT plan - - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem - \mathcal{C(a,b,r)} = \mathcal{C_1(a,b,r)} \cap \mathcal{C_2(r)} with + - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem + \mathcal{C(a,b,r)} = \mathcal{C_1(a,b,r)} \cap \mathcal{C_2(r)} with \mathcal{C_1(a,b,r)} = \{ (Q,R,g) s.t Q\mathbb{1}_r = a, R^T \mathbb{1}_m = b \} \mathcal{C_2(r)} = \{ (Q,R,g) s.t Q\mathbb{1}_n = R^T \mathbb{1}_m = g \} - + Parameters ---------- @@ -184,30 +171,30 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", log : bool, optional record log if True - + Returns ------- lazy_plan : LazyTensor() - OT plan in a LazyTensor object of shape (shape_plan) + OT plan in a LazyTensor object of shape (shape_plan) See :any:`LazyTensor` for more information. value : float - Optimal value of the optimization problem, + Optimal value of the optimization problem value_linear : float - Linear OT loss with the optimal OT + Linear OT loss with the optimal OT Q : array-like, shape (n_samples_a, r) - First low-rank matrix decomposition of the OT plan + First low-rank matrix decomposition of the OT plan R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) Weight vector for the low-rank decomposition of the OT plan - - + + References ---------- .. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. - ''' + """ # POT backend nx = get_backend(X_s, X_t) @@ -223,66 +210,72 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", r = rank if rank == "auto": r = min(ns, nt) - + if alpha == "auto": alpha = 1e-10 # Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank) # (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper) - if 1/r < alpha : - raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format(a=alpha,r=1/rank)) + if 1 / r < alpha: + raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( + a=alpha, r=1 / rank)) # Low rank decomposition of the sqeuclidean cost matrix (A, B) M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None) # Compute gamma (see "Section 3.4, proposition 4" in the paper) - L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) - gamma = 1/(2*L) - - # Initialize the low rank matrices Q, R, g - Q, R, g = nx.ones((ns,r), type_as=a), nx.ones((nt,r), type_as=a), nx.ones(r, type_as=a) - k = 100 # not specified in paper ? - + L = nx.sqrt( + 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + + (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 + ) + gamma = 1 / (2 * L) + + # Initialize the low rank matrices Q, R, g + Q = nx.ones((ns, r), type_as=a) + R = nx.ones((nt, r), type_as=a) + g = nx.ones(r, type_as=a) + k = 100 # -------------------------- Low rank algorithm ------------------------------ # see "Section 3.3, Algorithm 3 LOT" in the paper - for ii in range(k): + for ii in range(k): # Compute the C*R dot matrix using the lr decomposition of C CR_ = nx.dot(M2.T, R) - CR = nx.dot(M1, CR_) - + CR = nx.dot(M1, CR_) + # Compute the C.t * Q dot matrix using the lr decomposition of C CQ_ = nx.dot(M1.T, Q) CQ = nx.dot(M2, CQ_) - - diag_g = (1/g)[None,:] - eps1 = nx.exp(-gamma*(CR*diag_g) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(CQ*diag_g) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + diag_g = (1 / g)[None, :] - Q, R, g = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx) - Q = Q+1e-16 - R = R+1e-16 + eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q)) + eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g)) + Q, R, g = LR_Dysktra( + eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx + ) + Q = Q + 1e-16 + R = R + 1e-16 # ----------------- Compute lazy_plan, value and value_linear ------------------ # see "Section 3.2: The Low-rank OT Problem" in the paper # Compute lazy plan (using LazyTensor class) - lazy_plan = get_lowrank_lazytensor(Q, R, 1/g) - + lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) + # Compute value_linear (using trace formula) - v1 = nx.dot(Q.T,M1) - v2 = nx.dot(R,(v1.T*diag_g).T) + v1 = nx.dot(Q.T, M1) + v2 = nx.dot(R, (v1.T * diag_g).T) value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) # Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper) - reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q - reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g - reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R + reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q + reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g + reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R value = value_linear + reg * (reg_Q + reg_g + reg_R) if log: @@ -290,10 +283,8 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", dict_log["value"] = value dict_log["value_linear"] = value_linear dict_log["lazy_plan"] = lazy_plan - + return Q, R, g, dict_log return Q, R, g - - diff --git a/test/test_lowrank.py b/test/test_lowrank.py index e3ffe6df3..65f76a77b 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -1,29 +1,24 @@ -##################################################################################################### -####################################### WORK IN PROGRESS ############################################ -##################################################################################################### - - """ Test for low rank sinkhorn solvers """ +# Author: Laurène DAVID +# +# License: MIT License + import ot import numpy as np import pytest - -################################################## WORK IN PROGRESS ####################################################### - def test_compute_lr_sqeuclidean_matrix(): # test computation of low rank cost matrices M1 and M2 n = 100 - X_s = np.reshape(1.0 * np.arange(2*n), (n, 2)) - X_t = np.reshape(1.0 * np.arange(2*n), (n, 2)) + X_s = np.reshape(1.0 * np.arange(2 * n), (n, 2)) + X_t = np.reshape(1.0 * np.arange(2 * n), (n, 2)) M1, M2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t) - M = ot.dist(X_s, X_t, metric="sqeuclidean") # original cost matrix + M = ot.dist(X_s, X_t, metric="sqeuclidean") # original cost matrix - np.testing.assert_allclose( - np.dot(M1,M2.T), M, atol=1e-05) + np.testing.assert_allclose(np.dot(M1, M2.T), M, atol=1e-05) def test_lowrank_sinkhorn(): @@ -40,38 +35,40 @@ def test_lowrank_sinkhorn(): value_linear = log["value_linear"] # check constraints for P - np.testing.assert_allclose(a, P.sum(1), atol=1e-05) - np.testing.assert_allclose(b, P.sum(0), atol=1e-05) - + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + # check if lazy_plan is equal to the fully computed plan - P_true = np.dot(Q,np.dot(np.diag(1/g),R.T)) + P_true = np.dot(Q, np.dot(np.diag(1 / g), R.T)) np.testing.assert_allclose(P, P_true, atol=1e-05) # check if value_linear is correct with its original formula M = ot.dist(X_s, X_t, metric="sqeuclidean") value_linear_true = np.sum(M * P_true) np.testing.assert_allclose(value_linear, value_linear_true, atol=1e-05) - - # check warn parameter when Dykstra algorithm doesn't converge + + # check warn parameter when Dykstra algorithm doesn't converge with pytest.warns(UserWarning): ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1) - -@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,6))) -def test_lowrank_sinkhorn_alpha_error(alpha,rank): - # Test warning for value of alpha +@pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6))) +def test_lowrank_sinkhorn_alpha_error(alpha, rank): + # Test warning for value of alpha n = 100 a = ot.unif(n) b = ot.unif(n) X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - + with pytest.raises(ValueError): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) + ot.lowrank.lowrank_sinkhorn( + X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False + ) +@pytest.skip_backend('tf') def test_lowrank_sinkhorn_backends(nx): # Test low rank sinkhorn for different backends n = 100 @@ -85,7 +82,7 @@ def test_lowrank_sinkhorn_backends(nx): Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, reg=0.1, log=True) lazy_plan = log["lazy_plan"] - P = lazy_plan[:] + P = lazy_plan[:] - np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) - np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) \ No newline at end of file + np.testing.assert_allclose(ab, P.sum(1), atol=1e-05) + np.testing.assert_allclose(bb, P.sum(0), atol=1e-05) From 55c8d2bafcd01ba0fb52b65f5324b165d2f0ff35 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 28 Nov 2023 22:06:28 +0100 Subject: [PATCH 22/22] fix pep 8 tests --- ot/lowrank.py | 3 +-- ot/solvers.py | 7 +++---- ot/utils.py | 3 ++- test/test_solvers.py | 8 ++++---- test/test_utils.py | 1 - 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index ad804efd8..365d78ed9 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -8,7 +8,7 @@ import warnings -from .utils import unif, list_to_array, get_lowrank_lazytensor +from .utils import unif, get_lowrank_lazytensor from .backend import get_backend @@ -287,4 +287,3 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", return Q, R, g, dict_log return Q, R, g - diff --git a/ot/solvers.py b/ot/solvers.py index 958b951d1..40a03e974 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1248,19 +1248,19 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] - + elif method == "lowrank": if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) - + if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if reg is None: reg = 0 - + Q, R, g, log = lowrank_sinkhorn(X_a, X_b, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) value = log['value'] value_linear = log['value_linear'] @@ -1268,7 +1268,6 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if not lazy0: # store plan if not lazy plan = lazy_plan[:] - elif method.startswith('geomloss'): # Geomloss solver for entropi OT split_method = method.split('_') diff --git a/ot/utils.py b/ot/utils.py index 3e67a08b4..cb29b21c9 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1176,6 +1176,7 @@ def citation(self): } """ + class LazyTensor(object): """ A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. @@ -1240,4 +1241,4 @@ def __getitem__(self, key): return self._getitem(*k, **self.kwargs) def __repr__(self): - return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) \ No newline at end of file + return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) diff --git a/test/test_solvers.py b/test/test_solvers.py index 7cb26a096..343220c45 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -30,14 +30,14 @@ {'method': 'gaussian'}, {'method': 'gaussian', 'reg': 1}, {'method': 'factored', 'rank': 10}, - {'method': 'lowrank', 'reg':0.1} + {'method': 'lowrank', 'reg': 0.1} ] lst_parameters_solve_sample_NotImplemented = [ {'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics {'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean - {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean - {"method": 'lowrank', 'metric':'euclidean'}, # fail lowrank on metric not euclidean + {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean + {"method": 'lowrank', 'metric': 'euclidean'}, # fail lowrank on metric not euclidean {'lazy': True}, # fail lazy for non regularized {'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced {'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized @@ -415,7 +415,7 @@ def test_solve_sample_methods(nx, method_params): assert_allclose_sol(sol, solb) sol2 = ot.solve_sample(x, x, **method_params) - if method_params['method'] not in ['factored','lowrank']: + if method_params['method'] not in ['factored', 'lowrank']: np.testing.assert_allclose(sol2.value, 0) diff --git a/test/test_utils.py b/test/test_utils.py index 4a6dec9cf..258a1c742 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -583,4 +583,3 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) - \ No newline at end of file