diff --git a/RELEASES.md b/RELEASES.md index 9919076f6..3c428c521 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -21,7 +21,7 @@ + Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578) + Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578) + Add new BAPG solvers with KL projections for GW and FGW (PR #581) -+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582) ++ Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584) #### Closed issues diff --git a/examples/barycenters/plot_gaussian_barycenter.py b/examples/barycenters/plot_gaussian_barycenter.py new file mode 100644 index 000000000..c36b5daa9 --- /dev/null +++ b/examples/barycenters/plot_gaussian_barycenter.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +Gaussian Bures-Wasserstein barycenters +======================================================== + +Illustration of Gaussian Bures-Wasserstein barycenters. + +""" + +# Authors: Rémi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 +# %% +from matplotlib import colors +from matplotlib.patches import Ellipse +import numpy as np +import matplotlib.pylab as pl +import ot + + +# %% +# Define Gaussian Covariances and distributions +# --------------------------------------------- + +C1 = np.array([[0.5, -0.4], [-0.4, 0.5]]) +C2 = np.array([[1, 0.3], [0.3, 1]]) +C3 = np.array([[1.5, 0], [0, 0.5]]) +C4 = np.array([[0.5, 0], [0, 1.5]]) + +C = np.stack((C1, C2, C3, C4)) + +m1 = np.array([0, 0]) +m2 = np.array([0, 4]) +m3 = np.array([4, 0]) +m4 = np.array([4, 4]) + +m = np.stack((m1, m2, m3, m4)) + +# %% +# Plot the distributions +# ---------------------- + + +def draw_cov(mu, C, color=None, label=None, nstd=1): + + def eigsorted(cov): + vals, vecs = np.linalg.eigh(cov) + order = vals.argsort()[::-1] + return vals[order], vecs[:, order] + + vals, vecs = eigsorted(C) + theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) + w, h = 2 * nstd * np.sqrt(vals) + ell = Ellipse(xy=(mu[0], mu[1]), + width=w, height=h, alpha=0.5, + angle=theta, facecolor=color, edgecolor=color, label=label, fill=True) + pl.gca().add_artist(ell) + #pl.scatter(mu[0],mu[1],color=color, marker='x') + + +axis = [-1.5, 5.5, -1.5, 5.5] + +pl.figure(1, (8, 2)) +pl.clf() + +pl.subplot(1, 4, 1) +draw_cov(m1, C1, color='C0') +pl.axis(axis) +pl.title('$\mathcal{N}(m_1,\Sigma_1)$') + +pl.subplot(1, 4, 2) +draw_cov(m2, C2, color='C1') +pl.axis(axis) +pl.title('$\mathcal{N}(m_2,\Sigma_2)$') + +pl.subplot(1, 4, 3) +draw_cov(m3, C3, color='C2') +pl.axis(axis) +pl.title('$\mathcal{N}(m_3,\Sigma_3)$') + +pl.subplot(1, 4, 4) +draw_cov(m4, C4, color='C3') +pl.axis(axis) +pl.title('$\mathcal{N}(m_4,\Sigma_4)$') + +# %% +# Compute Bures-Wasserstein barycenters and plot them +# ------------------------------------------- + +# basis for bilinear interpolation +v1 = np.array((1, 0, 0, 0)) +v2 = np.array((0, 1, 0, 0)) +v3 = np.array((0, 0, 1, 0)) +v4 = np.array((0, 0, 0, 1)) + + +colors = np.stack((colors.to_rgb('C0'), + colors.to_rgb('C1'), + colors.to_rgb('C2'), + colors.to_rgb('C3'))) + +pl.figure(2, (8, 8)) + +nb_interp = 6 + +for i in range(nb_interp): + for j in range(nb_interp): + tx = float(i) / (nb_interp - 1) + ty = float(j) / (nb_interp - 1) + + # weights are constructed by bilinear interpolation + tmp1 = (1 - tx) * v1 + tx * v2 + tmp2 = (1 - tx) * v3 + tx * v4 + weights = (1 - ty) * tmp1 + ty * tmp2 + + color = np.dot(colors.T, weights) + + mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights) + + draw_cov(mb, Cb, color=color, label=None, nstd=0.3) + +pl.axis(axis) +pl.axis('off') +pl.tight_layout() diff --git a/ot/gaussian.py b/ot/gaussian.py index a3ae90f87..e9d475b52 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -399,14 +399,14 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo """ nx = get_backend(*C, *m,) + if weights is None: + weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] + # Compute the mean barycenter - mb = nx.mean(m) + mb = nx.sum(m * weights[:, None], axis=0) # Init the covariance barycenter - Cb = nx.mean(C, axis=0) - - if weights is None: - weights = nx.ones(len(C), type_as=C[0]) / len(C) + Cb = nx.mean(C * weights[:, None, None], axis=0) for it in range(num_iter): # fixed point update