Skip to content

Commit 55a851e

Browse files
[MRG] BAPG solvers for GW and FGW (#581)
* add citation for srgw-kl * init commit - BAPG for GW and FGW * add tests * update example with fgw solvers comparison * change BAPG names + improve doc * merge --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent f809253 commit 55a851e

File tree

6 files changed

+939
-63
lines changed

6 files changed

+939
-63
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,7 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
343343
[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462.
344344

345345
[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning.
346+
347+
[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations.
348+
349+
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
2121
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
2222
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
23+
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
2324
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)
2425

26+
2527
#### Closed issues
2628
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
2729
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)

examples/gromov/plot_fgw_solvers.py

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
==============================
66
77
This example illustrates the computation of FGW for attributed graphs
8-
using 3 different solvers to estimate the distance based on Conditional
9-
Gradient [24] or Sinkhorn projections [12, 51].
8+
using 4 different solvers to estimate the distance based on Conditional
9+
Gradient [24], Sinkhorn projections [12, 51] and alternated Bregman
10+
projections [63, 64].
1011
1112
We generate two graphs following Stochastic Block Models further endowed with
1213
node features and compute their FGW matchings.
@@ -23,6 +24,16 @@
2324
[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
2425
"Gromov-wasserstein learning for graph matching and node embedding".
2526
In International Conference on Machine Learning (ICML), 2019.
27+
28+
[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J.
29+
"A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in
30+
Graph Data". International Conference on Learning Representations (ICLR), 2023.
31+
32+
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W.
33+
"Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications".
34+
In Thirty-seventh Conference on Neural Information Processing Systems
35+
(NeurIPS), 2023.
36+
2637
"""
2738

2839
# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
@@ -33,9 +44,12 @@
3344

3445
import numpy as np
3546
import matplotlib.pylab as pl
36-
from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein
47+
from ot.gromov import (fused_gromov_wasserstein,
48+
entropic_fused_gromov_wasserstein,
49+
BAPG_fused_gromov_wasserstein)
3750
import networkx
3851
from networkx.generators.community import stochastic_block_model as sbm
52+
from time import time
3953

4054
#############################################################################
4155
#
@@ -85,34 +99,59 @@
8599

86100

87101
# Conditional Gradient algorithm
88-
fgw0, log0 = fused_gromov_wasserstein(
89-
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True)
102+
print('Conditional Gradient \n')
103+
start_cg = time()
104+
T_cg, log_cg = fused_gromov_wasserstein(
105+
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, tol_rel=1e-9,
106+
verbose=True, log=True)
107+
end_cg = time()
108+
time_cg = 1000 * (end_cg - start_cg)
90109

91110
# Proximal Point algorithm with Kullback-Leibler as proximal operator
92-
fgw, log = entropic_fused_gromov_wasserstein(
111+
print('Proximal Point Algorithm \n')
112+
start_ppa = time()
113+
T_ppa, log_ppa = entropic_fused_gromov_wasserstein(
93114
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA',
94-
log=True, verbose=True, warmstart=False, numItermax=10)
115+
tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10)
116+
end_ppa = time()
117+
time_ppa = 1000 * (end_ppa - start_ppa)
95118

96119
# Projected Gradient algorithm with entropic regularization
97-
fgwe, loge = entropic_fused_gromov_wasserstein(
120+
print('Projected Gradient Descent \n')
121+
start_pgd = time()
122+
T_pgd, log_pgd = entropic_fused_gromov_wasserstein(
98123
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD',
99-
log=True, verbose=True, warmstart=False, numItermax=10)
100-
101-
print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist']))
102-
print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist']))
103-
print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist']))
124+
tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10)
125+
end_pgd = time()
126+
time_pgd = 1000 * (end_pgd - start_pgd)
127+
128+
# Alternated Bregman Projected Gradient algorithm with Kullback-Leibler as proximal operator
129+
print('Bregman Alternated Projected Gradient \n')
130+
start_bapg = time()
131+
T_bapg, log_bapg = BAPG_fused_gromov_wasserstein(
132+
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1.,
133+
tol=1e-9, marginal_loss=True, verbose=True, log=True)
134+
end_bapg = time()
135+
time_bapg = 1000 * (end_bapg - start_bapg)
136+
137+
print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log_cg['fgw_dist']))
138+
print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log_ppa['fgw_dist']))
139+
print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_pgd['fgw_dist']))
140+
print('Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_bapg['fgw_dist']))
104141

105142
# compute OT sparsity level
106-
fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3)
107-
fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3)
108-
fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3)
143+
T_cg_sparsity = 100 * (T_cg == 0.).astype(np.float64).sum() / (N2 * N3)
144+
T_ppa_sparsity = 100 * (T_ppa == 0.).astype(np.float64).sum() / (N2 * N3)
145+
T_pgd_sparsity = 100 * (T_pgd == 0.).astype(np.float64).sum() / (N2 * N3)
146+
T_bapg_sparsity = 100 * (T_bapg == 0.).astype(np.float64).sum() / (N2 * N3)
109147

110-
# Methods using Sinkhorn projections tend to produce feasibility errors on the
148+
# Methods using Sinkhorn/Bregman projections tend to produce feasibility errors on the
111149
# marginal constraints
112150

113-
err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3)
114-
err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3)
115-
erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3)
151+
err_cg = np.linalg.norm(T_cg.sum(1) - h2) + np.linalg.norm(T_cg.sum(0) - h3)
152+
err_ppa = np.linalg.norm(T_ppa.sum(1) - h2) + np.linalg.norm(T_ppa.sum(0) - h3)
153+
err_pgd = np.linalg.norm(T_pgd.sum(1) - h2) + np.linalg.norm(T_pgd.sum(0) - h3)
154+
err_bapg = np.linalg.norm(T_bapg.sum(1) - h2) + np.linalg.norm(T_bapg.sum(0) - h3)
116155

117156
#############################################################################
118157
#
@@ -242,46 +281,52 @@ def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T,
242281
seed_G2 = 0
243282
seed_G3 = 4
244283

245-
pl.figure(2, figsize=(12, 3.5))
284+
pl.figure(2, figsize=(15, 3.5))
246285
pl.clf()
247-
pl.subplot(131)
286+
pl.subplot(141)
248287
pl.axis('off')
249-
pl.axis
250-
pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % (
251-
np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %',
252-
np.round(err0, 4)), fontsize=fontsize)
253288

254-
p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals
289+
pl.title('(CG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
290+
np.round(log_cg['fgw_dist'], 3), str(np.round(T_cg_sparsity, 2)) + ' %',
291+
np.round(err_cg, 4), str(np.round(time_cg, 2)) + ' ms'), fontsize=fontsize)
255292

256293
pos1, pos2 = draw_transp_colored_GW(
257-
weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0,
258-
shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
294+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_cg.sum(1), p2=T_cg.sum(0),
295+
T=T_cg, shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
259296

260-
pl.subplot(132)
297+
pl.subplot(142)
261298
pl.axis('off')
262299

263-
p, q = fgw.sum(1), fgw.sum(0) # check marginals
264-
265-
pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
266-
np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %',
267-
np.round(err, 4)), fontsize=fontsize)
300+
pl.title('(PPA) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
301+
np.round(log_ppa['fgw_dist'], 3), str(np.round(T_ppa_sparsity, 2)) + ' %',
302+
np.round(err_ppa, 4), str(np.round(time_ppa, 2)) + ' ms'), fontsize=fontsize)
268303

269304
pos1, pos2 = draw_transp_colored_GW(
270-
weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw,
271-
pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
305+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_ppa.sum(1), p2=T_ppa.sum(0),
306+
T=T_ppa, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
272307

273-
pl.subplot(133)
308+
pl.subplot(143)
274309
pl.axis('off')
275310

276-
pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals
311+
pl.title('(PGD) Entropic FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
312+
np.round(log_pgd['fgw_dist'], 3), str(np.round(T_pgd_sparsity, 2)) + ' %',
313+
np.round(err_pgd, 4), str(np.round(time_pgd, 2)) + ' ms'), fontsize=fontsize)
314+
315+
pos1, pos2 = draw_transp_colored_GW(
316+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_pgd.sum(1), p2=T_pgd.sum(0),
317+
T=T_pgd, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
318+
319+
320+
pl.subplot(144)
321+
pl.axis('off')
277322

278-
pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
279-
np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %',
280-
np.round(erre, 4)), fontsize=fontsize)
323+
pl.title('(BAPG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
324+
np.round(log_bapg['fgw_dist'], 3), str(np.round(T_bapg_sparsity, 2)) + ' %',
325+
np.round(err_bapg, 4), str(np.round(time_bapg, 2)) + ' ms'), fontsize=fontsize)
281326

282327
pos1, pos2 = draw_transp_colored_GW(
283-
weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe,
284-
pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
328+
weightedG2, C2, weightedG3, C3, part_G2, p1=T_bapg.sum(1), p2=T_bapg.sum(0),
329+
T=T_bapg, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
285330

286331
pl.tight_layout()
287332

ot/gromov/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020

2121
from ._bregman import (entropic_gromov_wasserstein,
2222
entropic_gromov_wasserstein2,
23+
BAPG_gromov_wasserstein,
24+
BAPG_gromov_wasserstein2,
2325
entropic_gromov_barycenters,
2426
entropic_fused_gromov_wasserstein,
2527
entropic_fused_gromov_wasserstein2,
28+
BAPG_fused_gromov_wasserstein,
29+
BAPG_fused_gromov_wasserstein2,
2630
entropic_fused_gromov_barycenters)
2731

2832
from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein,
@@ -49,8 +53,10 @@
4953
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
5054
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
5155
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
56+
'BAPG_gromov_wasserstein', 'BAPG_gromov_wasserstein2',
5257
'entropic_gromov_barycenters', 'entropic_fused_gromov_wasserstein',
53-
'entropic_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
58+
'entropic_fused_gromov_wasserstein2', 'BAPG_fused_gromov_wasserstein',
59+
'BAPG_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
5460
'GW_distance_estimation', 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
5561
'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2',
5662
'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',

0 commit comments

Comments
 (0)