Skip to content

Commit 9fc4ab8

Browse files
committed
add exmaple and debug barycenters
1 parent 55a851e commit 9fc4ab8

File tree

2 files changed

+129
-2
lines changed

2 files changed

+129
-2
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================================
4+
Gaussian Bures-Wasserstein barycenters
5+
========================================================
6+
7+
Illustration of Gaussian Bures-Wasserstein barycenters.
8+
9+
"""
10+
11+
# Authors: Rémi Flamary <remi.flamary@polytechnique.edu>
12+
#
13+
# License: MIT License
14+
15+
# sphinx_gallery_thumbnail_number = 2
16+
# %%
17+
from matplotlib import colors
18+
from matplotlib.patches import Ellipse
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import ot
22+
23+
24+
# %%
25+
# Define Gaussian Covariances and distributions
26+
# ---------------------------------------------
27+
28+
C1 = np.array([[0.5, -0.4], [-0.4, 0.5]])
29+
C2 = np.array([[1, 0.3], [0.3, 1]])
30+
C3 = np.array([[1.5, 0], [0, 0.5]])
31+
C4 = np.array([[0.5, 0], [0, 1.5]])
32+
33+
C = np.stack((C1, C2, C3, C4))
34+
35+
m1 = np.array([0, 0])
36+
m2 = np.array([0, 4])
37+
m3 = np.array([4, 0])
38+
m4 = np.array([4, 4])
39+
40+
m = np.stack((m1, m2, m3, m4))
41+
42+
# %%
43+
# Plot the distributions
44+
# ----------------------
45+
46+
47+
def draw_cov(mu, C, color=None, label=None, nstd=1):
48+
49+
def eigsorted(cov):
50+
vals, vecs = np.linalg.eigh(cov)
51+
order = vals.argsort()[::-1]
52+
return vals[order], vecs[:, order]
53+
54+
vals, vecs = eigsorted(C)
55+
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
56+
w, h = 2 * nstd * np.sqrt(vals)
57+
ell = Ellipse(xy=(mu[0], mu[1]),
58+
width=w, height=h, alpha=0.5,
59+
angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)
60+
pl.gca().add_artist(ell)
61+
#pl.scatter(mu[0],mu[1],color=color, marker='x')
62+
63+
64+
axis = [-1.5, 5.5, -1.5, 5.5]
65+
66+
pl.figure(1, (8, 2))
67+
pl.clf()
68+
69+
pl.subplot(1, 4, 1)
70+
draw_cov(m1, C1, color='C0')
71+
pl.axis(axis)
72+
pl.title('$\mathcal{N}(m_1,\Sigma_1)$')
73+
74+
pl.subplot(1, 4, 2)
75+
draw_cov(m2, C2, color='C1')
76+
pl.axis(axis)
77+
pl.title('$\mathcal{N}(m_2,\Sigma_2)$')
78+
79+
pl.subplot(1, 4, 3)
80+
draw_cov(m3, C3, color='C2')
81+
pl.axis(axis)
82+
pl.title('$\mathcal{N}(m_3,\Sigma_3)$')
83+
84+
pl.subplot(1, 4, 4)
85+
draw_cov(m4, C4, color='C3')
86+
pl.axis(axis)
87+
pl.title('$\mathcal{N}(m_4,\Sigma_4)$')
88+
89+
# %%
90+
# Compute Bures-Wasserstein barycenters and plot them
91+
# -------------------------------------------
92+
93+
# basis for bilinear interpolation
94+
v1 = np.array((1, 0, 0, 0))
95+
v2 = np.array((0, 1, 0, 0))
96+
v3 = np.array((0, 0, 1, 0))
97+
v4 = np.array((0, 0, 0, 1))
98+
99+
100+
colors = np.stack((colors.to_rgb('C0'),
101+
colors.to_rgb('C1'),
102+
colors.to_rgb('C2'),
103+
colors.to_rgb('C3')))
104+
105+
pl.figure(2, (8, 8))
106+
107+
nb_interp = 6
108+
109+
for i in range(nb_interp):
110+
for j in range(nb_interp):
111+
tx = float(i) / (nb_interp - 1)
112+
ty = float(j) / (nb_interp - 1)
113+
114+
# weights are constructed by bilinear interpolation
115+
tmp1 = (1 - tx) * v1 + tx * v2
116+
tmp2 = (1 - tx) * v3 + tx * v4
117+
weights = (1 - ty) * tmp1 + ty * tmp2
118+
119+
color = np.dot(colors.T, weights)
120+
121+
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights)
122+
123+
draw_cov(mb, Cb, color=color, label=None, nstd=0.3)
124+
125+
pl.axis(axis)
126+
pl.axis('off')
127+
pl.tight_layout()

ot/gaussian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,10 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo
400400
nx = get_backend(*C, *m,)
401401

402402
# Compute the mean barycenter
403-
mb = nx.mean(m)
403+
mb = nx.dot(weights, m)
404404

405405
# Init the covariance barycenter
406-
Cb = nx.mean(C, axis=0)
406+
Cb = nx.mean(C * weights[:, None, None], axis=0)
407407

408408
if weights is None:
409409
weights = nx.ones(len(C), type_as=C[0]) / len(C)

0 commit comments

Comments
 (0)