Skip to content

Commit 2472dd4

Browse files
authored
Implementation of Low Rank Gromov-Wasserstein (#614)
* new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * modify low rank, remove solve_sample,OTResultLazy * new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * remove test solve_sample * add value, value_linear, lazy_plan * add comments to lr algorithm * modify test functions + add comments to lowrank * modify __init__ with lowrank * debug lowrank + test * debug test function low_rank * error test * final debug of lowrank + add new test functions * Debug tests + add lowrank to solve_sample * fix torch backend for lowrank * fix jax backend and skip tf * fix pep 8 tests * add lowrank init + test functions * Add init strategies in lowrank + example (#588) * modified lowrank * changes from code review * fix error test pep8 * fix linux-minimal-deps + code review * Implementation of LR GW + add method in __init__ * add LR gw paper in README.md * add tests for low rank GW * add examples for Low Rank GW * fix __init__ * change atol of lr backends * fix pep8 errors * modif for code review
1 parent e01c4e6 commit 2472dd4

File tree

8 files changed

+622
-5
lines changed

8 files changed

+622
-5
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ The contributors to this library are:
5050
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
5151
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
5252
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
53-
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn)
53+
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples)
5454

5555
## Acknowledgments
5656

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
357357
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
358358

359359
[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021).
360+
361+
[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022.

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
+ Continuous entropic mapping (PR #613)
88
+ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620)
99
+ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605).
10+
+ Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614)
1011

1112
#### Closed issues
1213
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)

examples/others/plot_lowrank_GW.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================
4+
Low rank Gromov-Wasterstein between samples
5+
========================================
6+
7+
Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67]
8+
on two curves in 2D and 3D, both sampled with 200 points.
9+
10+
The squared Euclidean distance is considered as the ground cost for both samples.
11+
12+
[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022).
13+
"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs".
14+
In International Conference on Machine Learning (ICML), 2022.
15+
"""
16+
17+
# Author: Laurène David <laurene.david@ip-paris.fr>
18+
#
19+
# License: MIT License
20+
#
21+
# sphinx_gallery_thumbnail_number = 3
22+
23+
#%%
24+
import numpy as np
25+
import matplotlib.pylab as pl
26+
import ot.plot
27+
import time
28+
29+
##############################################################################
30+
# Generate data
31+
# -------------
32+
33+
#%% parameters
34+
n_samples = 200
35+
36+
# Generate 2D and 3D curves
37+
theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples)
38+
z = np.linspace(1, 2, n_samples)
39+
r = z**2 + 1
40+
x = r * np.sin(theta)
41+
y = r * np.cos(theta)
42+
43+
# Source and target distribution
44+
X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
45+
Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
46+
47+
48+
##############################################################################
49+
# Plot data
50+
# ------------
51+
52+
#%%
53+
# Plot the source and target samples
54+
fig = pl.figure(1, figsize=(10, 4))
55+
56+
ax = fig.add_subplot(121)
57+
ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6)
58+
ax.tick_params(left=False, right=False, labelleft=False,
59+
labelbottom=False, bottom=False)
60+
ax.set_title("2D curve (source)")
61+
62+
ax2 = fig.add_subplot(122, projection="3d")
63+
ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6)
64+
ax2.tick_params(left=False, right=False, labelleft=False,
65+
labelbottom=False, bottom=False)
66+
ax2.view_init(15, -50)
67+
ax2.set_title("3D curve (target)")
68+
69+
pl.tight_layout()
70+
pl.show()
71+
72+
73+
##############################################################################
74+
# Entropic Gromov-Wasserstein
75+
# ------------
76+
77+
#%%
78+
79+
# Compute cost matrices
80+
C1 = ot.dist(X, X, metric="sqeuclidean")
81+
C2 = ot.dist(Y, Y, metric="sqeuclidean")
82+
83+
# Scale cost matrices
84+
r1 = C1.max()
85+
r2 = C2.max()
86+
87+
C1 = C1 / r1
88+
C2 = C2 / r2
89+
90+
91+
# Solve entropic gw
92+
reg = 5 * 1e-3
93+
94+
start = time.time()
95+
gw, log = ot.gromov.entropic_gromov_wasserstein(
96+
C1, C2, tol=1e-3, epsilon=reg,
97+
log=True, verbose=False)
98+
99+
end = time.time()
100+
time_entropic = end - start
101+
102+
entropic_gw_loss = np.round(log['gw_dist'], 3)
103+
104+
# Plot entropic gw
105+
pl.figure(2)
106+
pl.imshow(gw, interpolation="nearest", aspect="auto")
107+
pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss))
108+
pl.show()
109+
110+
111+
##############################################################################
112+
# Low rank squared euclidean cost matrices
113+
# ------------
114+
# %%
115+
116+
# Compute the low rank sqeuclidean cost decompositions
117+
A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)
118+
B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False)
119+
120+
# Scale the low rank cost matrices
121+
A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1)
122+
B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2)
123+
124+
125+
##############################################################################
126+
# Low rank Gromov-Wasserstein
127+
# ------------
128+
# %%
129+
130+
# Solve low rank gromov-wasserstein with different ranks
131+
list_rank = [10, 50]
132+
list_P_GW = []
133+
list_loss_GW = []
134+
list_time_GW = []
135+
136+
for rank in list_rank:
137+
start = time.time()
138+
139+
Q, R, g, log = ot.lowrank_gromov_wasserstein_samples(
140+
X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2),
141+
cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6,
142+
)
143+
end = time.time()
144+
145+
P = log["lazy_plan"][:]
146+
loss = log["value"]
147+
148+
list_P_GW.append(P)
149+
list_loss_GW.append(np.round(loss, 3))
150+
list_time_GW.append(end - start)
151+
152+
153+
# %%
154+
# Plot low rank GW with different ranks
155+
pl.figure(3, figsize=(10, 4))
156+
157+
pl.subplot(1, 2, 1)
158+
pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto")
159+
pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0]))
160+
161+
pl.subplot(1, 2, 2)
162+
pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto")
163+
pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1]))
164+
165+
pl.tight_layout()
166+
pl.show()
167+
168+
169+
# %%
170+
# Compare computation time between entropic GW and low rank GW
171+
print("Entropic GW: {:.2f}s".format(time_entropic))
172+
print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0]))
173+
print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1]))

ot/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
5050
sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
5151
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
52-
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
52+
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2,
53+
lowrank_gromov_wasserstein_samples)
5354
from .weak import weak_optimal_transport
5455
from .factored import factored_optimal_transport
5556
from .solvers import solve, solve_gromov, solve_sample
@@ -71,5 +72,5 @@
7172
'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample',
7273
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
7374
'binary_search_circle', 'wasserstein_circle',
74-
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',
75-
'lowrank_sinkhorn']
75+
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn',
76+
'lowrank_gromov_wasserstein_samples']

ot/gromov/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
fused_gromov_wasserstein_dictionary_learning,
4848
fused_gromov_wasserstein_linear_unmixing)
4949

50+
from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples)
51+
5052

5153
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
5254
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
@@ -64,4 +66,4 @@
6466
'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein',
6567
'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning',
6668
'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
67-
'fused_gromov_wasserstein_linear_unmixing']
69+
'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples']

0 commit comments

Comments
 (0)