Skip to content

Commit 0024d07

Browse files
authored
[MRG] Low rank sinkhorn algorithm (#568)
* new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * modify low rank, remove solve_sample,OTResultLazy * new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * remove test solve_sample * add value, value_linear, lazy_plan * add comments to lr algorithm * modify test functions + add comments to lowrank * modify __init__ with lowrank * debug lowrank + test * debug test function low_rank * error test * final debug of lowrank + add new test functions * Debug tests + add lowrank to solve_sample * fix torch backend for lowrank * fix jax backend and skip tf * fix pep 8 tests
1 parent 659cde8 commit 0024d07

File tree

8 files changed

+459
-2
lines changed

8 files changed

+459
-2
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ The contributors to this library are:
4545
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
4646
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
4747
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
48+
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn)
4849

4950
## Acknowledgments
5051

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
347347
[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations.
348348

349349
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.
350+
351+
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
2323
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
2424
+ Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584)
25+
+ Added support for [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf) (PR #568)
2526

2627

2728
#### Closed issues

ot/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from . import factored
3636
from . import solvers
3737
from . import gaussian
38+
from . import lowrank
3839

3940

4041
# OT functions
@@ -52,6 +53,7 @@
5253
from .weak import weak_optimal_transport
5354
from .factored import factored_optimal_transport
5455
from .solvers import solve, solve_gromov, solve_sample
56+
from .lowrank import lowrank_sinkhorn
5557

5658
# utils functions
5759
from .utils import dist, unif, tic, toc, toq
@@ -69,4 +71,4 @@
6971
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
7072
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
7173
'binary_search_circle', 'wasserstein_circle',
72-
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
74+
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn']

ot/lowrank.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
"""
2+
Low rank OT solvers
3+
"""
4+
5+
# Author: Laurène David <laurene.david@ip-paris.fr>
6+
#
7+
# License: MIT License
8+
9+
10+
import warnings
11+
from .utils import unif, get_lowrank_lazytensor
12+
from .backend import get_backend
13+
14+
15+
def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None):
16+
"""
17+
Compute the low rank decomposition of a squared euclidean distance matrix.
18+
This function won't work for any other distance metric.
19+
20+
See "Section 3.5, proposition 1"
21+
22+
Parameters
23+
----------
24+
X_s : array-like, shape (n_samples_a, dim)
25+
samples in the source domain
26+
X_t : array-like, shape (n_samples_b, dim)
27+
samples in the target domain
28+
nx : POT backend, default none
29+
30+
31+
Returns
32+
----------
33+
M1 : array-like, shape (n_samples_a, dim+2)
34+
First low rank decomposition of the distance matrix
35+
M2 : array-like, shape (n_samples_b, dim+2)
36+
Second low rank decomposition of the distance matrix
37+
38+
39+
References
40+
----------
41+
.. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
42+
"Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
43+
"""
44+
45+
if nx is None:
46+
nx = get_backend(X_s, X_t)
47+
48+
ns = X_s.shape[0]
49+
nt = X_t.shape[0]
50+
51+
# First low rank decomposition of the cost matrix (A)
52+
array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1))
53+
array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1))
54+
M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1)
55+
56+
# Second low rank decomposition of the cost matrix (B)
57+
array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1))
58+
array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1))
59+
M2 = nx.concatenate((array1, array2, X_t), axis=1)
60+
61+
return M1, M2
62+
63+
64+
def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None):
65+
"""
66+
Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver.
67+
This function is specific to lowrank_sinkhorn.
68+
69+
Parameters
70+
----------
71+
eps1 : array-like, shape (n_samples_a, r)
72+
First input parameter of the Dykstra algorithm
73+
eps2 : array-like, shape (n_samples_b, r)
74+
Second input parameter of the Dykstra algorithm
75+
eps3 : array-like, shape (r,)
76+
Third input parameter of the Dykstra algorithm
77+
p1 : array-like, shape (n_samples_a,)
78+
Samples weights in the source domain (same as "a" in lowrank_sinkhorn)
79+
p2 : array-like, shape (n_samples_b,)
80+
Samples weights in the target domain (same as "b" in lowrank_sinkhorn)
81+
alpha: int
82+
Lower bound for the weight vector g (same as "alpha" in lowrank_sinkhorn)
83+
stopThr : float
84+
Stop threshold on error
85+
numItermax : int
86+
Max number of iterations
87+
warn : bool, optional
88+
if True, raises a warning if the algorithm doesn't convergence.
89+
nx : default None
90+
POT backend
91+
92+
93+
Returns
94+
----------
95+
Q : array-like, shape (n_samples_a, r)
96+
Dykstra update of the first low-rank matrix decomposition Q
97+
R: array-like, shape (n_samples_b, r)
98+
Dykstra update of the Second low-rank matrix decomposition R
99+
g : array-like, shape (r, )
100+
Dykstra update of the weight vector g
101+
102+
103+
References
104+
----------
105+
.. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
106+
"Low-rank Sinkhorn factorization". In International Conference on Machine Learning.
107+
108+
"""
109+
110+
# POT backend if None
111+
if nx is None:
112+
nx = get_backend(eps1, eps2, eps3, p1, p2)
113+
114+
# ----------------- Initialisation of Dykstra algorithm -----------------
115+
r = len(eps3) # rank
116+
g_ = nx.copy(eps3) # \tilde{g}
117+
q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2
118+
v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)}
119+
q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)}
120+
err = 1 # initial error
121+
122+
# --------------------- Dykstra algorithm -------------------------
123+
124+
# See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper
125+
126+
for ii in range(numItermax):
127+
if err > stopThr:
128+
# Compute u^{(1)} and u^{(2)}
129+
u1 = p1 / nx.dot(eps1, v1_)
130+
u2 = p2 / nx.dot(eps2, v2_)
131+
132+
# Compute g, g^{(3)}_1 and update \tilde{g}
133+
g = nx.maximum(alpha, g_ * q3_1)
134+
q3_1 = (g_ * q3_1) / g
135+
g_ = nx.copy(g)
136+
137+
# Compute new value of g with \prod
138+
prod1 = (v1_ * q1) * nx.dot(eps1.T, u1)
139+
prod2 = (v2_ * q2) * nx.dot(eps2.T, u2)
140+
g = (g_ * q3_2 * prod1 * prod2) ** (1 / 3)
141+
142+
# Compute v^{(1)} and v^{(2)}
143+
v1 = g / nx.dot(eps1.T, u1)
144+
v2 = g / nx.dot(eps2.T, u2)
145+
146+
# Compute q^{(1)}, q^{(2)} and q^{(3)}_2
147+
q1 = (v1_ * q1) / v1
148+
q2 = (v2_ * q2) / v2
149+
q3_2 = (g_ * q3_2) / g
150+
151+
# Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g}
152+
v1_, v2_ = nx.copy(v1), nx.copy(v2)
153+
g_ = nx.copy(g)
154+
155+
# Compute error
156+
err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1))
157+
err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2))
158+
err = err1 + err2
159+
160+
else:
161+
break
162+
163+
else:
164+
if warn:
165+
warnings.warn(
166+
"Sinkhorn did not converge. You might want to "
167+
"increase the number of iterations `numItermax` "
168+
)
169+
170+
# Compute low rank matrices Q, R
171+
Q = u1[:, None] * eps1 * v1[None, :]
172+
R = u2[:, None] * eps2 * v2[None, :]
173+
174+
return Q, R, g
175+
176+
177+
def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=None,
178+
numItermax=1000, stopThr=1e-9, warn=True, log=False):
179+
r"""
180+
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints.
181+
182+
The function solves the following optimization problem:
183+
184+
.. math::
185+
\mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle -
186+
\mathrm{reg} \cdot H((Q,R,g))
187+
188+
where :
189+
- :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix
190+
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term.
191+
- :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan
192+
- :math: `g` is the weight vector for the low-rank decomposition of the OT plan
193+
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
194+
- :math: `r` is the rank of the OT plan
195+
- :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem
196+
197+
198+
Parameters
199+
----------
200+
X_s : array-like, shape (n_samples_a, dim)
201+
samples in the source domain
202+
X_t : array-like, shape (n_samples_b, dim)
203+
samples in the target domain
204+
a : array-like, shape (n_samples_a,)
205+
samples weights in the source domain
206+
b : array-like, shape (n_samples_b,)
207+
samples weights in the target domain
208+
reg : float, optional
209+
Regularization term >0
210+
rank: int, optional. Default is None. (>0)
211+
Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.
212+
alpha: int, optional. Default is None. (>0 and <1/r)
213+
Lower bound for the weight vector g. If None, 1e-10 is considered
214+
numItermax : int, optional
215+
Max number of iterations
216+
stopThr : float, optional
217+
Stop threshold on error (>0)
218+
warn : bool, optional
219+
if True, raises a warning if the algorithm doesn't convergence.
220+
log : bool, optional
221+
record log if True
222+
223+
224+
Returns
225+
-------
226+
lazy_plan : LazyTensor()
227+
OT plan in a LazyTensor object of shape (shape_plan)
228+
See :any:`LazyTensor` for more information.
229+
value : float
230+
Optimal value of the optimization problem
231+
value_linear : float
232+
Linear OT loss with the optimal OT
233+
Q : array-like, shape (n_samples_a, r)
234+
First low-rank matrix decomposition of the OT plan
235+
R: array-like, shape (n_samples_b, r)
236+
Second low-rank matrix decomposition of the OT plan
237+
g : array-like, shape (r, )
238+
Weight vector for the low-rank decomposition of the OT plan
239+
240+
241+
References
242+
----------
243+
.. [65] Scetbon, M., Cuturi, M., & Peyré, G (2021).
244+
"Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737.
245+
246+
"""
247+
248+
# POT backend
249+
nx = get_backend(X_s, X_t)
250+
ns, nt = X_s.shape[0], X_t.shape[0]
251+
252+
# Initialize weights a, b
253+
if a is None:
254+
a = unif(ns, type_as=X_s)
255+
if b is None:
256+
b = unif(nt, type_as=X_t)
257+
258+
# Compute rank (see Section 3.1, def 1)
259+
r = rank
260+
if rank is None:
261+
r = min(ns, nt)
262+
263+
if alpha is None:
264+
alpha = 1e-10
265+
266+
# Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank)
267+
# (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper)
268+
if 1 / r < alpha:
269+
raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format(
270+
a=alpha, r=1 / rank))
271+
272+
if r <= 0:
273+
raise ValueError("The rank parameter cannot have a negative value")
274+
275+
# Low rank decomposition of the sqeuclidean cost matrix (A, B)
276+
M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None)
277+
278+
# Compute gamma (see "Section 3.4, proposition 4" in the paper)
279+
L = nx.sqrt(
280+
3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) +
281+
(reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2
282+
)
283+
gamma = 1 / (2 * L)
284+
285+
# Initialize the low rank matrices Q, R, g
286+
Q = nx.ones((ns, r), type_as=a)
287+
R = nx.ones((nt, r), type_as=a)
288+
g = nx.ones(r, type_as=a)
289+
k = 100
290+
291+
# -------------------------- Low rank algorithm ------------------------------
292+
# see "Section 3.3, Algorithm 3 LOT" in the paper
293+
294+
for ii in range(k):
295+
# Compute the C*R dot matrix using the lr decomposition of C
296+
CR_ = nx.dot(M2.T, R)
297+
CR = nx.dot(M1, CR_)
298+
299+
# Compute the C.t * Q dot matrix using the lr decomposition of C
300+
CQ_ = nx.dot(M1.T, Q)
301+
CQ = nx.dot(M2, CQ_)
302+
303+
diag_g = (1 / g)[None, :]
304+
305+
eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q))
306+
eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R))
307+
omega = nx.diag(nx.dot(Q.T, CR))
308+
eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g))
309+
310+
Q, R, g = _LR_Dysktra(
311+
eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx
312+
)
313+
Q = Q + 1e-16
314+
R = R + 1e-16
315+
316+
# ----------------- Compute lazy_plan, value and value_linear ------------------
317+
# see "Section 3.2: The Low-rank OT Problem" in the paper
318+
319+
# Compute lazy plan (using LazyTensor class)
320+
lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g)
321+
322+
# Compute value_linear (using trace formula)
323+
v1 = nx.dot(Q.T, M1)
324+
v2 = nx.dot(R, (v1.T * diag_g).T)
325+
value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2)))
326+
327+
# Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper)
328+
reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q
329+
reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g
330+
reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R
331+
value = value_linear + reg * (reg_Q + reg_g + reg_R)
332+
333+
if log:
334+
dict_log = dict()
335+
dict_log["value"] = value
336+
dict_log["value_linear"] = value_linear
337+
dict_log["lazy_plan"] = lazy_plan
338+
339+
return Q, R, g, dict_log
340+
341+
return Q, R, g

0 commit comments

Comments
 (0)