Skip to content

Commit f809253

Browse files
authored
[MRG] Add Wasserstein barycenter for Gaussian distribution (#582)
* add barycenter functions * add tests * update release
1 parent 89e6010 commit f809253

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
2121
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
2222
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
23+
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)
2324

2425
#### Closed issues
2526
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/gaussian.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,188 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
344344
return W
345345

346346

347+
def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False):
348+
r"""Return OT linear operator between samples.
349+
350+
The function estimates the optimal barycenter of the
351+
empirical distributions. This is equivalent to resolving the fixed point
352+
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
353+
:ref:`[1] <references-OT-mapping-linear-barycenter>`.
354+
355+
The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
356+
where :
357+
358+
.. math::
359+
\mu_b = \sum_{i=1}^n w_i \mu_i
360+
361+
And the barycentric covariance is the solution of the following fixed-point algorithm:
362+
363+
.. math::
364+
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
365+
366+
367+
Parameters
368+
----------
369+
m : array-like (k,d)
370+
mean of k distributions
371+
C : array-like (k,d,d)
372+
covariance of k distributions
373+
weights : array-like (k), optional
374+
weights for each distribution
375+
num_iter : int, optional
376+
number of iteration for the fixed point algorithm
377+
eps : float, optional
378+
tolerance for the fixed point algorithm
379+
log : bool, optional
380+
record log if True
381+
382+
383+
Returns
384+
-------
385+
mb : (d,) array-like
386+
mean of the barycenter
387+
Cb : (d, d) array-like
388+
covariance of the barycenter
389+
log : dict
390+
log dictionary return only if log==True in parameters
391+
392+
393+
.. _references-OT-mapping-linear-barycenter:
394+
References
395+
----------
396+
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
397+
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
398+
2011.
399+
"""
400+
nx = get_backend(*C, *m,)
401+
402+
# Compute the mean barycenter
403+
mb = nx.mean(m)
404+
405+
# Init the covariance barycenter
406+
Cb = nx.mean(C, axis=0)
407+
408+
if weights is None:
409+
weights = nx.ones(len(C), type_as=C[0]) / len(C)
410+
411+
for it in range(num_iter):
412+
# fixed point update
413+
Cb12 = nx.sqrtm(Cb)
414+
415+
Cnew = Cb12 @ C @ Cb12
416+
C_ = []
417+
for i in range(len(C)):
418+
C_.append(nx.sqrtm(Cnew[i]))
419+
Cnew = nx.stack(C_, axis=0)
420+
Cnew *= weights[:, None, None]
421+
Cnew = nx.sum(Cnew, axis=0)
422+
423+
# check convergence
424+
diff = nx.norm(Cb - Cnew)
425+
if diff <= eps:
426+
break
427+
Cb = Cnew
428+
else:
429+
print("Dit not converge.")
430+
431+
if log:
432+
log = {}
433+
log['num_iter'] = it
434+
log['final_diff'] = diff
435+
return mb, Cb, log
436+
else:
437+
return mb, Cb
438+
439+
440+
def empirical_bures_wasserstein_barycenter(
441+
X, reg=1e-6, weights=None, num_iter=1000, eps=1e-7,
442+
w=None, bias=True, log=False
443+
):
444+
r"""Return OT linear operator between samples.
445+
446+
The function estimates the optimal barycenter of the
447+
empirical distributions. This is equivalent to resolving the fixed point
448+
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
449+
:ref:`[1] <references-OT-mapping-linear-barycenter>`.
450+
451+
The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
452+
where :
453+
454+
.. math::
455+
\mu_b = \sum_{i=1}^n w_i \mu_i
456+
457+
And the barycentric covariance is the solution of the following fixed-point algorithm:
458+
459+
.. math::
460+
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
461+
462+
463+
Parameters
464+
----------
465+
X : list of array-like (n,d)
466+
samples in each distribution
467+
reg : float,optional
468+
regularization added to the diagonals of covariances (>0)
469+
weights : array-like (n,), optional
470+
weights for each distribution
471+
num_iter : int, optional
472+
number of iteration for the fixed point algorithm
473+
eps : float, optional
474+
tolerance for the fixed point algorithm
475+
w : list of array-like (n,), optional
476+
weights for each sample in each distribution
477+
bias: boolean, optional
478+
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
479+
log : bool, optional
480+
record log if True
481+
482+
483+
Returns
484+
-------
485+
mb : (d,) array-like
486+
mean of the barycenter
487+
Cb : (d, d) array-like
488+
covariance of the barycenter
489+
log : dict
490+
log dictionary return only if log==True in parameters
491+
492+
493+
.. _references-OT-mapping-linear-barycenter:
494+
References
495+
----------
496+
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
497+
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
498+
2011.
499+
"""
500+
X = list_to_array(*X)
501+
nx = get_backend(*X)
502+
503+
k = len(X)
504+
d = [X[i].shape[1] for i in range(k)]
505+
506+
if bias:
507+
m = [nx.mean(X[i], axis=0)[None, :] for i in range(k)]
508+
X = [X[i] - m[i] for i in range(k)]
509+
else:
510+
m = [nx.zeros((1, d[i]), type_as=X[i]) for i in range(k)]
511+
512+
if w is None:
513+
w = [nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k)]
514+
515+
C = [
516+
nx.dot((X[i] * w[i]).T, X[i]) / nx.sum(w[i]) + reg * nx.eye(d[i], type_as=X[i])
517+
for i in range(k)
518+
]
519+
m = nx.stack(m, axis=0)
520+
C = nx.stack(C, axis=0)
521+
if log:
522+
mb, Cb, log = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log)
523+
return mb, Cb, log
524+
else:
525+
mb, Cb = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log)
526+
return mb, Cb
527+
528+
347529
def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False):
348530
r""" Return the Gaussian Gromov-Wasserstein value from [57].
349531

test/test_gaussian.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,71 @@ def test_empirical_bures_wasserstein_distance(nx, bias):
108108
np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
109109

110110

111+
def test_bures_wasserstein_barycenter(nx):
112+
n = 50
113+
k = 10
114+
X = []
115+
y = []
116+
m = []
117+
C = []
118+
for _ in range(k):
119+
X_, y_ = make_data_classif('3gauss', n)
120+
m_ = np.mean(X_, axis=0)[None, :]
121+
C_ = np.cov(X_.T)
122+
X.append(X_)
123+
y.append(y_)
124+
m.append(m_)
125+
C.append(C_)
126+
m = np.array(m)
127+
C = np.array(C)
128+
X = nx.from_numpy(*X)
129+
m = nx.from_numpy(m)
130+
C = nx.from_numpy(C)
131+
132+
mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, log=True)
133+
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, log=False)
134+
135+
np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
136+
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)
137+
138+
# Test weights argument
139+
weights = nx.ones(k) / k
140+
mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, log=False)
141+
np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2)
142+
143+
# test with closed form for diagonal covariance matrices
144+
Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)]
145+
Cdiag = nx.stack(Cdiag, axis=0)
146+
mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, log=False)
147+
148+
Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag]
149+
Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0)
150+
Cdiag_mean = nx.mean(Cdiag_sqrt, axis=0)
151+
Cdiag_cf = Cdiag_mean @ Cdiag_mean
152+
153+
np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2)
154+
155+
156+
@pytest.mark.parametrize("bias", [True, False])
157+
def test_empirical_bures_wasserstein_barycenter(nx, bias):
158+
n = 50
159+
k = 10
160+
X = []
161+
y = []
162+
for _ in range(k):
163+
X_, y_ = make_data_classif('3gauss', n)
164+
X.append(X_)
165+
y.append(y_)
166+
167+
X = nx.from_numpy(*X)
168+
169+
mblog, Cblog, log = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=True, bias=bias)
170+
mb, Cb = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=False, bias=bias)
171+
172+
np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
173+
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)
174+
175+
111176
@pytest.mark.parametrize("d_target", [1, 2, 3, 10])
112177
def test_gaussian_gromov_wasserstein_distance(nx, d_target):
113178
ns = 400

0 commit comments

Comments
 (0)