diff --git a/README.md b/README.md
index a9a94c53e..3c9b212ce 100644
--- a/README.md
+++ b/README.md
@@ -194,7 +194,8 @@ The numerous contributors to this library are listed [here](CONTRIBUTORS.md).
POT has benefited from the financing or manpower from the following partners:
-

+


+
## Contributions and code of conduct
diff --git a/docs/source/_static/images/logo_hiparis.png b/docs/source/_static/images/logo_hiparis.png
new file mode 100644
index 000000000..1ce6dfb5a
Binary files /dev/null and b/docs/source/_static/images/logo_hiparis.png differ
diff --git a/examples/others/plot_lowrank_sinkhorn.py b/examples/others/plot_lowrank_sinkhorn.py
new file mode 100644
index 000000000..ece35b295
--- /dev/null
+++ b/examples/others/plot_lowrank_sinkhorn.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+"""
+========================================
+Low rank Sinkhorn
+========================================
+
+This example illustrates the computation of Low Rank Sinkhorn [26].
+
+[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
+"Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
+"""
+
+# Author: Laurène David
+#
+# License: MIT License
+#
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100
+m = 120
+
+# Gaussian distribution
+a = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss(n, m=int(5 * n / 6), s=15 / np.sqrt(2))
+a = a / np.sum(a)
+
+b = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss(m, m=int(m / 2), s=35 / np.sqrt(2))
+b = b / np.sum(b)
+
+# Source and target distribution
+X = np.arange(n).reshape(-1, 1)
+Y = np.arange(m).reshape(-1, 1)
+
+
+##############################################################################
+# Solve Low rank sinkhorn
+# ------------
+
+#%%
+# Solve low rank sinkhorn
+Q, R, g, log = ot.lowrank_sinkhorn(X, Y, a, b, rank=10, init="random", gamma_init="rescale", rescale_cost=True, warn=False, log=True)
+P = log["lazy_plan"][:]
+
+ot.plot.plot1D_mat(a, b, P, 'OT matrix Low rank')
+
+
+##############################################################################
+# Sinkhorn vs Low Rank Sinkhorn
+# -----------------------
+# Compare Sinkhorn and Low rank sinkhorn with different regularizations and ranks.
+
+#%% Sinkhorn
+
+# Compute cost matrix for sinkhorn OT
+M = ot.dist(X, Y)
+M = M / np.max(M)
+
+# Solve sinkhorn with different regularizations using ot.solve
+list_reg = [0.05, 0.005, 0.001]
+list_P_Sin = []
+
+for reg in list_reg:
+ P = ot.solve(M, a, b, reg=reg, max_iter=2000, tol=1e-8).plan
+ list_P_Sin.append(P)
+
+#%% Low rank sinkhorn
+
+# Solve low rank sinkhorn with different ranks using ot.solve_sample
+list_rank = [3, 10, 50]
+list_P_LR = []
+
+for rank in list_rank:
+ P = ot.solve_sample(X, Y, a, b, method='lowrank', rank=rank).plan
+ P = P[:]
+ list_P_LR.append(P)
+
+
+#%%
+
+# Plot sinkhorn vs low rank sinkhorn
+pl.figure(1, figsize=(10, 4))
+
+pl.subplot(1, 3, 1)
+pl.imshow(list_P_Sin[0], interpolation='nearest')
+pl.axis('off')
+pl.title('Sinkhorn (reg=0.05)')
+
+pl.subplot(1, 3, 2)
+pl.imshow(list_P_Sin[1], interpolation='nearest')
+pl.axis('off')
+pl.title('Sinkhorn (reg=0.005)')
+
+pl.subplot(1, 3, 3)
+pl.imshow(list_P_Sin[2], interpolation='nearest')
+pl.axis('off')
+pl.title('Sinkhorn (reg=0.001)')
+pl.show()
+
+
+#%%
+
+pl.figure(2, figsize=(10, 4))
+
+pl.subplot(1, 3, 1)
+pl.imshow(list_P_LR[0], interpolation='nearest')
+pl.axis('off')
+pl.title('Low rank (rank=3)')
+
+pl.subplot(1, 3, 2)
+pl.imshow(list_P_LR[1], interpolation='nearest')
+pl.axis('off')
+pl.title('Low rank (rank=10)')
+
+pl.subplot(1, 3, 3)
+pl.imshow(list_P_LR[2], interpolation='nearest')
+pl.axis('off')
+pl.title('Low rank (rank=50)')
+
+pl.tight_layout()
diff --git a/ot/lowrank.py b/ot/lowrank.py
index 5c8f673cb..f6c1469bd 100644
--- a/ot/lowrank.py
+++ b/ot/lowrank.py
@@ -8,14 +8,142 @@
import warnings
-from .utils import unif, get_lowrank_lazytensor
+from .utils import unif, dist, get_lowrank_lazytensor
from .backend import get_backend
+from .bregman import sinkhorn
+# test if sklearn is installed for linux-minimal-deps
+try:
+ import sklearn.cluster
+ sklearn_import = True
+except ImportError:
+ sklearn_import = False
-def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None):
+
+def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init, random_state, nx=None):
+ """
+ Implementation of different initialization strategies for the low rank sinkhorn solver (Q ,R, g).
+ This function is specific to lowrank_sinkhorn.
+
+ Parameters
+ ----------
+ X_s : array-like, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : array-like, shape (n_samples_b, dim)
+ samples in the target domain
+ a : array-like, shape (n_samples_a,)
+ samples weights in the source domain
+ b : array-like, shape (n_samples_b,)
+ samples weights in the target domain
+ rank : int
+ Nonnegative rank of the OT plan.
+ init : str
+ Initialization strategy for Q, R and g. 'random', 'trivial' or 'kmeans'
+ reg_init : float, optional.
+ Regularization term for a 'kmeans' init.
+ random_state : int, optional.
+ Random state for a "random" or 'kmeans' init strategy
+ nx : optional, Default is None
+ POT backend
+
+
+ Returns
+ ---------
+ Q : array-like, shape (n_samples_a, r)
+ Init for the first low-rank matrix decomposition of the OT plan (Q)
+ R: array-like, shape (n_samples_b, r)
+ Init for the second low-rank matrix decomposition of the OT plan (R)
+ g : array-like, shape (r, )
+ Init for the weight vector of the low-rank decomposition of the OT plan (g)
+
+
+ References
+ -----------
+ .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
+ "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
+
+ """
+
+ if nx is None:
+ nx = get_backend(X_s, X_t, a, b)
+
+ ns = X_s.shape[0]
+ nt = X_t.shape[0]
+ r = rank
+
+ if init == "random":
+ nx.seed(seed=random_state)
+
+ # Init g
+ g = nx.abs(nx.randn(r, type_as=X_s)) + 1
+ g = g / nx.sum(g)
+
+ # Init Q
+ Q = nx.abs(nx.randn(ns, r, type_as=X_s)) + 1
+ Q = (Q.T * (a / nx.sum(Q, axis=1))).T
+
+ # Init R
+ R = nx.abs(nx.randn(nt, rank, type_as=X_s)) + 1
+ R = (R.T * (b / nx.sum(R, axis=1))).T
+
+ if init == "deterministic":
+ # Init g
+ g = nx.ones(rank) / rank
+
+ lambda_1 = min(nx.min(a), nx.min(g), nx.min(b)) / 2
+ a1 = nx.arange(start=1, stop=ns + 1, type_as=X_s)
+ a1 = a1 / nx.sum(a1)
+ a2 = (a - lambda_1 * a1) / (1 - lambda_1)
+
+ b1 = nx.arange(start=1, stop=nt + 1, type_as=X_s)
+ b1 = b1 / nx.sum(b1)
+ b2 = (b - lambda_1 * b1) / (1 - lambda_1)
+
+ g1 = nx.arange(start=1, stop=rank + 1, type_as=X_s)
+ g1 = g1 / nx.sum(g1)
+ g2 = (g - lambda_1 * g1) / (1 - lambda_1)
+
+ # Init Q
+ Q1 = lambda_1 * nx.dot(a1[:, None], nx.reshape(g1, (1, -1)))
+ Q2 = (1 - lambda_1) * nx.dot(a2[:, None], nx.reshape(g2, (1, -1)))
+ Q = Q1 + Q2
+
+ # Init R
+ R1 = lambda_1 * nx.dot(b1[:, None], nx.reshape(g1, (1, -1)))
+ R2 = (1 - lambda_1) * nx.dot(b2[:, None], nx.reshape(g2, (1, -1)))
+ R = R1 + R2
+
+ if init == "kmeans":
+ if sklearn_import:
+ # Init g
+ g = nx.ones(rank, type_as=X_s) / rank
+
+ # Init Q
+ kmeans_Xs = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto")
+ kmeans_Xs.fit(X_s)
+ Z_Xs = nx.from_numpy(kmeans_Xs.cluster_centers_)
+ C_Xs = dist(X_s, Z_Xs) # shape (ns, rank)
+ C_Xs = C_Xs / nx.max(C_Xs)
+ Q = sinkhorn(a, g, C_Xs, reg=reg_init, numItermax=10000, stopThr=1e-3)
+
+ # Init R
+ kmeans_Xt = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto")
+ kmeans_Xt.fit(X_t)
+ Z_Xt = nx.from_numpy(kmeans_Xt.cluster_centers_)
+ C_Xt = dist(X_t, Z_Xt) # shape (nt, rank)
+ C_Xt = C_Xt / nx.max(C_Xt)
+ R = sinkhorn(b, g, C_Xt, reg=reg_init, numItermax=10000, stopThr=1e-3)
+
+ else:
+ raise ImportError("Scikit-learn should be installed to use the 'kmeans' init.")
+
+ return Q, R, g
+
+
+def compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost, nx=None):
"""
Compute the low rank decomposition of a squared euclidean distance matrix.
- This function won't work for any other distance metric.
+ This function won't work for other distance metrics.
See "Section 3.5, proposition 1"
@@ -25,7 +153,10 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None):
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
- nx : POT backend, default none
+ rescale_cost : bool
+ Rescale the low rank factorization of the sqeuclidean cost matrix
+ nx : default None
+ POT backend
Returns
@@ -37,9 +168,9 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None):
References
- ----------
+ -----------
.. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
- "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
+ "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
"""
if nx is None:
@@ -50,14 +181,18 @@ def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None):
# First low rank decomposition of the cost matrix (A)
array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1))
- array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1))
+ array2 = nx.ones((ns, 1), type_as=X_s)
M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1)
# Second low rank decomposition of the cost matrix (B)
- array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1))
+ array1 = nx.ones((nt, 1), type_as=X_s)
array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1))
M2 = nx.concatenate((array1, array2, X_t), axis=1)
+ if rescale_cost is True:
+ M1 = M1 / nx.sqrt(nx.max(M1))
+ M2 = M2 / nx.sqrt(nx.max(M2))
+
return M1, M2
@@ -103,7 +238,7 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N
References
----------
.. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
- "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
+ "Low-rank Sinkhorn Factorization". In International Conference on Machine Learning.
"""
@@ -163,7 +298,7 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N
else:
if warn:
warnings.warn(
- "Sinkhorn did not converge. You might want to "
+ "Dykstra did not converge. You might want to "
"increase the number of iterations `numItermax` "
)
@@ -174,10 +309,12 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N
return Q, R, g
-def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None,
- numItermax=1000, stopThr=1e-9, warn=True, log=False):
+def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, rescale_cost=True,
+ init="random", reg_init=1e-1, seed_init=49, gamma_init="rescale",
+ numItermax=2000, stopThr=1e-7, warn=True, log=False):
r"""
- Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints.
+ Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints
+ on the couplings.
The function solves the following optimization problem:
@@ -207,14 +344,26 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None,
samples weights in the target domain
reg : float, optional
Regularization term >0
- rank: int, optional. Default is None. (>0)
+ rank : int, optional. Default is None. (>0)
Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.
- alpha: int, optional. Default is None. (>0 and <1/r)
- Lower bound for the weight vector g. If None, 1e-10 is considered
- numItermax : int, optional
- Max number of iterations
- stopThr : float, optional
- Stop threshold on error (>0)
+ alpha : int, optional. Default is 1e-10. (>0 and <1/r)
+ Lower bound for the weight vector g.
+ rescale_cost : bool, optional. Default is False
+ Rescale the low rank factorization of the sqeuclidean cost matrix
+ init : str, optional. Default is 'random'.
+ Initialization strategy for the low rank couplings. 'random', 'deterministic' or 'kmeans'
+ reg_init : float, optional. Default is 1e-1. (>0)
+ Regularization term for a 'kmeans' init. If None, 1 is considered.
+ seed_init : int, optional. Default is 49. (>0)
+ Random state for a 'random' or 'kmeans' init strategy.
+ gamma_init : str, optional. Default is "rescale".
+ Initialization strategy for gamma. 'rescale', or 'theory'
+ Gamma is a constant that scales the convergence criterion of the Mirror Descent
+ optimization scheme used to compute the low-rank couplings (Q, R and g)
+ numItermax : int, optional. Default is 2000.
+ Max number of iterations for the Dykstra algorithm
+ stopThr : float, optional. Default is 1e-7.
+ Stop threshold on error (>0) in Dykstra
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
log : bool, optional
@@ -222,26 +371,21 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None,
Returns
- -------
- lazy_plan : LazyTensor()
- OT plan in a LazyTensor object of shape (shape_plan)
- See :any:`LazyTensor` for more information.
- value : float
- Optimal value of the optimization problem
- value_linear : float
- Linear OT loss with the optimal OT
+ ---------
Q : array-like, shape (n_samples_a, r)
First low-rank matrix decomposition of the OT plan
R: array-like, shape (n_samples_b, r)
Second low-rank matrix decomposition of the OT plan
g : array-like, shape (r, )
- Weight vector for the low-rank decomposition of the OT plan
+ Weight vector for the low-rank decomposition of the OT
+ log : dict (lazy_plan, value and value_linear)
+ log dictionary return only if log==True in parameters
References
----------
- .. [65] Scetbon, M., Cuturi, M., & Peyré, G (2021).
- "Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737.
+ .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
+ "Low-rank Sinkhorn Factorization". In International Conference on Machine Learning.
"""
@@ -259,59 +403,70 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None,
r = rank
if rank is None:
r = min(ns, nt)
+ else:
+ r = min(ns, nt, rank)
- if alpha is None:
- alpha = 1e-10
+ if r <= 0:
+ raise ValueError("The rank parameter cannot have a negative value")
- # Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank)
- # (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper)
+ # Dykstra won't converge if 1/rank < alpha (see Section 3.2)
if 1 / r < alpha:
raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format(
a=alpha, r=1 / rank))
- if r <= 0:
- raise ValueError("The rank parameter cannot have a negative value")
+ # Low rank decomposition of the sqeuclidean cost matrix
+ M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost, nx)
- # Low rank decomposition of the sqeuclidean cost matrix (A, B)
- M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None)
+ # Initialize the low rank matrices Q, R, g
+ Q, R, g = _init_lr_sinkhorn(X_s, X_t, a, b, r, init, reg_init, seed_init, nx=nx)
- # Compute gamma (see "Section 3.4, proposition 4" in the paper)
- L = nx.sqrt(
- 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) +
- (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2
- )
- gamma = 1 / (2 * L)
+ # Gamma initialization
+ if gamma_init == "theory":
+ L = nx.sqrt(
+ 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) +
+ (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2
+ )
+ gamma = 1 / (2 * L)
- # Initialize the low rank matrices Q, R, g
- Q = nx.ones((ns, r), type_as=a)
- R = nx.ones((nt, r), type_as=a)
- g = nx.ones(r, type_as=a)
- k = 100
+ if gamma_init not in ["rescale", "theory"]:
+ raise (NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init)))
# -------------------------- Low rank algorithm ------------------------------
- # see "Section 3.3, Algorithm 3 LOT" in the paper
+ # see "Section 3.3, Algorithm 3 LOT"
- for ii in range(k):
- # Compute the C*R dot matrix using the lr decomposition of C
- CR_ = nx.dot(M2.T, R)
- CR = nx.dot(M1, CR_)
+ for ii in range(100):
+ # Compute C*R dot using the lr decomposition of C
+ CR = nx.dot(M2.T, R)
+ CR_ = nx.dot(M1, CR)
+ diag_g = (1 / g)[None, :]
+ CR_g = CR_ * diag_g
- # Compute the C.t * Q dot matrix using the lr decomposition of C
- CQ_ = nx.dot(M1.T, Q)
- CQ = nx.dot(M2, CQ_)
+ # Compute C.T * Q using the lr decomposition of C
+ CQ = nx.dot(M1.T, Q)
+ CQ_ = nx.dot(M2, CQ)
+ CQ_g = CQ_ * diag_g
- diag_g = (1 / g)[None, :]
+ # Compute omega
+ omega = nx.diag(nx.dot(Q.T, CR_))
+
+ # Rescale gamma at each iteration
+ if gamma_init == "rescale":
+ norm_1 = nx.max(nx.abs(CR_ * diag_g + reg * nx.log(Q))) ** 2
+ norm_2 = nx.max(nx.abs(CQ_ * diag_g + reg * nx.log(R))) ** 2
+ norm_3 = nx.max(nx.abs(-omega * diag_g)) ** 2
+ gamma = 10 / max(norm_1, norm_2, norm_3)
- eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q))
- eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R))
- omega = nx.diag(nx.dot(Q.T, CR))
- eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g))
+ eps1 = nx.exp(-gamma * CR_g - ((gamma * reg) - 1) * nx.log(Q))
+ eps2 = nx.exp(-gamma * CQ_g - ((gamma * reg) - 1) * nx.log(R))
+ eps3 = nx.exp((gamma * omega / (g**2)) - (gamma * reg - 1) * nx.log(g))
+ # LR Dykstra algorithm
Q, R, g = _LR_Dysktra(
eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx
)
Q = Q + 1e-16
R = R + 1e-16
+ g = g + 1e-16
# ----------------- Compute lazy_plan, value and value_linear ------------------
# see "Section 3.2: The Low-rank OT Problem" in the paper
@@ -324,7 +479,7 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None,
v2 = nx.dot(R, (v1.T * diag_g).T)
value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2)))
- # Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper)
+ # Compute value with entropy reg (see "Section 3.2" in the paper)
reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q
reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g
reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R
diff --git a/ot/solvers.py b/ot/solvers.py
index 40a03e974..c4c0c79ed 100644
--- a/ot/solvers.py
+++ b/ot/solvers.py
@@ -1173,6 +1173,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.
+ .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
+ Low-rank Sinkhorn Factorization. In International Conference on
+ Machine Learning.
+
"""
@@ -1255,13 +1259,13 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
raise (NotImplementedError('Not implemented metric="{}"'.format(metric)))
if max_iter is None:
- max_iter = 1000
+ max_iter = 2000
if tol is None:
- tol = 1e-9
+ tol = 1e-7
if reg is None:
reg = 0
- Q, R, g, log = lowrank_sinkhorn(X_a, X_b, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True)
+ Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True)
value = log['value']
value_linear = log['value_linear']
lazy_plan = log['lazy_plan']
diff --git a/test/test_lowrank.py b/test/test_lowrank.py
index 65f76a77b..60b2d633f 100644
--- a/test/test_lowrank.py
+++ b/test/test_lowrank.py
@@ -7,6 +7,7 @@
import ot
import numpy as np
import pytest
+from ot.lowrank import sklearn_import # check sklearn installation
def test_compute_lr_sqeuclidean_matrix():
@@ -15,7 +16,7 @@ def test_compute_lr_sqeuclidean_matrix():
X_s = np.reshape(1.0 * np.arange(2 * n), (n, 2))
X_t = np.reshape(1.0 * np.arange(2 * n), (n, 2))
- M1, M2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t)
+ M1, M2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost=False)
M = ot.dist(X_s, X_t, metric="sqeuclidean") # original cost matrix
np.testing.assert_allclose(np.dot(M1, M2.T), M, atol=1e-05)
@@ -30,7 +31,7 @@ def test_lowrank_sinkhorn():
X_s = np.reshape(1.0 * np.arange(n), (n, 1))
X_t = np.reshape(1.0 * np.arange(n), (n, 1))
- Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True)
+ Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False)
P = log["lazy_plan"][:]
value_linear = log["value_linear"]
@@ -52,6 +53,30 @@ def test_lowrank_sinkhorn():
ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1)
+@pytest.mark.parametrize(("init"), ("random", "deterministic", "kmeans"))
+def test_lowrank_sinkhorn_init(init):
+ # test lowrank inits
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(n), (n, 1))
+
+ # test ImportError if init="kmeans" and sklearn not imported
+ if init in ["random", "deterministic"] or ((init == "kmeans") and (sklearn_import is True)):
+ Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, init=init, log=True)
+ P = log["lazy_plan"][:]
+
+ # check constraints for P
+ np.testing.assert_allclose(a, P.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, P.sum(0), atol=1e-05)
+
+ else:
+ with pytest.raises(ImportError):
+ Q, R, g = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, init=init)
+
+
@pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6)))
def test_lowrank_sinkhorn_alpha_error(alpha, rank):
# Test warning for value of alpha
@@ -63,9 +88,25 @@ def test_lowrank_sinkhorn_alpha_error(alpha, rank):
X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
with pytest.raises(ValueError):
- ot.lowrank.lowrank_sinkhorn(
- X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False
- )
+ ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False)
+
+
+@pytest.mark.parametrize(("gamma_init"), ("rescale", "theory"))
+def test_lowrank_sinkhorn_gamma_init(gamma_init):
+ # Test lr sinkhorn with different init strategies
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(n), (n, 1))
+
+ Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True)
+ P = log["lazy_plan"][:]
+
+ # check constraints for P
+ np.testing.assert_allclose(a, P.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, P.sum(0), atol=1e-05)
@pytest.skip_backend('tf')
diff --git a/test/test_solvers.py b/test/test_solvers.py
index 343220c45..164989811 100644
--- a/test/test_solvers.py
+++ b/test/test_solvers.py
@@ -30,7 +30,7 @@
{'method': 'gaussian'},
{'method': 'gaussian', 'reg': 1},
{'method': 'factored', 'rank': 10},
- {'method': 'lowrank', 'reg': 0.1}
+ {'method': 'lowrank', 'rank': 10}
]
lst_parameters_solve_sample_NotImplemented = [