Skip to content

Commit 077184f

Browse files
committed
fix linux-minimal-deps + code review
1 parent 477deed commit 077184f

File tree

3 files changed

+55
-42
lines changed

3 files changed

+55
-42
lines changed

examples/others/plot_lowrank_sinkhorn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# Author: Laurène David <laurene.david@ip-paris.fr>
1414
#
1515
# License: MIT License
16+
#
17+
# sphinx_gallery_thumbnail_number = 2
1618

1719
import numpy as np
1820
import matplotlib.pylab as pl
@@ -86,7 +88,7 @@
8688
#%%
8789

8890
# Plot sinkhorn vs low rank sinkhorn
89-
pl.figure(3, figsize=(10, 4))
91+
pl.figure(1, figsize=(10, 4))
9092

9193
pl.subplot(1, 3, 1)
9294
pl.imshow(list_P_Sin[0], interpolation='nearest')
@@ -107,7 +109,7 @@
107109

108110
#%%
109111

110-
pl.figure(3, figsize=(10, 4))
112+
pl.figure(2, figsize=(10, 4))
111113

112114
pl.subplot(1, 3, 1)
113115
pl.imshow(list_P_LR[0], interpolation='nearest')

ot/lowrank.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
from .utils import unif, dist, get_lowrank_lazytensor
1212
from .backend import get_backend
1313
from .bregman import sinkhorn
14-
from sklearn.cluster import KMeans
1514

15+
# test if sklearn is installed for linux-minimal-deps
16+
try:
17+
import sklearn.cluster
18+
sklearn_import = True
19+
except ImportError:
20+
sklearn_import = False
1621

17-
def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=None, nx=None):
22+
23+
def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init, random_state, nx=None):
1824
"""
1925
Implementation of different initialization strategies for the low rank sinkhorn solver (Q ,R, g).
2026
This function is specific to lowrank_sinkhorn.
@@ -33,11 +39,11 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
3339
Nonnegative rank of the OT plan.
3440
init : str
3541
Initialization strategy for Q, R and g. 'random', 'trivial' or 'kmeans'
36-
reg_init : float, optional. Default is None. (>0)
37-
Regularization term for a 'kmeans' init. If None, 1 is considered.
38-
random_state : default None
42+
reg_init : float, optional.
43+
Regularization term for a 'kmeans' init.
44+
random_state : int, optional.
3945
Random state for a "random" or 'kmeans' init strategy
40-
nx : default None
46+
nx : optional, Default is None
4147
POT backend
4248
4349
@@ -61,12 +67,6 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
6167
if nx is None:
6268
nx = get_backend(X_s, X_t, a, b)
6369

64-
if reg_init is None:
65-
reg_init = 0.1
66-
67-
if random_state is None:
68-
random_state = 49
69-
7070
ns = X_s.shape[0]
7171
nt = X_t.shape[0]
7272
r = rank
@@ -86,7 +86,7 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
8686
R = nx.abs(nx.randn(nt, rank, type_as=X_s)) + 1
8787
R = (R.T * (b / nx.sum(R, axis=1))).T
8888

89-
if init == "trivial":
89+
if init == "deterministic":
9090
# Init g
9191
g = nx.ones(rank) / rank
9292

@@ -114,24 +114,28 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
114114
R = R1 + R2
115115

116116
if init == "kmeans":
117-
# Init g
118-
g = nx.ones(rank, type_as=X_s) / rank
119-
120-
# Init Q
121-
kmeans_Xs = KMeans(n_clusters=rank, random_state=random_state, n_init="auto")
122-
kmeans_Xs.fit(X_s)
123-
Z_Xs = nx.from_numpy(kmeans_Xs.cluster_centers_)
124-
C_Xs = dist(X_s, Z_Xs) # shape (ns, rank)
125-
C_Xs = C_Xs / nx.max(C_Xs)
126-
Q = sinkhorn(a, g, C_Xs, reg=reg_init, numItermax=10000, stopThr=1e-3)
117+
if sklearn_import:
118+
# Init g
119+
g = nx.ones(rank, type_as=X_s) / rank
120+
121+
# Init Q
122+
kmeans_Xs = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto")
123+
kmeans_Xs.fit(X_s)
124+
Z_Xs = nx.from_numpy(kmeans_Xs.cluster_centers_)
125+
C_Xs = dist(X_s, Z_Xs) # shape (ns, rank)
126+
C_Xs = C_Xs / nx.max(C_Xs)
127+
Q = sinkhorn(a, g, C_Xs, reg=reg_init, numItermax=10000, stopThr=1e-3)
128+
129+
# Init R
130+
kmeans_Xt = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto")
131+
kmeans_Xt.fit(X_t)
132+
Z_Xt = nx.from_numpy(kmeans_Xt.cluster_centers_)
133+
C_Xt = dist(X_t, Z_Xt) # shape (nt, rank)
134+
C_Xt = C_Xt / nx.max(C_Xt)
135+
R = sinkhorn(b, g, C_Xt, reg=reg_init, numItermax=10000, stopThr=1e-3)
127136

128-
# Init R
129-
kmeans_Xt = KMeans(n_clusters=rank, random_state=random_state, n_init="auto")
130-
kmeans_Xt.fit(X_t)
131-
Z_Xt = nx.from_numpy(kmeans_Xt.cluster_centers_)
132-
C_Xt = dist(X_t, Z_Xt) # shape (nt, rank)
133-
C_Xt = C_Xt / nx.max(C_Xt)
134-
R = sinkhorn(b, g, C_Xt, reg=reg_init, numItermax=10000, stopThr=1e-3)
137+
else:
138+
raise ImportError("Scikit-learn should be installed to use the 'kmeans' init.")
135139

136140
return Q, R, g
137141

@@ -306,7 +310,7 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N
306310

307311

308312
def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, rescale_cost=True,
309-
init="random", reg_init=None, seed_init=None, gamma_init="rescale",
313+
init="random", reg_init=1e-1, seed_init=49, gamma_init="rescale",
310314
numItermax=2000, stopThr=1e-7, warn=True, log=False):
311315
r"""
312316
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints
@@ -347,10 +351,10 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, re
347351
rescale_cost : bool, optional. Default is False
348352
Rescale the low rank factorization of the sqeuclidean cost matrix
349353
init : str, optional. Default is 'random'.
350-
Initialization strategy for the low rank couplings. 'random', 'trivial' or 'kmeans'
351-
reg_init : float, optional. Default is None. (>0)
354+
Initialization strategy for the low rank couplings. 'random', 'deterministic' or 'kmeans'
355+
reg_init : float, optional. Default is 1e-1. (>0)
352356
Regularization term for a 'kmeans' init. If None, 1 is considered.
353-
seed_init : int, optional. Default is None. (>0)
357+
seed_init : int, optional. Default is 49. (>0)
354358
Random state for a 'random' or 'kmeans' init strategy.
355359
gamma_init : str, optional. Default is "rescale".
356360
Initialization strategy for gamma. 'rescale', or 'theory'

test/test_lowrank.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import ot
88
import numpy as np
99
import pytest
10+
from ot.lowrank import sklearn_import # check sklearn installation
1011

1112

1213
def test_compute_lr_sqeuclidean_matrix():
@@ -52,7 +53,7 @@ def test_lowrank_sinkhorn():
5253
ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1)
5354

5455

55-
@pytest.mark.parametrize(("init"), ("random", "trivial", "kmeans"))
56+
@pytest.mark.parametrize(("init"), ("random", "deterministic", "kmeans"))
5657
def test_lowrank_sinkhorn_init(init):
5758
# test lowrank inits
5859
n = 100
@@ -62,12 +63,18 @@ def test_lowrank_sinkhorn_init(init):
6263
X_s = np.reshape(1.0 * np.arange(n), (n, 1))
6364
X_t = np.reshape(1.0 * np.arange(n), (n, 1))
6465

65-
Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True, init=init, reg_init=1)
66-
P = log["lazy_plan"][:]
66+
# test ImportError if init="kmeans" and sklearn not imported
67+
if init in ["random", "deterministic"] or ((init == "kmeans") and (sklearn_import is True)):
68+
Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, init=init, log=True)
69+
P = log["lazy_plan"][:]
6770

68-
# check constraints for P
69-
np.testing.assert_allclose(a, P.sum(1), atol=1e-05)
70-
np.testing.assert_allclose(b, P.sum(0), atol=1e-05)
71+
# check constraints for P
72+
np.testing.assert_allclose(a, P.sum(1), atol=1e-05)
73+
np.testing.assert_allclose(b, P.sum(0), atol=1e-05)
74+
75+
else:
76+
with pytest.raises(ImportError):
77+
Q, R, g = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, init=init)
7178

7279

7380
@pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6)))

0 commit comments

Comments
 (0)