Skip to content

Implementation of Low Rank Gromov-Wasserstein #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 59 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f49f6b4
new file for lr sinkhorn
laudavid Oct 24, 2023
3c4b50f
lr sinkhorn, solve_sample, OTResultLazy
laudavid Oct 24, 2023
3034e57
add test functions + small modif lr_sin/solve_sample
laudavid Oct 25, 2023
085863a
add import to __init__
laudavid Oct 26, 2023
9becafc
modify low rank, remove solve_sample,OTResultLazy
laudavid Nov 3, 2023
855234d
pull from master
laudavid Nov 3, 2023
6ea251c
new file for lr sinkhorn
laudavid Oct 24, 2023
965e4d6
lr sinkhorn, solve_sample, OTResultLazy
laudavid Oct 24, 2023
fd5e26d
add test functions + small modif lr_sin/solve_sample
laudavid Oct 25, 2023
3df3b77
add import to __init__
laudavid Oct 26, 2023
fae28f7
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 3, 2023
ab5475b
remove test solve_sample
laudavid Nov 3, 2023
f1c8cdd
add value, value_linear, lazy_plan
laudavid Nov 8, 2023
b1a2136
Merge branch 'PythonOT:master' into master
laudavid Nov 8, 2023
9e51a83
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 8, 2023
df01cff
add comments to lr algorithm
laudavid Nov 8, 2023
a0b0a9d
Merge branch 'PythonOT:master' into master
laudavid Nov 9, 2023
7075c8b
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 9, 2023
5f2af0e
Merge branch 'PythonOT:master' into lowrank_v2
laudavid Nov 9, 2023
5bc9de9
modify test functions + add comments to lowrank
laudavid Nov 9, 2023
c66951b
Merge branch 'lowrank_v2' of https://github.com/hi-paris/POT into low…
laudavid Nov 9, 2023
6040e6f
modify __init__ with lowrank
laudavid Nov 9, 2023
a7fdffd
debug lowrank + test
laudavid Nov 14, 2023
d90c186
debug test function low_rank
laudavid Nov 14, 2023
ea3a3e0
error test
laudavid Nov 14, 2023
f6a36bf
Merge branch 'PythonOT:master' into master
laudavid Nov 15, 2023
fe067fd
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 15, 2023
5d3ed32
Merge branch 'PythonOT:master' into lowrank_v2
laudavid Nov 15, 2023
3e6b9aa
Merge branch 'lowrank_v2' of https://github.com/hi-paris/POT into low…
laudavid Nov 15, 2023
165e8f5
final debug of lowrank + add new test functions
laudavid Nov 15, 2023
de54bb9
branch up to date with master
laudavid Nov 17, 2023
bdcfaf6
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 22, 2023
d0d9f46
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 24, 2023
ec96836
Merge branch 'PythonOT:master' into lowrank_v2
laudavid Nov 24, 2023
8c6ac67
Debug tests + add lowrank to solve_sample
laudavid Nov 24, 2023
fafc5f6
Merge branch 'lowrank_v2' of https://github.com/hi-paris/POT into low…
laudavid Nov 24, 2023
bc7af6b
fix torch backend for lowrank
laudavid Nov 25, 2023
b40705c
fix jax backend and skip tf
laudavid Nov 28, 2023
55c8d2b
fix pep 8 tests
laudavid Nov 28, 2023
218c54a
merge master + doc for lowrank
laudavid Dec 5, 2023
e25c91d
add lowrank init + test functions
laudavid Dec 6, 2023
0436e95
Add init strategies in lowrank + example (PythonOT#588)
laudavid Dec 18, 2023
3649633
modified lowrank
laudavid Dec 20, 2023
78091a0
merge with upstream POT
laudavid Dec 20, 2023
5c4c4a8
changes from code review
laudavid Dec 20, 2023
477deed
fix error test pep8
laudavid Dec 20, 2023
077184f
fix linux-minimal-deps + code review
laudavid Dec 21, 2023
7d3071f
Implementation of LR GW + add method in __init__
laudavid Mar 8, 2024
bd06dc4
add LR gw paper in README.md
laudavid Mar 8, 2024
99b38d5
add tests for low rank GW
laudavid Mar 8, 2024
cbc4d1f
add examples for Low Rank GW
laudavid Mar 8, 2024
a0ddb7e
merge latest version of POT
laudavid Mar 8, 2024
269ed30
fix __init__
laudavid Mar 8, 2024
440f3a2
change atol of lr backends
laudavid Mar 8, 2024
ab770ce
fix pep8 errors
laudavid Mar 8, 2024
df28231
pull POT master + new gromov/_lowrank.py
laudavid Apr 17, 2024
96bad89
modif for code review
laudavid Apr 26, 2024
a960b28
pull upstream master
laudavid Apr 26, 2024
cd3a0af
Merge branch 'master' of https://github.com/PythonOT/POT into lowrank…
laudavid May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The contributors to this library are:
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn)
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples)

## Acknowledgments

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021).

[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
+ Continuous entropic mapping (PR #613)
+ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620)
+ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605).
+ Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614)

#### Closed issues
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
Expand Down
173 changes: 173 additions & 0 deletions examples/others/plot_lowrank_GW.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
"""
========================================
Low rank Gromov-Wasterstein between samples
========================================

Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67]
on two curves in 2D and 3D, both sampled with 200 points.

The squared Euclidean distance is considered as the ground cost for both samples.

[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022).
"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs".
In International Conference on Machine Learning (ICML), 2022.
"""

# Author: Laurène David <laurene.david@ip-paris.fr>
#
# License: MIT License
#
# sphinx_gallery_thumbnail_number = 3

#%%
import numpy as np
import matplotlib.pylab as pl
import ot.plot
import time

##############################################################################
# Generate data
# -------------

#%% parameters
n_samples = 200

# Generate 2D and 3D curves
theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples)
z = np.linspace(1, 2, n_samples)
r = z**2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)

# Source and target distribution
X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)


##############################################################################
# Plot data
# ------------

#%%
# Plot the source and target samples
fig = pl.figure(1, figsize=(10, 4))

ax = fig.add_subplot(121)
ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6)
ax.tick_params(left=False, right=False, labelleft=False,
labelbottom=False, bottom=False)
ax.set_title("2D curve (source)")

ax2 = fig.add_subplot(122, projection="3d")
ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6)
ax2.tick_params(left=False, right=False, labelleft=False,
labelbottom=False, bottom=False)
ax2.view_init(15, -50)
ax2.set_title("3D curve (target)")

pl.tight_layout()
pl.show()


##############################################################################
# Entropic Gromov-Wasserstein
# ------------

#%%

# Compute cost matrices
C1 = ot.dist(X, X, metric="sqeuclidean")
C2 = ot.dist(Y, Y, metric="sqeuclidean")

# Scale cost matrices
r1 = C1.max()
r2 = C2.max()

C1 = C1 / r1
C2 = C2 / r2


# Solve entropic gw
reg = 5 * 1e-3

start = time.time()
gw, log = ot.gromov.entropic_gromov_wasserstein(
C1, C2, tol=1e-3, epsilon=reg,
log=True, verbose=False)

end = time.time()
time_entropic = end - start

entropic_gw_loss = np.round(log['gw_dist'], 3)

# Plot entropic gw
pl.figure(2)
pl.imshow(gw, interpolation="nearest", aspect="auto")
pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss))
pl.show()


##############################################################################
# Low rank squared euclidean cost matrices
# ------------
# %%

# Compute the low rank sqeuclidean cost decompositions
A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)
B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False)

# Scale the low rank cost matrices
A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1)
B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2)


##############################################################################
# Low rank Gromov-Wasserstein
# ------------
# %%

# Solve low rank gromov-wasserstein with different ranks
list_rank = [10, 50]
list_P_GW = []
list_loss_GW = []
list_time_GW = []

for rank in list_rank:
start = time.time()

Q, R, g, log = ot.lowrank_gromov_wasserstein_samples(
X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2),
cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6,
)
end = time.time()

P = log["lazy_plan"][:]
loss = log["value"]

list_P_GW.append(P)
list_loss_GW.append(np.round(loss, 3))
list_time_GW.append(end - start)


# %%
# Plot low rank GW with different ranks
pl.figure(3, figsize=(10, 4))

pl.subplot(1, 2, 1)
pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto")
pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0]))

pl.subplot(1, 2, 2)
pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto")
pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1]))

pl.tight_layout()
pl.show()


# %%
# Compare computation time between entropic GW and low rank GW
print("Entropic GW: {:.2f}s".format(time_entropic))
print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0]))
print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1]))
7 changes: 4 additions & 3 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2,
lowrank_gromov_wasserstein_samples)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov, solve_sample
Expand All @@ -71,5 +72,5 @@
'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',
'lowrank_sinkhorn']
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn',
'lowrank_gromov_wasserstein_samples']
4 changes: 3 additions & 1 deletion ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
fused_gromov_wasserstein_dictionary_learning,
fused_gromov_wasserstein_linear_unmixing)

from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples)


__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
Expand All @@ -64,4 +66,4 @@
'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein',
'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning',
'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
'fused_gromov_wasserstein_linear_unmixing']
'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples']
Loading
Loading