From f49f6b4f34ddd3a2313e1df00c487bd7f47df845 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 14:54:42 +0200 Subject: [PATCH 01/11] new file for lr sinkhorn --- ot/lowrank.py | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 ot/lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..ba46cd1ed --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,171 @@ +################################################################################################################# +############################################## WORK IN PROGRESS ################################################# +################################################################################################################# + + +from ot.utils import unif, list_to_array +from ot.backend import get_backend +from ot.datasets import make_1D_gauss as gauss + + + +################################## LR-DYSKTRA ALGORITHM ########################################## + +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): + """ + Implementation of the Dykstra algorithm for low rank Sinkhorn + """ + + # get dykstra parameters + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + + # POT backend + eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) + q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) + + nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) + + # ------- Dykstra algorithm ------ + g_ = eps3 + + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1, v2 + g_ = g + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + # Compute low rank matrices Q, R + Q = u1[:,None] * eps1 * v1[None,:] + R = u2[:,None] * eps2 * v2[None,:] + + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + + return Q, R, g, err, dykstra_w + + + +#################################### LOW RANK SINKHORN ALGORITHM ######################################### + + +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): + r''' + Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + + Returns + ------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + ... + + ''' + + X_s, X_t = list_to_array(X_s, X_t) + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + M = ot.dist(X_s,X_t, metric=metric) + + # Compute rank + r = min(ns, nt, r) + + # Compute gamma + L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + gamma = 1/(2*L) + + # Initialisation + Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + q3_1, q3_2 = nx.ones(r), nx.ones(r) + v1_, v2_ = nx.ones(r), nx.ones(r) + q1, q2 = nx.ones(r), nx.ones(r) + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + n_iter = 0 + err = 1 + + while n_iter < numIterMax: + if err > stopThr: + n_iter = n_iter + 1 + + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + else: + break + + return Q, R, g + + + + + +############################################################################ +## Test with X_s, X_t from ot.datasets +############################################################################# + +import numpy as np +import ot + +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + + +Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + +print(np.sum(P)) + + + + From 3c4b50fdb660f27cc080618edb664d17086d93a9 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 16:47:21 +0200 Subject: [PATCH 02/11] lr sinkhorn, solve_sample, OTResultLazy --- ot/lowrank.py | 40 +++++++------ ot/solvers.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++++++ ot/utils.py | 90 ++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+), 19 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index ba46cd1ed..a1c73bdf3 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -2,8 +2,10 @@ ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# +## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms -from ot.utils import unif, list_to_array + +from ot.utils import unif, list_to_array, dist from ot.backend import get_backend from ot.datasets import make_1D_gauss as gauss @@ -11,13 +13,13 @@ ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ Implementation of the Dykstra algorithm for low rank Sinkhorn """ # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p # POT backend eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) @@ -58,18 +60,18 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - return Q, R, g, err, dykstra_w + return Q, R, g, err, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. Parameters ---------- @@ -95,7 +97,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - ... + Third low-rank matrix decomposition of the OT plan ''' @@ -108,7 +110,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) - M = ot.dist(X_s,X_t, metric=metric) + M = dist(X_s,X_t, metric=metric) # Compute rank r = min(ns, nt, r) @@ -122,7 +124,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', q3_1, q3_2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] n_iter = 0 err = 1 @@ -139,7 +141,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) else: break @@ -153,18 +155,18 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', ## Test with X_s, X_t from ot.datasets ############################################################################# -import numpy as np -import ot +# import numpy as np +# import ot -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# M = ot.dist(Xs,Xt) +# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) -print(np.sum(P)) +# print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..9c2746c25 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -848,3 +848,164 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + + + +################################## WORK IN PROGRESS ##################################### + +## Implementation of the ot.solve_sample function +## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) + + +from .utils import unif, list_to_array, dist, OTResultLazy +from .bregman import empirical_sinkhorn + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res_lazy : OTResultLazy() + Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. + The information can be obtained as follows: + + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.potentials : OT dual potentials + + See :any:`OTResultLazy` for more information. + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + + X_s, X_t = list_to_array(X_s,X_t) + + # detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) + pass + elif unbalanced_type.lower() in ['kl', 'l2']: + pass + elif unbalanced_type.lower() == 'tv': + pass + pass + ############################################# + + else: + # compute potentials + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + + diff --git a/ot/utils.py b/ot/utils.py index 8cbb0db25..d570b9f30 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -938,3 +938,93 @@ def citation(self): url = {http://jmlr.org/papers/v22/20-451.html} } """ + + + +############################## WORK IN PROGRESS #################################### + +## Implementation of the OTResultLazy class for ot.solve_sample() with potentials and lazy_plan as attributes + +class OTResultLazy: + def __init__(self, potentials=None, lazy_plan=None, backend=None): + + self._potentials = potentials + self._lazy_plan = lazy_plan + self._backend = backend if backend is not None else NumpyBackend() + + + # Dual potentials -------------------------------------------- + + def __repr__(self): + s = 'OTResultLazy(' + if self._lazy_plan is not None: + s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) + + if s[-1] != '(': + s = s[:-1] + ')' + else: + s = s + ')' + return s + + @property + def potentials(self): + """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. + + This pair of arrays has the same shape, numerical type + and properties as the input weights "a" and "b". + """ + if self._potentials is not None: + return self._potentials + else: + raise NotImplementedError() + + @property + def potential_a(self): + """First dual potential, associated to the "source" measure "a".""" + if self._potentials is not None: + return self._potentials[0] + else: + raise NotImplementedError() + + @property + def potential_b(self): + """Second dual potential, associated to the "target" measure "b".""" + if self._potentials is not None: + return self._potentials[1] + else: + raise NotImplementedError() + + # Transport plan ------------------------------------------- + @property + def lazy_plan(self): + """A subset of the Transport plan, encoded as a dense array.""" + + if self._lazy_plan is not None: + return self._lazy_plan + else: + raise NotImplementedError() + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ \ No newline at end of file From 3034e575c55d2ce56499be6849e1906fe52f0573 Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 25 Oct 2023 17:39:08 +0200 Subject: [PATCH 03/11] add test functions + small modif lr_sin/solve_sample --- ot/lowrank.py | 97 ++++++++++++++++++++++++++++------------- ot/solvers.py | 47 +++++++++++--------- test/test_lowrank.py | 84 ++++++++++++++++++++++++++++++++++++ test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 27 ++++++++++++ 5 files changed, 304 insertions(+), 51 deletions(-) create mode 100644 test/test_lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py index a1c73bdf3..22ff8b754 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -4,10 +4,9 @@ ## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms - -from ot.utils import unif, list_to_array, dist -from ot.backend import get_backend -from ot.datasets import make_1D_gauss as gauss +import warnings +from .utils import unif, list_to_array, dist +from .backend import get_backend @@ -15,7 +14,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ - Implementation of the Dykstra algorithm for low rank Sinkhorn + Implementation of the Dykstra algorithm for low rank sinkhorn """ # get dykstra parameters @@ -69,9 +68,12 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", + numItermax=10000, stopThr=1e-9, warn=True, verbose=False): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters ---------- @@ -79,17 +81,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain - reg : float - Regularization term >0 a : array-like, shape (n_samples_a,) samples weights in the source domain b : array-like, shape (n_samples_b,) samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional + Nonnegative rank of the OT plan + alpha: int, optional + Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) + Returns ------- Q : array-like, shape (n_samples_a, r) @@ -97,7 +104,14 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - Third low-rank matrix decomposition of the OT plan + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' @@ -110,13 +124,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + # Compute cost matrix M = dist(X_s,X_t, metric=metric) - + # Compute rank - r = min(ns, nt, r) + rank = min(ns, nt, rank) + r = rank + + if alpha == 'auto': + alpha = 1.0 / (r + 1) + + if (1/r < alpha) or (alpha < 0): + warnings.warn("The provided alpha value might lead to instabilities.") + # Compute gamma - L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) gamma = 1/(2*L) # Initialisation @@ -125,25 +148,34 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - n_iter = 0 err = 1 - while n_iter < numIterMax: - if err > stopThr: - n_iter = n_iter + 1 - - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - else: + for ii in range(numItermax): + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) + + if err < stopThr: break + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") return Q, R, g @@ -161,8 +193,13 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', # Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) # Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# ns = Xs.shape[0] +# nt = Xt.shape[0] + +# a = unif(ns) +# b = unif(nt) -# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) # M = ot.dist(Xs,Xt) # P = np.dot(Q,np.dot(np.diag(1/g),R.T)) diff --git a/ot/solvers.py b/ot/solvers.py index 9c2746c25..c176969ca 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -926,7 +926,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. The information can be obtained as follows: - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - res.potentials : OT dual potentials See :any:`OTResultLazy` for more information. @@ -975,29 +975,34 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if is_Lazy: ################# WIP #################### if reg is None or reg == 0: # EMD solver for isLazy ? - if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) - pass - elif unbalanced_type.lower() in ['kl', 'l2']: - pass - elif unbalanced_type.lower() == 'tv': - pass - pass + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + ############################################# else: - # compute potentials - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) else: # compute cost matrix M and use solve function diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..6e1f24067 --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,84 @@ +##################################################################################################### +####################################### WORK IN PROGRESS ############################################ +##################################################################################################### + + +""" Test for low rank sinkhorn solvers """ + +import ot +import numpy as np +import pytest +from itertools import product + + +def test_LR_Dykstra(): + # test for LR_Dykstra algorithm ? catch nan values ? + pass + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + + Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') + P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + + # check constraints + np.testing.assert_allclose( + a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + a, P_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + b, P_m.sum(0), atol=1e-05) # metric euclidian + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + + + +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +def test_lowrank_sinkhorn_alpha_warning(alpha,rank): + # test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + + + +def test_lowrank_sinkhorn_backends(nx): + # test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) + P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..5a05d54cf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,103 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +########################################################################################################### +############################################ WORK IN PROGRESS ############################################# +########################################################################################################### + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 40324518e..a14be460e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -401,3 +401,30 @@ def test_get_coordinate_circle(): x_p = ot.utils.get_coordinate_circle(x) np.testing.assert_allclose(u[0], x_p) + + + +############################################################################################## +##################################### WORK IN PROGRESS ####################################### +############################################################################################## + +# test function for OTResultLazy + +def test_OTResultLazy(): + + res_lazy = ot.utils.OTResultLazy() + + # test print + print(res_lazy) + + # tets get citation + print(res_lazy.citation) + + lst_attributes = ['lazy_plan', + 'potential_a', + 'potential_b', + 'potentials'] + + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res_lazy, at) \ No newline at end of file From 085863aef96f0d19e740879dfae158a762275a67 Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 26 Oct 2023 10:49:23 +0200 Subject: [PATCH 04/11] add import to __init__ --- ot/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..cb00f4553 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -50,7 +51,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq From 9becafc305fd6b2cc5390b0de16bae015bd41121 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 11:38:40 +0100 Subject: [PATCH 05/11] modify low rank, remove solve_sample,OTResultLazy --- ot/__init__.py | 4 +- ot/lowrank.py | 200 ++++++++++++++++++++++++------------------- ot/solvers.py | 160 ---------------------------------- ot/utils.py | 89 ------------------- test/test_lowrank.py | 54 ++++++------ test/test_solvers.py | 97 --------------------- test/test_utils.py | 25 ------ 7 files changed, 142 insertions(+), 487 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index cb00f4553..4aba450af 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov from .lowrank import lowrank_sinkhorn # utils functions @@ -70,4 +70,4 @@ 'factored_optimal_transport', 'solve', 'solve_gromov', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn'] diff --git a/ot/lowrank.py b/ot/lowrank.py index 22ff8b754..b3fce8de0 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -1,78 +1,88 @@ +""" +Low rank OT solvers +""" + +# Author: Laurène David +# +# License: MIT License + + + ################################################################################################################# ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# -## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms import warnings -from .utils import unif, list_to_array, dist -from .backend import get_backend +from ot.utils import unif +from ot.backend import get_backend ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): """ - Implementation of the Dykstra algorithm for low rank sinkhorn + Implementation of the Dykstra algorithm for the Low rank sinkhorn solver + """ + # Get dykstra parameters + g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2 = dykstra_p + g_ = eps3.copy() + err = 1 - # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p + # POT backend if needed + if nx is None: + nx = get_backend(eps1, eps2, eps3, p1, p2, + g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2) - # POT backend - eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) - q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) - - nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) - - # ------- Dykstra algorithm ------ - g_ = eps3 - u1 = p1 / nx.dot(eps1, v1_) - u2 = p2 / nx.dot(eps2, v2_) + # ------- Dykstra algorithm ------ + while err > stopThr : + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) - g = nx.maximum(alpha, g_ * q3_1) - q3_1 = (g_ * q3_1) / g - g_ = g + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g.copy() - prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) - prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) - g = (g_ * q3_2 * prod1 * prod2)**(1/3) + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) - v1 = g / nx.dot(eps1.T,u1) - v2 = g / nx.dot(eps2.T,u2) + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + q1 = (v1_ * q1) / v1 - q1 = (v1_ * q1) / v1 - q2 = (v2_ * q2) / v2 - q3_2 = (g_ * q3_2) / g - - v1_, v2_ = v1, v2 - g_ = g + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1.copy(), v2.copy() + g_ = g.copy() - # Compute error - err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) - err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) - err = err1 + err2 + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 # Compute low rank matrices Q, R Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - return Q, R, g, err, dykstra_p + return Q, R, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", - numItermax=10000, stopThr=1e-9, warn=True, verbose=False): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", + numItermax=1000, stopThr=1e-9, warn=True, verbose=False): #stopThr = 1e-9 + r''' Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. - This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters @@ -95,6 +105,9 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea Max number of iterations stopThr : float, optional Stop threshold on error (>0) + warn: + + verbose: Returns @@ -109,73 +122,87 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea References ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' - X_s, X_t = list_to_array(X_s, X_t) nx = get_backend(X_s, X_t) - ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) + a = unif(ns, type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) + b = unif(nt, type_as=X_t) - # Compute cost matrix - M = dist(X_s,X_t, metric=metric) + d = X_s.shape[1] + + # First low rank decomposition of the cost matrix (A) + M1 = nx.zeros((ns,(d+2))) + M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] + M1[:,1] = nx.ones(ns) + M1[:,2:] = -2*X_s + + # Second low rank decomposition of the cost matrix (B) + M2 = nx.zeros((nt,(d+2))) + M2[:,0] = nx.ones(nt) + M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] + M2[:,2:] = X_t # Compute rank rank = min(ns, nt, rank) r = rank + # Alpha: lower bound for 1/rank if alpha == 'auto': - alpha = 1.0 / (r + 1) + alpha = 1e-3 # no convergence with alpha = 1 / (r+1) if (1/r < alpha) or (alpha < 0): - warnings.warn("The provided alpha value might lead to instabilities.") + warnings.warn("The provided alpha value might lead to instabilities.") - # Compute gamma - L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) + L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) gamma = 1/(2*L) - # Initialisation + # Initialize the low rank matrices Q, R, g Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + + # Initialize parameters for Dykstra algorithm q3_1, q3_2 = nx.ones(r), nx.ones(r) + u1, u2 = nx.ones(ns), nx.ones(nt) + v1, v2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - err = 1 + dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] + - for ii in range(numItermax): - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) + for ii in range(numItermax): + CR_ = nx.dot(M2.T, R) + CR = nx.dot(M1, CR_) + + CQ_ = nx.dot(M1.T, Q) + CQ = nx.dot(M2, CQ_) + diag_g = (1/g)[:,None] eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - - if err < stopThr: - break - - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, nx) + g = dykstra_p[0] + + # if verbose: + # if ii % 200 == 0: + # print( + # '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + # print('{:5d}|{:8e}|'.format(ii, err)) + + # else: + # if warn: + # warnings.warn("Sinkhorn did not converge. You might want to " + # "increase the number of iterations `numItermax` " + # "or the regularization parameter `reg`.") return Q, R, g @@ -187,24 +214,23 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea ## Test with X_s, X_t from ot.datasets ############################################################################# -# import numpy as np -# import ot - -# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +import numpy as np +import ot -# ns = Xs.shape[0] -# nt = Xt.shape[0] +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -# a = unif(ns) -# b = unif(nt) +ns = Xs.shape[0] +nt = Xt.shape[0] -# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) -# M = ot.dist(Xs,Xt) -# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +a = unif(ns) +b = unif(nt) -# print(np.sum(P)) +Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, verbose=True, numItermax=20) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index c176969ca..8d6e10a5f 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -854,163 +854,3 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, -################################## WORK IN PROGRESS ##################################### - -## Implementation of the ot.solve_sample function -## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) - - -from .utils import unif, list_to_array, dist, OTResultLazy -from .bregman import empirical_sinkhorn - - -def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False): - - r"""Solve the discrete optimal transport problem using the samples in the source and target domains. - It returns either a :any:`OTResult` or :any:`OTResultLazy` object. - - The function solves the following general optimal transport problem - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + - \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + - \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By - default ``reg=None`` and there is no regularization. The unbalanced marginal - penalization can be selected with `unbalanced` (:math:`\lambda_u`) and - `unbalanced_type`. By default ``unbalanced=None`` and the function - solves the exact optimal transport problem (respecting the marginals). - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - a : array-like, shape (dim_a,), optional - Samples weights in the source domain (default is uniform) - b : array-like, shape (dim_b,), optional - Samples weights in the source domain (default is uniform) - reg : float, optional - Regularization weight :math:`\lambda_r`, by default None (no reg., exact - OT) - reg_type : str, optional - Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT) - unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" - is_Lazy : bool, optional - Return :any:`OTResultlazy` object to reduce memory cost when True, by default False - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 - max_iter : int, optional - Maximum number of iteration, by default None (default values in each solvers) - plan_init : array_like, shape (dim_a, dim_b), optional - Initialization of the OT plan for iterative methods, by default None - potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional - Initialization of the OT dual potentials for iterative methods, by default None - tol : _type_, optional - Tolerance for solution precision, by default None (default values in each solvers) - verbose : bool, optional - Print information in the solver, by default False - - Returns - ------- - - res_lazy : OTResultLazy() - Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. - The information can be obtained as follows: - - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - - res.potentials : OT dual potentials - - See :any:`OTResultLazy` for more information. - - res : OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials - - res.value : Optimal value of the optimization problem - - res.value_linear : Linear OT loss with the optimal OT plan - - See :any:`OTResult` for more information. - - - """ - - X_s, X_t = list_to_array(X_s,X_t) - - # detect backend - arr = [X_s,X_t] - if a is not None: - arr.append(a) - if b is not None: - arr.append(b) - nx = get_backend(*arr) - - # create uniform weights if not given - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) - if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) - - # default values for solutions - potentials = None - lazy_plan = None - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if batch_size is None: - batch_size = 100 - - if is_Lazy: - ################# WIP #################### - if reg is None or reg == 0: # EMD solver for isLazy ? - - if unbalanced is None: # balanced EMD solver for isLazy ? - raise (NotImplementedError('Not implemented balanced with no regularization')) - - else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) - - - ############################################# - - else: - if unbalanced is None: - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - # compute potentials - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy - - else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) - - else: - # compute cost matrix M and use solve function - M = dist(X_s, X_t, metric) - - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) - return res - - - - diff --git a/ot/utils.py b/ot/utils.py index d570b9f30..01944f56b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -939,92 +939,3 @@ def citation(self): } """ - - -############################## WORK IN PROGRESS #################################### - -## Implementation of the OTResultLazy class for ot.solve_sample() with potentials and lazy_plan as attributes - -class OTResultLazy: - def __init__(self, potentials=None, lazy_plan=None, backend=None): - - self._potentials = potentials - self._lazy_plan = lazy_plan - self._backend = backend if backend is not None else NumpyBackend() - - - # Dual potentials -------------------------------------------- - - def __repr__(self): - s = 'OTResultLazy(' - if self._lazy_plan is not None: - s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) - - if s[-1] != '(': - s = s[:-1] + ')' - else: - s = s + ')' - return s - - @property - def potentials(self): - """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. - - This pair of arrays has the same shape, numerical type - and properties as the input weights "a" and "b". - """ - if self._potentials is not None: - return self._potentials - else: - raise NotImplementedError() - - @property - def potential_a(self): - """First dual potential, associated to the "source" measure "a".""" - if self._potentials is not None: - return self._potentials[0] - else: - raise NotImplementedError() - - @property - def potential_b(self): - """Second dual potential, associated to the "target" measure "b".""" - if self._potentials is not None: - return self._potentials[1] - else: - raise NotImplementedError() - - # Transport plan ------------------------------------------- - @property - def lazy_plan(self): - """A subset of the Transport plan, encoded as a dense array.""" - - if self._lazy_plan is not None: - return self._lazy_plan - else: - raise NotImplementedError() - - @property - def citation(self): - """Appropriate citation(s) for this result, in plain text and BibTex formats.""" - - # The string below refers to the POT library: - # successor methods may concatenate the relevant references - # to the original definitions, solvers and underlying numerical backends. - return """POT library: - - POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. - Website: https://pythonot.github.io/ - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; - - @article{flamary2021pot, - author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, - title = {{POT}: {Python} {Optimal} {Transport}}, - journal = {Journal of Machine Learning Research}, - year = {2021}, - volume = {22}, - number = {78}, - pages = {1-8}, - url = {http://jmlr.org/papers/v22/20-451.html} - } - """ \ No newline at end of file diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 6e1f24067..7d90ce9ef 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -16,34 +16,34 @@ def test_LR_Dykstra(): pass -@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -def test_lowrank_sinkhorn(verbose, warn): - # test low rank sinkhorn - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) - P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) - - Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') - P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) - - # check constraints - np.testing.assert_allclose( - a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - a, P_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - b, P_m.sum(0), atol=1e-05) # metric euclidian +# @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +# def test_lowrank_sinkhorn(verbose, warn): +# # test low rank sinkhorn +# n = 100 +# a = ot.unif(n) +# b = ot.unif(n) + +# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) +# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + +# Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) +# P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + +# Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') +# P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + +# # check constraints +# np.testing.assert_allclose( +# a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian +# np.testing.assert_allclose( +# b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian +# np.testing.assert_allclose( +# a, P_m.sum(1), atol=1e-05) # metric euclidian +# np.testing.assert_allclose( +# b, P_m.sum(0), atol=1e-05) # metric euclidian - with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) +# with pytest.warns(UserWarning): +# ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) diff --git a/test/test_solvers.py b/test/test_solvers.py index 5a05d54cf..e845ac7c2 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -258,100 +258,3 @@ def test_solve_gromov_not_implemented(nx): - -########################################################################################################### -############################################ WORK IN PROGRESS ############################################# -########################################################################################################### - -def assert_allclose_sol_sample(sol1, sol2): - # test attributes of OTResultLazy class - lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] - - nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() - nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() - - for attr in lst_attr: - try: - np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) - except NotImplementedError: - pass - - -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) -def test_solve_sample(nx): - # test solve_sample when is_Lazy = False - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t) - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b) - - # check some attributes - sol.potentials - sol.sparse_plan - sol.marginals - sol.status - - assert_allclose_sol(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb) - - assert_allclose_sol(sol, solb) - - # test not implemented unbalanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') - - # test not implemented reg_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') - - - -def test_lazy_solve_sample(nx): - # test solve_sample when is_Lazy = True - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) - - # check some attributes - sol.potentials - sol.lazy_plan - - assert_allclose_sol_sample(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) - - assert_allclose_sol_sample(sol, solb) - - # test not implemented reg==0 (or None) + balanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default - - # test not implemented reg==0 (or None) + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default - - # test not implemented reg != 0 + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index a14be460e..bbadec65a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -403,28 +403,3 @@ def test_get_coordinate_circle(): np.testing.assert_allclose(u[0], x_p) - -############################################################################################## -##################################### WORK IN PROGRESS ####################################### -############################################################################################## - -# test function for OTResultLazy - -def test_OTResultLazy(): - - res_lazy = ot.utils.OTResultLazy() - - # test print - print(res_lazy) - - # tets get citation - print(res_lazy.citation) - - lst_attributes = ['lazy_plan', - 'potential_a', - 'potential_b', - 'potentials'] - - for at in lst_attributes: - with pytest.raises(NotImplementedError): - getattr(res_lazy, at) \ No newline at end of file From 58576a3f89ac3322c93588832bdb4daa380ce384 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 13:59:01 +0100 Subject: [PATCH 06/11] solve_sample + test functions --- ot/solvers.py | 147 ++++++++++++++++++++++++++++++++++++++++++- test/test_solvers.py | 98 +++++++++++++++++++++++++++++ 2 files changed, 243 insertions(+), 2 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 8d6e10a5f..6eb19ea05 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -7,11 +7,11 @@ # # License: MIT License -from .utils import OTResult +from .utils import OTResult, unif, dist from .lp import emd2 from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log +from .bregman import sinkhorn_log, empirical_sinkhorn from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, @@ -20,6 +20,8 @@ entropic_semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_gromov_wasserstein2) from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 +from .bregman import empirical_sinkhorn + #, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 @@ -851,6 +853,147 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, +########## ot.solve_sample function ########### + +from .bregman import empirical_sinkhorn +from .utils import unif, dist + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + # Detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # Create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + if metric is not 'sqeuclidean': + raise (NotImplementedError('Only implemented for sqeuclidean metric')) + + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + + ############################################# + + else: + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + pass + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index e845ac7c2..77723ddd0 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -258,3 +258,101 @@ def test_solve_gromov_not_implemented(nx): + +######## Test functions for ot.solve_sample ######## + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) + + + From ed1b22d13da6082a69c099e59915fed3301ff208 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 14:04:11 +0100 Subject: [PATCH 07/11] remove low rank from branch --- ot/__init__.py | 9 +- ot/lowrank.py | 247 ------------------------------------------- test/test_lowrank.py | 84 --------------- 3 files changed, 4 insertions(+), 336 deletions(-) delete mode 100644 ot/lowrank.py delete mode 100644 test/test_lowrank.py diff --git a/ot/__init__.py b/ot/__init__.py index 4aba450af..034875c55 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,7 +35,6 @@ from . import factored from . import solvers from . import gaussian -from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -51,8 +50,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov -from .lowrank import lowrank_sinkhorn +from .solvers import solve, solve_gromov, solve_sample + # utils functions from .utils import dist, unif, tic, toc, toq @@ -67,7 +66,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', 'solve_gromov', + 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample' 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/lowrank.py b/ot/lowrank.py deleted file mode 100644 index d583f4741..000000000 --- a/ot/lowrank.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Low rank OT solvers -""" - -# Author: Laurène David -# -# License: MIT License - - - -################################################################################################################# -############################################## WORK IN PROGRESS ################################################# -################################################################################################################# - - -import warnings -from ot.utils import unif -from ot.backend import get_backend - - - -################################## LR-DYSKTRA ALGORITHM ########################################## - -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): - """ - Implementation of the Dykstra algorithm for the Low rank sinkhorn solver - - """ - # Get dykstra parameters - g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2 = dykstra_p - g_ = eps3.copy() - err = 1 - - # POT backend if needed - if nx is None: - nx = get_backend(eps1, eps2, eps3, p1, p2, - g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2) - - - # ------- Dykstra algorithm ------ - while err > stopThr : - u1 = p1 / nx.dot(eps1, v1_) - u2 = p2 / nx.dot(eps2, v2_) - - g = nx.maximum(alpha, g_ * q3_1) - q3_1 = (g_ * q3_1) / g - g_ = g.copy() - - prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) - prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) - g = (g_ * q3_2 * prod1 * prod2)**(1/3) - - v1 = g / nx.dot(eps1.T,u1) - v2 = g / nx.dot(eps2.T,u2) - q1 = (v1_ * q1) / v1 - - q2 = (v2_ * q2) / v2 - q3_2 = (g_ * q3_2) / g - - v1_, v2_ = v1.copy(), v2.copy() - g_ = g.copy() - - # Compute error - err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) - err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) - err = err1 + err2 - - # Compute low rank matrices Q, R - Q = u1[:,None] * eps1 * v1[None,:] - R = u2[:,None] * eps2 * v2[None,:] - - dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - - return Q, R, dykstra_p - - - - -#################################### LOW RANK SINKHORN ALGORITHM ######################################### - - -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", - numItermax=1000, stopThr=1e-9, warn=True, verbose=False): #stopThr = 1e-9 - - r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. - This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - a : array-like, shape (n_samples_a,) - samples weights in the source domain - b : array-like, shape (n_samples_b,) - samples weights in the target domain - reg : float, optional - Regularization term >0 - rank: int, optional - Nonnegative rank of the OT plan - alpha: int, optional - Lower bound for the weight vector g (>0 and <1/r) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - warn: - - verbose: - - - Returns - ------- - Q : array-like, shape (n_samples_a, r) - First low-rank matrix decomposition of the OT plan - R: array-like, shape (n_samples_b, r) - Second low-rank matrix decomposition of the OT plan - g : array-like, shape (r, ) - Weight vector for the low-rank decomposition of the OT plan - - - References - ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). - Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. - - ''' - # POT backend - nx = get_backend(X_s, X_t) - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = unif(ns, type_as=X_s) - if b is None: - b = unif(nt, type_as=X_t) - - d = X_s.shape[1] - - # First low rank decomposition of the cost matrix (A) - M1 = nx.zeros((ns,(d+2))) - M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] - M1[:,1] = nx.ones(ns) - M1[:,2:] = -2*X_s - - # Second low rank decomposition of the cost matrix (B) - M2 = nx.zeros((nt,(d+2))) - M2[:,0] = nx.ones(nt) - M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] - M2[:,2:] = X_t - - # Compute rank - rank = min(ns, nt, rank) - r = rank - - # Alpha: lower bound for 1/rank - if alpha == 'auto': - alpha = 1e-3 # no convergence with alpha = 1 / (r+1) - - if (1/r < alpha) or (alpha < 0): - warnings.warn("The provided alpha value might lead to instabilities.") - - # Compute gamma - L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) - gamma = 1/(2*L) - - # Initialize the low rank matrices Q, R, g - Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) - - # Initialize parameters for Dykstra algorithm - q3_1, q3_2 = nx.ones(r), nx.ones(r) - u1, u2 = nx.ones(ns), nx.ones(nt) - v1, v2 = nx.ones(r), nx.ones(r) - v1_, v2_ = nx.ones(r), nx.ones(r) - q1, q2 = nx.ones(r), nx.ones(r) - dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - - - for ii in range(numItermax): - CR_ = nx.dot(M2.T, R) - CR = nx.dot(M1, CR_) - - CQ_ = nx.dot(M1.T, Q) - CQ = nx.dot(M2, CQ_) - - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, nx) - g = dykstra_p[0] - - # if verbose: - # if ii % 200 == 0: - # print( - # '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - # print('{:5d}|{:8e}|'.format(ii, err)) - - # else: - # if warn: - # warnings.warn("Sinkhorn did not converge. You might want to " - # "increase the number of iterations `numItermax` " - # "or the regularization parameter `reg`.") - - - # Compute OT value using trace formula for scalar product - v1 = nx.dot(Q.T,M1) - v2 = nx.dot(R,nx.dot(diag_g.T,v1)) - value_linear = nx.sum(nx.diag(nx.dot(v2,M2.T))) # compute Trace - - #value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - #value - - return value_linear, Q, R, g - - - - - -############################################################################ -## Test with X_s, X_t from ot.datasets -############################################################################# - -import numpy as np -import ot - -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) - -ns = Xs.shape[0] -nt = Xt.shape[0] - -a = unif(ns) -b = unif(nt) - -Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, verbose=True, numItermax=20) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) - -print(np.sum(P)) - - - diff --git a/test/test_lowrank.py b/test/test_lowrank.py deleted file mode 100644 index 7d90ce9ef..000000000 --- a/test/test_lowrank.py +++ /dev/null @@ -1,84 +0,0 @@ -##################################################################################################### -####################################### WORK IN PROGRESS ############################################ -##################################################################################################### - - -""" Test for low rank sinkhorn solvers """ - -import ot -import numpy as np -import pytest -from itertools import product - - -def test_LR_Dykstra(): - # test for LR_Dykstra algorithm ? catch nan values ? - pass - - -# @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -# def test_lowrank_sinkhorn(verbose, warn): -# # test low rank sinkhorn -# n = 100 -# a = ot.unif(n) -# b = ot.unif(n) - -# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) -# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - -# Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) -# P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) - -# Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') -# P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) - -# # check constraints -# np.testing.assert_allclose( -# a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian -# np.testing.assert_allclose( -# b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian -# np.testing.assert_allclose( -# a, P_m.sum(1), atol=1e-05) # metric euclidian -# np.testing.assert_allclose( -# b, P_m.sum(0), atol=1e-05) # metric euclidian - -# with pytest.warns(UserWarning): -# ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) - - - -@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) -def test_lowrank_sinkhorn_alpha_warning(alpha,rank): - # test warning for value of alpha - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) - - - -def test_lowrank_sinkhorn_backends(nx): - # test low rank sinkhorn for different backends - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) - - Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) - P = np.dot(Q,np.dot(np.diag(1/g),R.T)) - - np.testing.assert_allclose(a, P.sum(1), atol=1e-05) - np.testing.assert_allclose(b, P.sum(0), atol=1e-05) - - - - From 6ea251c89ecf52603eb81c798a0769e9a2cb9f54 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 14:54:42 +0200 Subject: [PATCH 08/11] new file for lr sinkhorn --- ot/lowrank.py | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 ot/lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..ba46cd1ed --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,171 @@ +################################################################################################################# +############################################## WORK IN PROGRESS ################################################# +################################################################################################################# + + +from ot.utils import unif, list_to_array +from ot.backend import get_backend +from ot.datasets import make_1D_gauss as gauss + + + +################################## LR-DYSKTRA ALGORITHM ########################################## + +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): + """ + Implementation of the Dykstra algorithm for low rank Sinkhorn + """ + + # get dykstra parameters + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + + # POT backend + eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) + q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) + + nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) + + # ------- Dykstra algorithm ------ + g_ = eps3 + + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1, v2 + g_ = g + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + # Compute low rank matrices Q, R + Q = u1[:,None] * eps1 * v1[None,:] + R = u2[:,None] * eps2 * v2[None,:] + + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + + return Q, R, g, err, dykstra_w + + + +#################################### LOW RANK SINKHORN ALGORITHM ######################################### + + +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): + r''' + Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + + Returns + ------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + ... + + ''' + + X_s, X_t = list_to_array(X_s, X_t) + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + M = ot.dist(X_s,X_t, metric=metric) + + # Compute rank + r = min(ns, nt, r) + + # Compute gamma + L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + gamma = 1/(2*L) + + # Initialisation + Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + q3_1, q3_2 = nx.ones(r), nx.ones(r) + v1_, v2_ = nx.ones(r), nx.ones(r) + q1, q2 = nx.ones(r), nx.ones(r) + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + n_iter = 0 + err = 1 + + while n_iter < numIterMax: + if err > stopThr: + n_iter = n_iter + 1 + + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + else: + break + + return Q, R, g + + + + + +############################################################################ +## Test with X_s, X_t from ot.datasets +############################################################################# + +import numpy as np +import ot + +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + + +Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + +print(np.sum(P)) + + + + From 965e4d69113f6fe8eab106412b652dabdbc05712 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 16:47:21 +0200 Subject: [PATCH 09/11] lr sinkhorn, solve_sample, OTResultLazy --- ot/lowrank.py | 40 +++++++------ ot/solvers.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++++++ ot/utils.py | 3 +- 3 files changed, 183 insertions(+), 21 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index ba46cd1ed..a1c73bdf3 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -2,8 +2,10 @@ ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# +## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms -from ot.utils import unif, list_to_array + +from ot.utils import unif, list_to_array, dist from ot.backend import get_backend from ot.datasets import make_1D_gauss as gauss @@ -11,13 +13,13 @@ ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ Implementation of the Dykstra algorithm for low rank Sinkhorn """ # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p # POT backend eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) @@ -58,18 +60,18 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - return Q, R, g, err, dykstra_w + return Q, R, g, err, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. Parameters ---------- @@ -95,7 +97,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - ... + Third low-rank matrix decomposition of the OT plan ''' @@ -108,7 +110,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) - M = ot.dist(X_s,X_t, metric=metric) + M = dist(X_s,X_t, metric=metric) # Compute rank r = min(ns, nt, r) @@ -122,7 +124,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', q3_1, q3_2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] n_iter = 0 err = 1 @@ -139,7 +141,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) else: break @@ -153,18 +155,18 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', ## Test with X_s, X_t from ot.datasets ############################################################################# -import numpy as np -import ot +# import numpy as np +# import ot -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# M = ot.dist(Xs,Xt) +# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) -print(np.sum(P)) +# print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..9c2746c25 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -848,3 +848,164 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + + + +################################## WORK IN PROGRESS ##################################### + +## Implementation of the ot.solve_sample function +## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) + + +from .utils import unif, list_to_array, dist, OTResultLazy +from .bregman import empirical_sinkhorn + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res_lazy : OTResultLazy() + Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. + The information can be obtained as follows: + + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.potentials : OT dual potentials + + See :any:`OTResultLazy` for more information. + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + + X_s, X_t = list_to_array(X_s,X_t) + + # detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) + pass + elif unbalanced_type.lower() in ['kl', 'l2']: + pass + elif unbalanced_type.lower() == 'tv': + pass + pass + ############################################# + + else: + # compute potentials + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + + diff --git a/ot/utils.py b/ot/utils.py index 0936648ca..2f4cfc9e7 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1168,7 +1168,6 @@ def citation(self): } """ - class LazyTensor(object): """ A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. @@ -1233,4 +1232,4 @@ def __getitem__(self, key): return self._getitem(*k, **self.kwargs) def __repr__(self): - return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) + return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) \ No newline at end of file From fd5e26d86e484f310f55a792bca89de13bd7340f Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 25 Oct 2023 17:39:08 +0200 Subject: [PATCH 10/11] add test functions + small modif lr_sin/solve_sample --- ot/lowrank.py | 97 ++++++++++++++++++++++++++++------------- ot/solvers.py | 47 +++++++++++--------- test/test_lowrank.py | 84 ++++++++++++++++++++++++++++++++++++ test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 1 + 5 files changed, 278 insertions(+), 51 deletions(-) create mode 100644 test/test_lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py index a1c73bdf3..22ff8b754 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -4,10 +4,9 @@ ## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms - -from ot.utils import unif, list_to_array, dist -from ot.backend import get_backend -from ot.datasets import make_1D_gauss as gauss +import warnings +from .utils import unif, list_to_array, dist +from .backend import get_backend @@ -15,7 +14,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ - Implementation of the Dykstra algorithm for low rank Sinkhorn + Implementation of the Dykstra algorithm for low rank sinkhorn """ # get dykstra parameters @@ -69,9 +68,12 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", + numItermax=10000, stopThr=1e-9, warn=True, verbose=False): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters ---------- @@ -79,17 +81,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain - reg : float - Regularization term >0 a : array-like, shape (n_samples_a,) samples weights in the source domain b : array-like, shape (n_samples_b,) samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional + Nonnegative rank of the OT plan + alpha: int, optional + Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) + Returns ------- Q : array-like, shape (n_samples_a, r) @@ -97,7 +104,14 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - Third low-rank matrix decomposition of the OT plan + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' @@ -110,13 +124,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + # Compute cost matrix M = dist(X_s,X_t, metric=metric) - + # Compute rank - r = min(ns, nt, r) + rank = min(ns, nt, rank) + r = rank + + if alpha == 'auto': + alpha = 1.0 / (r + 1) + + if (1/r < alpha) or (alpha < 0): + warnings.warn("The provided alpha value might lead to instabilities.") + # Compute gamma - L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) gamma = 1/(2*L) # Initialisation @@ -125,25 +148,34 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - n_iter = 0 err = 1 - while n_iter < numIterMax: - if err > stopThr: - n_iter = n_iter + 1 - - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - else: + for ii in range(numItermax): + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) + + if err < stopThr: break + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") return Q, R, g @@ -161,8 +193,13 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', # Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) # Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# ns = Xs.shape[0] +# nt = Xt.shape[0] + +# a = unif(ns) +# b = unif(nt) -# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) # M = ot.dist(Xs,Xt) # P = np.dot(Q,np.dot(np.diag(1/g),R.T)) diff --git a/ot/solvers.py b/ot/solvers.py index 9c2746c25..c176969ca 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -926,7 +926,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. The information can be obtained as follows: - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - res.potentials : OT dual potentials See :any:`OTResultLazy` for more information. @@ -975,29 +975,34 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if is_Lazy: ################# WIP #################### if reg is None or reg == 0: # EMD solver for isLazy ? - if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) - pass - elif unbalanced_type.lower() in ['kl', 'l2']: - pass - elif unbalanced_type.lower() == 'tv': - pass - pass + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + ############################################# else: - # compute potentials - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) else: # compute cost matrix M and use solve function diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..6e1f24067 --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,84 @@ +##################################################################################################### +####################################### WORK IN PROGRESS ############################################ +##################################################################################################### + + +""" Test for low rank sinkhorn solvers """ + +import ot +import numpy as np +import pytest +from itertools import product + + +def test_LR_Dykstra(): + # test for LR_Dykstra algorithm ? catch nan values ? + pass + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + + Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') + P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + + # check constraints + np.testing.assert_allclose( + a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + a, P_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + b, P_m.sum(0), atol=1e-05) # metric euclidian + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + + + +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +def test_lowrank_sinkhorn_alpha_warning(alpha,rank): + # test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + + + +def test_lowrank_sinkhorn_backends(nx): + # test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) + P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..5a05d54cf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,103 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +########################################################################################################### +############################################ WORK IN PROGRESS ############################################# +########################################################################################################### + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 3a9d590ab..942f403ce 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -569,3 +569,4 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + \ No newline at end of file From 3df3b77de2605d233224e0ccefa6ee127af9f040 Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 26 Oct 2023 10:49:23 +0200 Subject: [PATCH 11/11] add import to __init__ --- ot/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..cb00f4553 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -50,7 +51,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq