Skip to content

Commit 7ff43c9

Browse files
authored
Merge branch 'master' into da-on-jax
2 parents 95357c1 + f5fbdd6 commit 7ff43c9

File tree

7 files changed

+404
-74
lines changed

7 files changed

+404
-74
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ The numerous contributors to this library are listed [here](CONTRIBUTORS.md).
194194

195195
POT has benefited from the financing or manpower from the following partners:
196196

197-
<img src="https://pythonot.github.io/master/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_cnrs.jpg" alt="CNRS" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_3ia.jpg" alt="3IA" style="height:60px;"/>
197+
<img src="https://pythonot.github.io/master/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_cnrs.jpg" alt="CNRS" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_3ia.jpg" alt="3IA" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_hiparis.png" alt="Hi!PARIS" style="height:60px;"/>
198+
198199

199200

200201
## Contributions and code of conduct
76.9 KB
Loading
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================
4+
Low rank Sinkhorn
5+
========================================
6+
7+
This example illustrates the computation of Low Rank Sinkhorn [26].
8+
9+
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
10+
"Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
11+
"""
12+
13+
# Author: Laurène David <laurene.david@ip-paris.fr>
14+
#
15+
# License: MIT License
16+
#
17+
# sphinx_gallery_thumbnail_number = 2
18+
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import ot.plot
22+
from ot.datasets import make_1D_gauss as gauss
23+
24+
##############################################################################
25+
# Generate data
26+
# -------------
27+
28+
#%% parameters
29+
30+
n = 100
31+
m = 120
32+
33+
# Gaussian distribution
34+
a = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss(n, m=int(5 * n / 6), s=15 / np.sqrt(2))
35+
a = a / np.sum(a)
36+
37+
b = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss(m, m=int(m / 2), s=35 / np.sqrt(2))
38+
b = b / np.sum(b)
39+
40+
# Source and target distribution
41+
X = np.arange(n).reshape(-1, 1)
42+
Y = np.arange(m).reshape(-1, 1)
43+
44+
45+
##############################################################################
46+
# Solve Low rank sinkhorn
47+
# ------------
48+
49+
#%%
50+
# Solve low rank sinkhorn
51+
Q, R, g, log = ot.lowrank_sinkhorn(X, Y, a, b, rank=10, init="random", gamma_init="rescale", rescale_cost=True, warn=False, log=True)
52+
P = log["lazy_plan"][:]
53+
54+
ot.plot.plot1D_mat(a, b, P, 'OT matrix Low rank')
55+
56+
57+
##############################################################################
58+
# Sinkhorn vs Low Rank Sinkhorn
59+
# -----------------------
60+
# Compare Sinkhorn and Low rank sinkhorn with different regularizations and ranks.
61+
62+
#%% Sinkhorn
63+
64+
# Compute cost matrix for sinkhorn OT
65+
M = ot.dist(X, Y)
66+
M = M / np.max(M)
67+
68+
# Solve sinkhorn with different regularizations using ot.solve
69+
list_reg = [0.05, 0.005, 0.001]
70+
list_P_Sin = []
71+
72+
for reg in list_reg:
73+
P = ot.solve(M, a, b, reg=reg, max_iter=2000, tol=1e-8).plan
74+
list_P_Sin.append(P)
75+
76+
#%% Low rank sinkhorn
77+
78+
# Solve low rank sinkhorn with different ranks using ot.solve_sample
79+
list_rank = [3, 10, 50]
80+
list_P_LR = []
81+
82+
for rank in list_rank:
83+
P = ot.solve_sample(X, Y, a, b, method='lowrank', rank=rank).plan
84+
P = P[:]
85+
list_P_LR.append(P)
86+
87+
88+
#%%
89+
90+
# Plot sinkhorn vs low rank sinkhorn
91+
pl.figure(1, figsize=(10, 4))
92+
93+
pl.subplot(1, 3, 1)
94+
pl.imshow(list_P_Sin[0], interpolation='nearest')
95+
pl.axis('off')
96+
pl.title('Sinkhorn (reg=0.05)')
97+
98+
pl.subplot(1, 3, 2)
99+
pl.imshow(list_P_Sin[1], interpolation='nearest')
100+
pl.axis('off')
101+
pl.title('Sinkhorn (reg=0.005)')
102+
103+
pl.subplot(1, 3, 3)
104+
pl.imshow(list_P_Sin[2], interpolation='nearest')
105+
pl.axis('off')
106+
pl.title('Sinkhorn (reg=0.001)')
107+
pl.show()
108+
109+
110+
#%%
111+
112+
pl.figure(2, figsize=(10, 4))
113+
114+
pl.subplot(1, 3, 1)
115+
pl.imshow(list_P_LR[0], interpolation='nearest')
116+
pl.axis('off')
117+
pl.title('Low rank (rank=3)')
118+
119+
pl.subplot(1, 3, 2)
120+
pl.imshow(list_P_LR[1], interpolation='nearest')
121+
pl.axis('off')
122+
pl.title('Low rank (rank=10)')
123+
124+
pl.subplot(1, 3, 3)
125+
pl.imshow(list_P_LR[2], interpolation='nearest')
126+
pl.axis('off')
127+
pl.title('Low rank (rank=50)')
128+
129+
pl.tight_layout()

0 commit comments

Comments
 (0)