Skip to content

Commit 218c54a

Browse files
committed
merge master + doc for lowrank
2 parents 55c8d2b + 659cde8 commit 218c54a

File tree

10 files changed

+1387
-84
lines changed

10 files changed

+1387
-84
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,4 +344,8 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
344344

345345
[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning.
346346

347-
[63] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
347+
[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations.
348+
349+
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.
350+
351+
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

RELEASES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
2121
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
2222
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
23+
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
24+
+ Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584)
2325
+ Added support for [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf) (PR #568)
2426

27+
2528
#### Closed issues
2629
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
2730
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================================
4+
Gaussian Bures-Wasserstein barycenters
5+
========================================================
6+
7+
Illustration of Gaussian Bures-Wasserstein barycenters.
8+
9+
"""
10+
11+
# Authors: Rémi Flamary <remi.flamary@polytechnique.edu>
12+
#
13+
# License: MIT License
14+
15+
# sphinx_gallery_thumbnail_number = 2
16+
# %%
17+
from matplotlib import colors
18+
from matplotlib.patches import Ellipse
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import ot
22+
23+
24+
# %%
25+
# Define Gaussian Covariances and distributions
26+
# ---------------------------------------------
27+
28+
C1 = np.array([[0.5, -0.4], [-0.4, 0.5]])
29+
C2 = np.array([[1, 0.3], [0.3, 1]])
30+
C3 = np.array([[1.5, 0], [0, 0.5]])
31+
C4 = np.array([[0.5, 0], [0, 1.5]])
32+
33+
C = np.stack((C1, C2, C3, C4))
34+
35+
m1 = np.array([0, 0])
36+
m2 = np.array([0, 4])
37+
m3 = np.array([4, 0])
38+
m4 = np.array([4, 4])
39+
40+
m = np.stack((m1, m2, m3, m4))
41+
42+
# %%
43+
# Plot the distributions
44+
# ----------------------
45+
46+
47+
def draw_cov(mu, C, color=None, label=None, nstd=1):
48+
49+
def eigsorted(cov):
50+
vals, vecs = np.linalg.eigh(cov)
51+
order = vals.argsort()[::-1]
52+
return vals[order], vecs[:, order]
53+
54+
vals, vecs = eigsorted(C)
55+
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
56+
w, h = 2 * nstd * np.sqrt(vals)
57+
ell = Ellipse(xy=(mu[0], mu[1]),
58+
width=w, height=h, alpha=0.5,
59+
angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)
60+
pl.gca().add_artist(ell)
61+
#pl.scatter(mu[0],mu[1],color=color, marker='x')
62+
63+
64+
axis = [-1.5, 5.5, -1.5, 5.5]
65+
66+
pl.figure(1, (8, 2))
67+
pl.clf()
68+
69+
pl.subplot(1, 4, 1)
70+
draw_cov(m1, C1, color='C0')
71+
pl.axis(axis)
72+
pl.title('$\mathcal{N}(m_1,\Sigma_1)$')
73+
74+
pl.subplot(1, 4, 2)
75+
draw_cov(m2, C2, color='C1')
76+
pl.axis(axis)
77+
pl.title('$\mathcal{N}(m_2,\Sigma_2)$')
78+
79+
pl.subplot(1, 4, 3)
80+
draw_cov(m3, C3, color='C2')
81+
pl.axis(axis)
82+
pl.title('$\mathcal{N}(m_3,\Sigma_3)$')
83+
84+
pl.subplot(1, 4, 4)
85+
draw_cov(m4, C4, color='C3')
86+
pl.axis(axis)
87+
pl.title('$\mathcal{N}(m_4,\Sigma_4)$')
88+
89+
# %%
90+
# Compute Bures-Wasserstein barycenters and plot them
91+
# -------------------------------------------
92+
93+
# basis for bilinear interpolation
94+
v1 = np.array((1, 0, 0, 0))
95+
v2 = np.array((0, 1, 0, 0))
96+
v3 = np.array((0, 0, 1, 0))
97+
v4 = np.array((0, 0, 0, 1))
98+
99+
100+
colors = np.stack((colors.to_rgb('C0'),
101+
colors.to_rgb('C1'),
102+
colors.to_rgb('C2'),
103+
colors.to_rgb('C3')))
104+
105+
pl.figure(2, (8, 8))
106+
107+
nb_interp = 6
108+
109+
for i in range(nb_interp):
110+
for j in range(nb_interp):
111+
tx = float(i) / (nb_interp - 1)
112+
ty = float(j) / (nb_interp - 1)
113+
114+
# weights are constructed by bilinear interpolation
115+
tmp1 = (1 - tx) * v1 + tx * v2
116+
tmp2 = (1 - tx) * v3 + tx * v4
117+
weights = (1 - ty) * tmp1 + ty * tmp2
118+
119+
color = np.dot(colors.T, weights)
120+
121+
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights)
122+
123+
draw_cov(mb, Cb, color=color, label=None, nstd=0.3)
124+
125+
pl.axis(axis)
126+
pl.axis('off')
127+
pl.tight_layout()

examples/gromov/plot_fgw_solvers.py

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
==============================
66
77
This example illustrates the computation of FGW for attributed graphs
8-
using 3 different solvers to estimate the distance based on Conditional
9-
Gradient [24] or Sinkhorn projections [12, 51].
8+
using 4 different solvers to estimate the distance based on Conditional
9+
Gradient [24], Sinkhorn projections [12, 51] and alternated Bregman
10+
projections [63, 64].
1011
1112
We generate two graphs following Stochastic Block Models further endowed with
1213
node features and compute their FGW matchings.
@@ -23,6 +24,16 @@
2324
[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
2425
"Gromov-wasserstein learning for graph matching and node embedding".
2526
In International Conference on Machine Learning (ICML), 2019.
27+
28+
[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J.
29+
"A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in
30+
Graph Data". International Conference on Learning Representations (ICLR), 2023.
31+
32+
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W.
33+
"Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications".
34+
In Thirty-seventh Conference on Neural Information Processing Systems
35+
(NeurIPS), 2023.
36+
2637
"""
2738

2839
# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
@@ -33,9 +44,12 @@
3344

3445
import numpy as np
3546
import matplotlib.pylab as pl
36-
from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein
47+
from ot.gromov import (fused_gromov_wasserstein,
48+
entropic_fused_gromov_wasserstein,
49+
BAPG_fused_gromov_wasserstein)
3750
import networkx
3851
from networkx.generators.community import stochastic_block_model as sbm
52+
from time import time
3953

4054
#############################################################################
4155
#
@@ -85,34 +99,59 @@
8599

86100

87101
# Conditional Gradient algorithm
88-
fgw0, log0 = fused_gromov_wasserstein(
89-
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True)
102+
print('Conditional Gradient \n')
103+
start_cg = time()
104+
T_cg, log_cg = fused_gromov_wasserstein(
105+
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, tol_rel=1e-9,
106+
verbose=True, log=True)
107+
end_cg = time()
108+
time_cg = 1000 * (end_cg - start_cg)
90109

91110
# Proximal Point algorithm with Kullback-Leibler as proximal operator
92-
fgw, log = entropic_fused_gromov_wasserstein(
111+
print('Proximal Point Algorithm \n')
112+
start_ppa = time()
113+
T_ppa, log_ppa = entropic_fused_gromov_wasserstein(
93114
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA',
94-
log=True, verbose=True, warmstart=False, numItermax=10)
115+
tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10)
116+
end_ppa = time()
117+
time_ppa = 1000 * (end_ppa - start_ppa)
95118

96119
# Projected Gradient algorithm with entropic regularization
97-
fgwe, loge = entropic_fused_gromov_wasserstein(
120+
print('Projected Gradient Descent \n')
121+
start_pgd = time()
122+
T_pgd, log_pgd = entropic_fused_gromov_wasserstein(
98123
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD',
99-
log=True, verbose=True, warmstart=False, numItermax=10)
100-
101-
print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist']))
102-
print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist']))
103-
print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist']))
124+
tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10)
125+
end_pgd = time()
126+
time_pgd = 1000 * (end_pgd - start_pgd)
127+
128+
# Alternated Bregman Projected Gradient algorithm with Kullback-Leibler as proximal operator
129+
print('Bregman Alternated Projected Gradient \n')
130+
start_bapg = time()
131+
T_bapg, log_bapg = BAPG_fused_gromov_wasserstein(
132+
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1.,
133+
tol=1e-9, marginal_loss=True, verbose=True, log=True)
134+
end_bapg = time()
135+
time_bapg = 1000 * (end_bapg - start_bapg)
136+
137+
print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log_cg['fgw_dist']))
138+
print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log_ppa['fgw_dist']))
139+
print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_pgd['fgw_dist']))
140+
print('Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_bapg['fgw_dist']))
104141

105142
# compute OT sparsity level
106-
fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3)
107-
fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3)
108-
fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3)
143+
T_cg_sparsity = 100 * (T_cg == 0.).astype(np.float64).sum() / (N2 * N3)
144+
T_ppa_sparsity = 100 * (T_ppa == 0.).astype(np.float64).sum() / (N2 * N3)
145+
T_pgd_sparsity = 100 * (T_pgd == 0.).astype(np.float64).sum() / (N2 * N3)
146+
T_bapg_sparsity = 100 * (T_bapg == 0.).astype(np.float64).sum() / (N2 * N3)
109147

110-
# Methods using Sinkhorn projections tend to produce feasibility errors on the
148+
# Methods using Sinkhorn/Bregman projections tend to produce feasibility errors on the
111149
# marginal constraints
112150

113-
err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3)
114-
err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3)
115-
erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3)
151+
err_cg = np.linalg.norm(T_cg.sum(1) - h2) + np.linalg.norm(T_cg.sum(0) - h3)
152+
err_ppa = np.linalg.norm(T_ppa.sum(1) - h2) + np.linalg.norm(T_ppa.sum(0) - h3)
153+
err_pgd = np.linalg.norm(T_pgd.sum(1) - h2) + np.linalg.norm(T_pgd.sum(0) - h3)
154+
err_bapg = np.linalg.norm(T_bapg.sum(1) - h2) + np.linalg.norm(T_bapg.sum(0) - h3)
116155

117156
#############################################################################
118157
#
@@ -242,46 +281,52 @@ def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T,
242281
seed_G2 = 0
243282
seed_G3 = 4
244283

245-
pl.figure(2, figsize=(12, 3.5))
284+
pl.figure(2, figsize=(15, 3.5))
246285
pl.clf()
247-
pl.subplot(131)
286+
pl.subplot(141)
248287
pl.axis('off')
249-
pl.axis
250-
pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % (
251-
np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %',
252-
np.round(err0, 4)), fontsize=fontsize)
253288

254-
p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals
289+
pl.title('(CG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
290+
np.round(log_cg['fgw_dist'], 3), str(np.round(T_cg_sparsity, 2)) + ' %',
291+
np.round(err_cg, 4), str(np.round(time_cg, 2)) + ' ms'), fontsize=fontsize)
255292

256293
pos1, pos2 = draw_transp_colored_GW(
257-
weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0,
258-
shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
294+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_cg.sum(1), p2=T_cg.sum(0),
295+
T=T_cg, shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
259296

260-
pl.subplot(132)
297+
pl.subplot(142)
261298
pl.axis('off')
262299

263-
p, q = fgw.sum(1), fgw.sum(0) # check marginals
264-
265-
pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
266-
np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %',
267-
np.round(err, 4)), fontsize=fontsize)
300+
pl.title('(PPA) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
301+
np.round(log_ppa['fgw_dist'], 3), str(np.round(T_ppa_sparsity, 2)) + ' %',
302+
np.round(err_ppa, 4), str(np.round(time_ppa, 2)) + ' ms'), fontsize=fontsize)
268303

269304
pos1, pos2 = draw_transp_colored_GW(
270-
weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw,
271-
pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
305+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_ppa.sum(1), p2=T_ppa.sum(0),
306+
T=T_ppa, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
272307

273-
pl.subplot(133)
308+
pl.subplot(143)
274309
pl.axis('off')
275310

276-
pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals
311+
pl.title('(PGD) Entropic FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
312+
np.round(log_pgd['fgw_dist'], 3), str(np.round(T_pgd_sparsity, 2)) + ' %',
313+
np.round(err_pgd, 4), str(np.round(time_pgd, 2)) + ' ms'), fontsize=fontsize)
314+
315+
pos1, pos2 = draw_transp_colored_GW(
316+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_pgd.sum(1), p2=T_pgd.sum(0),
317+
T=T_pgd, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
318+
319+
320+
pl.subplot(144)
321+
pl.axis('off')
277322

278-
pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
279-
np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %',
280-
np.round(erre, 4)), fontsize=fontsize)
323+
pl.title('(BAPG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
324+
np.round(log_bapg['fgw_dist'], 3), str(np.round(T_bapg_sparsity, 2)) + ' %',
325+
np.round(err_bapg, 4), str(np.round(time_bapg, 2)) + ' ms'), fontsize=fontsize)
281326

282327
pos1, pos2 = draw_transp_colored_GW(
283-
weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe,
284-
pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
328+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_bapg.sum(1), p2=T_bapg.sum(0),
329+
T=T_bapg, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
285330

286331
pl.tight_layout()
287332

0 commit comments

Comments
 (0)