|
5 | 5 | ==============================
|
6 | 6 |
|
7 | 7 | This example illustrates the computation of FGW for attributed graphs
|
8 |
| -using 3 different solvers to estimate the distance based on Conditional |
9 |
| -Gradient [24] or Sinkhorn projections [12, 51]. |
| 8 | +using 4 different solvers to estimate the distance based on Conditional |
| 9 | +Gradient [24], Sinkhorn projections [12, 51] and alternated Bregman |
| 10 | +projections [63, 64]. |
10 | 11 |
|
11 | 12 | We generate two graphs following Stochastic Block Models further endowed with
|
12 | 13 | node features and compute their FGW matchings.
|
|
23 | 24 | [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
|
24 | 25 | "Gromov-wasserstein learning for graph matching and node embedding".
|
25 | 26 | In International Conference on Machine Learning (ICML), 2019.
|
| 27 | +
|
| 28 | +[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. |
| 29 | +"A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in |
| 30 | +Graph Data". International Conference on Learning Representations (ICLR), 2023. |
| 31 | +
|
| 32 | +[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. |
| 33 | +"Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications". |
| 34 | +In Thirty-seventh Conference on Neural Information Processing Systems |
| 35 | +(NeurIPS), 2023. |
| 36 | +
|
26 | 37 | """
|
27 | 38 |
|
28 | 39 | # Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
|
|
33 | 44 |
|
34 | 45 | import numpy as np
|
35 | 46 | import matplotlib.pylab as pl
|
36 |
| -from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein |
| 47 | +from ot.gromov import (fused_gromov_wasserstein, |
| 48 | + entropic_fused_gromov_wasserstein, |
| 49 | + BAPG_fused_gromov_wasserstein) |
37 | 50 | import networkx
|
38 | 51 | from networkx.generators.community import stochastic_block_model as sbm
|
| 52 | +from time import time |
39 | 53 |
|
40 | 54 | #############################################################################
|
41 | 55 | #
|
|
85 | 99 |
|
86 | 100 |
|
87 | 101 | # Conditional Gradient algorithm
|
88 |
| -fgw0, log0 = fused_gromov_wasserstein( |
89 |
| - M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True) |
| 102 | +print('Conditional Gradient \n') |
| 103 | +start_cg = time() |
| 104 | +T_cg, log_cg = fused_gromov_wasserstein( |
| 105 | + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, tol_rel=1e-9, |
| 106 | + verbose=True, log=True) |
| 107 | +end_cg = time() |
| 108 | +time_cg = 1000 * (end_cg - start_cg) |
90 | 109 |
|
91 | 110 | # Proximal Point algorithm with Kullback-Leibler as proximal operator
|
92 |
| -fgw, log = entropic_fused_gromov_wasserstein( |
| 111 | +print('Proximal Point Algorithm \n') |
| 112 | +start_ppa = time() |
| 113 | +T_ppa, log_ppa = entropic_fused_gromov_wasserstein( |
93 | 114 | M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA',
|
94 |
| - log=True, verbose=True, warmstart=False, numItermax=10) |
| 115 | + tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10) |
| 116 | +end_ppa = time() |
| 117 | +time_ppa = 1000 * (end_ppa - start_ppa) |
95 | 118 |
|
96 | 119 | # Projected Gradient algorithm with entropic regularization
|
97 |
| -fgwe, loge = entropic_fused_gromov_wasserstein( |
| 120 | +print('Projected Gradient Descent \n') |
| 121 | +start_pgd = time() |
| 122 | +T_pgd, log_pgd = entropic_fused_gromov_wasserstein( |
98 | 123 | M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD',
|
99 |
| - log=True, verbose=True, warmstart=False, numItermax=10) |
100 |
| - |
101 |
| -print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist'])) |
102 |
| -print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist'])) |
103 |
| -print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist'])) |
| 124 | + tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10) |
| 125 | +end_pgd = time() |
| 126 | +time_pgd = 1000 * (end_pgd - start_pgd) |
| 127 | + |
| 128 | +# Alternated Bregman Projected Gradient algorithm with Kullback-Leibler as proximal operator |
| 129 | +print('Bregman Alternated Projected Gradient \n') |
| 130 | +start_bapg = time() |
| 131 | +T_bapg, log_bapg = BAPG_fused_gromov_wasserstein( |
| 132 | + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., |
| 133 | + tol=1e-9, marginal_loss=True, verbose=True, log=True) |
| 134 | +end_bapg = time() |
| 135 | +time_bapg = 1000 * (end_bapg - start_bapg) |
| 136 | + |
| 137 | +print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log_cg['fgw_dist'])) |
| 138 | +print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log_ppa['fgw_dist'])) |
| 139 | +print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_pgd['fgw_dist'])) |
| 140 | +print('Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_bapg['fgw_dist'])) |
104 | 141 |
|
105 | 142 | # compute OT sparsity level
|
106 |
| -fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3) |
107 |
| -fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3) |
108 |
| -fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3) |
| 143 | +T_cg_sparsity = 100 * (T_cg == 0.).astype(np.float64).sum() / (N2 * N3) |
| 144 | +T_ppa_sparsity = 100 * (T_ppa == 0.).astype(np.float64).sum() / (N2 * N3) |
| 145 | +T_pgd_sparsity = 100 * (T_pgd == 0.).astype(np.float64).sum() / (N2 * N3) |
| 146 | +T_bapg_sparsity = 100 * (T_bapg == 0.).astype(np.float64).sum() / (N2 * N3) |
109 | 147 |
|
110 |
| -# Methods using Sinkhorn projections tend to produce feasibility errors on the |
| 148 | +# Methods using Sinkhorn/Bregman projections tend to produce feasibility errors on the |
111 | 149 | # marginal constraints
|
112 | 150 |
|
113 |
| -err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3) |
114 |
| -err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3) |
115 |
| -erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3) |
| 151 | +err_cg = np.linalg.norm(T_cg.sum(1) - h2) + np.linalg.norm(T_cg.sum(0) - h3) |
| 152 | +err_ppa = np.linalg.norm(T_ppa.sum(1) - h2) + np.linalg.norm(T_ppa.sum(0) - h3) |
| 153 | +err_pgd = np.linalg.norm(T_pgd.sum(1) - h2) + np.linalg.norm(T_pgd.sum(0) - h3) |
| 154 | +err_bapg = np.linalg.norm(T_bapg.sum(1) - h2) + np.linalg.norm(T_bapg.sum(0) - h3) |
116 | 155 |
|
117 | 156 | #############################################################################
|
118 | 157 | #
|
@@ -242,46 +281,52 @@ def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T,
|
242 | 281 | seed_G2 = 0
|
243 | 282 | seed_G3 = 4
|
244 | 283 |
|
245 |
| -pl.figure(2, figsize=(12, 3.5)) |
| 284 | +pl.figure(2, figsize=(15, 3.5)) |
246 | 285 | pl.clf()
|
247 |
| -pl.subplot(131) |
| 286 | +pl.subplot(141) |
248 | 287 | pl.axis('off')
|
249 |
| -pl.axis |
250 |
| -pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % ( |
251 |
| - np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %', |
252 |
| - np.round(err0, 4)), fontsize=fontsize) |
253 | 288 |
|
254 |
| -p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals |
| 289 | +pl.title('(CG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( |
| 290 | + np.round(log_cg['fgw_dist'], 3), str(np.round(T_cg_sparsity, 2)) + ' %', |
| 291 | + np.round(err_cg, 4), str(np.round(time_cg, 2)) + ' ms'), fontsize=fontsize) |
255 | 292 |
|
256 | 293 | pos1, pos2 = draw_transp_colored_GW(
|
257 |
| - weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0, |
258 |
| - shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) |
| 294 | + weightedG2, C2, weightedG3, C3, part_G2, p1=T_cg.sum(1), p2=T_cg.sum(0), |
| 295 | + T=T_cg, shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) |
259 | 296 |
|
260 |
| -pl.subplot(132) |
| 297 | +pl.subplot(142) |
261 | 298 | pl.axis('off')
|
262 | 299 |
|
263 |
| -p, q = fgw.sum(1), fgw.sum(0) # check marginals |
264 |
| - |
265 |
| -pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % ( |
266 |
| - np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %', |
267 |
| - np.round(err, 4)), fontsize=fontsize) |
| 300 | +pl.title('(PPA) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( |
| 301 | + np.round(log_ppa['fgw_dist'], 3), str(np.round(T_ppa_sparsity, 2)) + ' %', |
| 302 | + np.round(err_ppa, 4), str(np.round(time_ppa, 2)) + ' ms'), fontsize=fontsize) |
268 | 303 |
|
269 | 304 | pos1, pos2 = draw_transp_colored_GW(
|
270 |
| - weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw, |
271 |
| - pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) |
| 305 | + weightedG2, C2, weightedG3, C3, part_G2, p1=T_ppa.sum(1), p2=T_ppa.sum(0), |
| 306 | + T=T_ppa, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) |
272 | 307 |
|
273 |
| -pl.subplot(133) |
| 308 | +pl.subplot(143) |
274 | 309 | pl.axis('off')
|
275 | 310 |
|
276 |
| -pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals |
| 311 | +pl.title('(PGD) Entropic FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( |
| 312 | + np.round(log_pgd['fgw_dist'], 3), str(np.round(T_pgd_sparsity, 2)) + ' %', |
| 313 | + np.round(err_pgd, 4), str(np.round(time_pgd, 2)) + ' ms'), fontsize=fontsize) |
| 314 | + |
| 315 | +pos1, pos2 = draw_transp_colored_GW( |
| 316 | + weightedG2, C2, weightedG3, C3, part_G2, p1=T_pgd.sum(1), p2=T_pgd.sum(0), |
| 317 | + T=T_pgd, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) |
| 318 | + |
| 319 | + |
| 320 | +pl.subplot(144) |
| 321 | +pl.axis('off') |
277 | 322 |
|
278 |
| -pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % ( |
279 |
| - np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %', |
280 |
| - np.round(erre, 4)), fontsize=fontsize) |
| 323 | +pl.title('(BAPG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( |
| 324 | + np.round(log_bapg['fgw_dist'], 3), str(np.round(T_bapg_sparsity, 2)) + ' %', |
| 325 | + np.round(err_bapg, 4), str(np.round(time_bapg, 2)) + ' ms'), fontsize=fontsize) |
281 | 326 |
|
282 | 327 | pos1, pos2 = draw_transp_colored_GW(
|
283 |
| - weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe, |
284 |
| - pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) |
| 328 | + weightedG2, C2, weightedG3, C3, part_G2, p1=T_bapg.sum(1), p2=T_bapg.sum(0), |
| 329 | + T=T_bapg, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) |
285 | 330 |
|
286 | 331 | pl.tight_layout()
|
287 | 332 |
|
|
0 commit comments