Skip to content

Commit 659cde8

Browse files
authored
[MRG] Add Bures-Wasserstein arycenetrs example (and debug the solver) (#584)
* add exmaple and debug barycenters * debug barycenter again
1 parent 55a851e commit 659cde8

File tree

3 files changed

+133
-6
lines changed

3 files changed

+133
-6
lines changed

RELEASES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
2222
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
2323
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
24-
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)
24+
+ Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584)
2525

2626

2727
#### Closed issues
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================================
4+
Gaussian Bures-Wasserstein barycenters
5+
========================================================
6+
7+
Illustration of Gaussian Bures-Wasserstein barycenters.
8+
9+
"""
10+
11+
# Authors: Rémi Flamary <remi.flamary@polytechnique.edu>
12+
#
13+
# License: MIT License
14+
15+
# sphinx_gallery_thumbnail_number = 2
16+
# %%
17+
from matplotlib import colors
18+
from matplotlib.patches import Ellipse
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import ot
22+
23+
24+
# %%
25+
# Define Gaussian Covariances and distributions
26+
# ---------------------------------------------
27+
28+
C1 = np.array([[0.5, -0.4], [-0.4, 0.5]])
29+
C2 = np.array([[1, 0.3], [0.3, 1]])
30+
C3 = np.array([[1.5, 0], [0, 0.5]])
31+
C4 = np.array([[0.5, 0], [0, 1.5]])
32+
33+
C = np.stack((C1, C2, C3, C4))
34+
35+
m1 = np.array([0, 0])
36+
m2 = np.array([0, 4])
37+
m3 = np.array([4, 0])
38+
m4 = np.array([4, 4])
39+
40+
m = np.stack((m1, m2, m3, m4))
41+
42+
# %%
43+
# Plot the distributions
44+
# ----------------------
45+
46+
47+
def draw_cov(mu, C, color=None, label=None, nstd=1):
48+
49+
def eigsorted(cov):
50+
vals, vecs = np.linalg.eigh(cov)
51+
order = vals.argsort()[::-1]
52+
return vals[order], vecs[:, order]
53+
54+
vals, vecs = eigsorted(C)
55+
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
56+
w, h = 2 * nstd * np.sqrt(vals)
57+
ell = Ellipse(xy=(mu[0], mu[1]),
58+
width=w, height=h, alpha=0.5,
59+
angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)
60+
pl.gca().add_artist(ell)
61+
#pl.scatter(mu[0],mu[1],color=color, marker='x')
62+
63+
64+
axis = [-1.5, 5.5, -1.5, 5.5]
65+
66+
pl.figure(1, (8, 2))
67+
pl.clf()
68+
69+
pl.subplot(1, 4, 1)
70+
draw_cov(m1, C1, color='C0')
71+
pl.axis(axis)
72+
pl.title('$\mathcal{N}(m_1,\Sigma_1)$')
73+
74+
pl.subplot(1, 4, 2)
75+
draw_cov(m2, C2, color='C1')
76+
pl.axis(axis)
77+
pl.title('$\mathcal{N}(m_2,\Sigma_2)$')
78+
79+
pl.subplot(1, 4, 3)
80+
draw_cov(m3, C3, color='C2')
81+
pl.axis(axis)
82+
pl.title('$\mathcal{N}(m_3,\Sigma_3)$')
83+
84+
pl.subplot(1, 4, 4)
85+
draw_cov(m4, C4, color='C3')
86+
pl.axis(axis)
87+
pl.title('$\mathcal{N}(m_4,\Sigma_4)$')
88+
89+
# %%
90+
# Compute Bures-Wasserstein barycenters and plot them
91+
# -------------------------------------------
92+
93+
# basis for bilinear interpolation
94+
v1 = np.array((1, 0, 0, 0))
95+
v2 = np.array((0, 1, 0, 0))
96+
v3 = np.array((0, 0, 1, 0))
97+
v4 = np.array((0, 0, 0, 1))
98+
99+
100+
colors = np.stack((colors.to_rgb('C0'),
101+
colors.to_rgb('C1'),
102+
colors.to_rgb('C2'),
103+
colors.to_rgb('C3')))
104+
105+
pl.figure(2, (8, 8))
106+
107+
nb_interp = 6
108+
109+
for i in range(nb_interp):
110+
for j in range(nb_interp):
111+
tx = float(i) / (nb_interp - 1)
112+
ty = float(j) / (nb_interp - 1)
113+
114+
# weights are constructed by bilinear interpolation
115+
tmp1 = (1 - tx) * v1 + tx * v2
116+
tmp2 = (1 - tx) * v3 + tx * v4
117+
weights = (1 - ty) * tmp1 + ty * tmp2
118+
119+
color = np.dot(colors.T, weights)
120+
121+
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights)
122+
123+
draw_cov(mb, Cb, color=color, label=None, nstd=0.3)
124+
125+
pl.axis(axis)
126+
pl.axis('off')
127+
pl.tight_layout()

ot/gaussian.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,14 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo
399399
"""
400400
nx = get_backend(*C, *m,)
401401

402+
if weights is None:
403+
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]
404+
402405
# Compute the mean barycenter
403-
mb = nx.mean(m)
406+
mb = nx.sum(m * weights[:, None], axis=0)
404407

405408
# Init the covariance barycenter
406-
Cb = nx.mean(C, axis=0)
407-
408-
if weights is None:
409-
weights = nx.ones(len(C), type_as=C[0]) / len(C)
409+
Cb = nx.mean(C * weights[:, None, None], axis=0)
410410

411411
for it in range(num_iter):
412412
# fixed point update

0 commit comments

Comments
 (0)