|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +======================================== |
| 4 | +Low rank Gromov-Wasterstein between samples |
| 5 | +======================================== |
| 6 | +
|
| 7 | +Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67] |
| 8 | +on two curves in 2D and 3D, both sampled with 200 points. |
| 9 | +
|
| 10 | +The squared Euclidean distance is considered as the ground cost for both samples. |
| 11 | +
|
| 12 | +[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). |
| 13 | +"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". |
| 14 | +In International Conference on Machine Learning (ICML), 2022. |
| 15 | +""" |
| 16 | + |
| 17 | +# Author: Laurène David <laurene.david@ip-paris.fr> |
| 18 | +# |
| 19 | +# License: MIT License |
| 20 | +# |
| 21 | +# sphinx_gallery_thumbnail_number = 3 |
| 22 | + |
| 23 | +#%% |
| 24 | +import numpy as np |
| 25 | +import matplotlib.pylab as pl |
| 26 | +import ot.plot |
| 27 | +import time |
| 28 | + |
| 29 | +############################################################################## |
| 30 | +# Generate data |
| 31 | +# ------------- |
| 32 | + |
| 33 | +#%% parameters |
| 34 | +n_samples = 200 |
| 35 | + |
| 36 | +# Generate 2D and 3D curves |
| 37 | +theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples) |
| 38 | +z = np.linspace(1, 2, n_samples) |
| 39 | +r = z**2 + 1 |
| 40 | +x = r * np.sin(theta) |
| 41 | +y = r * np.cos(theta) |
| 42 | + |
| 43 | +# Source and target distribution |
| 44 | +X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1) |
| 45 | +Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1) |
| 46 | + |
| 47 | + |
| 48 | +############################################################################## |
| 49 | +# Plot data |
| 50 | +# ------------ |
| 51 | + |
| 52 | +#%% |
| 53 | +# Plot the source and target samples |
| 54 | +fig = pl.figure(1, figsize=(10, 4)) |
| 55 | + |
| 56 | +ax = fig.add_subplot(121) |
| 57 | +ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6) |
| 58 | +ax.tick_params(left=False, right=False, labelleft=False, |
| 59 | + labelbottom=False, bottom=False) |
| 60 | +ax.set_title("2D curve (source)") |
| 61 | + |
| 62 | +ax2 = fig.add_subplot(122, projection="3d") |
| 63 | +ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6) |
| 64 | +ax2.tick_params(left=False, right=False, labelleft=False, |
| 65 | + labelbottom=False, bottom=False) |
| 66 | +ax2.view_init(15, -50) |
| 67 | +ax2.set_title("3D curve (target)") |
| 68 | + |
| 69 | +pl.tight_layout() |
| 70 | +pl.show() |
| 71 | + |
| 72 | + |
| 73 | +############################################################################## |
| 74 | +# Entropic Gromov-Wasserstein |
| 75 | +# ------------ |
| 76 | + |
| 77 | +#%% |
| 78 | + |
| 79 | +# Compute cost matrices |
| 80 | +C1 = ot.dist(X, X, metric="sqeuclidean") |
| 81 | +C2 = ot.dist(Y, Y, metric="sqeuclidean") |
| 82 | + |
| 83 | +# Scale cost matrices |
| 84 | +r1 = C1.max() |
| 85 | +r2 = C2.max() |
| 86 | + |
| 87 | +C1 = C1 / r1 |
| 88 | +C2 = C2 / r2 |
| 89 | + |
| 90 | + |
| 91 | +# Solve entropic gw |
| 92 | +reg = 5 * 1e-3 |
| 93 | + |
| 94 | +start = time.time() |
| 95 | +gw, log = ot.gromov.entropic_gromov_wasserstein( |
| 96 | + C1, C2, tol=1e-3, epsilon=reg, |
| 97 | + log=True, verbose=False) |
| 98 | + |
| 99 | +end = time.time() |
| 100 | +time_entropic = end - start |
| 101 | + |
| 102 | +entropic_gw_loss = np.round(log['gw_dist'], 3) |
| 103 | + |
| 104 | +# Plot entropic gw |
| 105 | +pl.figure(2) |
| 106 | +pl.imshow(gw, interpolation="nearest", aspect="auto") |
| 107 | +pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss)) |
| 108 | +pl.show() |
| 109 | + |
| 110 | + |
| 111 | +############################################################################## |
| 112 | +# Low rank squared euclidean cost matrices |
| 113 | +# ------------ |
| 114 | +# %% |
| 115 | + |
| 116 | +# Compute the low rank sqeuclidean cost decompositions |
| 117 | +A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) |
| 118 | +B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False) |
| 119 | + |
| 120 | +# Scale the low rank cost matrices |
| 121 | +A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1) |
| 122 | +B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2) |
| 123 | + |
| 124 | + |
| 125 | +############################################################################## |
| 126 | +# Low rank Gromov-Wasserstein |
| 127 | +# ------------ |
| 128 | +# %% |
| 129 | + |
| 130 | +# Solve low rank gromov-wasserstein with different ranks |
| 131 | +list_rank = [10, 50] |
| 132 | +list_P_GW = [] |
| 133 | +list_loss_GW = [] |
| 134 | +list_time_GW = [] |
| 135 | + |
| 136 | +for rank in list_rank: |
| 137 | + start = time.time() |
| 138 | + |
| 139 | + Q, R, g, log = ot.lowrank_gromov_wasserstein_samples( |
| 140 | + X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2), |
| 141 | + cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6, |
| 142 | + ) |
| 143 | + end = time.time() |
| 144 | + |
| 145 | + P = log["lazy_plan"][:] |
| 146 | + loss = log["value"] |
| 147 | + |
| 148 | + list_P_GW.append(P) |
| 149 | + list_loss_GW.append(np.round(loss, 3)) |
| 150 | + list_time_GW.append(end - start) |
| 151 | + |
| 152 | + |
| 153 | +# %% |
| 154 | +# Plot low rank GW with different ranks |
| 155 | +pl.figure(3, figsize=(10, 4)) |
| 156 | + |
| 157 | +pl.subplot(1, 2, 1) |
| 158 | +pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto") |
| 159 | +pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0])) |
| 160 | + |
| 161 | +pl.subplot(1, 2, 2) |
| 162 | +pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto") |
| 163 | +pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1])) |
| 164 | + |
| 165 | +pl.tight_layout() |
| 166 | +pl.show() |
| 167 | + |
| 168 | + |
| 169 | +# %% |
| 170 | +# Compare computation time between entropic GW and low rank GW |
| 171 | +print("Entropic GW: {:.2f}s".format(time_entropic)) |
| 172 | +print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0])) |
| 173 | +print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1])) |
0 commit comments