Skip to content

[MRG] Add Wasserstein barycenter for Gaussian distribution #582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
182 changes: 182 additions & 0 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,188 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
return W


def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False):
r"""Return OT linear operator between samples.

The function estimates the optimal barycenter of the
empirical distributions. This is equivalent to resolving the fixed point
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>`.

The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
where :

.. math::
\mu_b = \sum_{i=1}^n w_i \mu_i

And the barycentric covariance is the solution of the following fixed-point algorithm:

.. math::
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}


Parameters
----------
m : array-like (k,d)
mean of k distributions
C : array-like (k,d,d)
covariance of k distributions
weights : array-like (k), optional
weights for each distribution
num_iter : int, optional
number of iteration for the fixed point algorithm
eps : float, optional
tolerance for the fixed point algorithm
log : bool, optional
record log if True


Returns
-------
mb : (d,) array-like
mean of the barycenter
Cb : (d, d) array-like
covariance of the barycenter
log : dict
log dictionary return only if log==True in parameters


.. _references-OT-mapping-linear-barycenter:
References
----------
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
2011.
"""
nx = get_backend(*C, *m,)

# Compute the mean barycenter
mb = nx.mean(m)

# Init the covariance barycenter
Cb = nx.mean(C, axis=0)

if weights is None:
weights = nx.ones(len(C), type_as=C[0]) / len(C)

for it in range(num_iter):
# fixed point update
Cb12 = nx.sqrtm(Cb)

Cnew = Cb12 @ C @ Cb12
C_ = []
for i in range(len(C)):
C_.append(nx.sqrtm(Cnew[i]))
Cnew = nx.stack(C_, axis=0)
Cnew *= weights[:, None, None]
Cnew = nx.sum(Cnew, axis=0)

# check convergence
diff = nx.norm(Cb - Cnew)
if diff <= eps:
break
Cb = Cnew
else:
print("Dit not converge.")

if log:
log = {}
log['num_iter'] = it
log['final_diff'] = diff
return mb, Cb, log
else:
return mb, Cb


def empirical_bures_wasserstein_barycenter(
X, reg=1e-6, weights=None, num_iter=1000, eps=1e-7,
w=None, bias=True, log=False
):
r"""Return OT linear operator between samples.

The function estimates the optimal barycenter of the
empirical distributions. This is equivalent to resolving the fixed point
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>`.

The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
where :

.. math::
\mu_b = \sum_{i=1}^n w_i \mu_i

And the barycentric covariance is the solution of the following fixed-point algorithm:

.. math::
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}


Parameters
----------
X : list of array-like (n,d)
samples in each distribution
reg : float,optional
regularization added to the diagonals of covariances (>0)
weights : array-like (n,), optional
weights for each distribution
num_iter : int, optional
number of iteration for the fixed point algorithm
eps : float, optional
tolerance for the fixed point algorithm
w : list of array-like (n,), optional
weights for each sample in each distribution
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True


Returns
-------
mb : (d,) array-like
mean of the barycenter
Cb : (d, d) array-like
covariance of the barycenter
log : dict
log dictionary return only if log==True in parameters


.. _references-OT-mapping-linear-barycenter:
References
----------
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
2011.
"""
X = list_to_array(*X)
nx = get_backend(*X)

k = len(X)
d = [X[i].shape[1] for i in range(k)]

if bias:
m = [nx.mean(X[i], axis=0)[None, :] for i in range(k)]
X = [X[i] - m[i] for i in range(k)]
else:
m = [nx.zeros((1, d[i]), type_as=X[i]) for i in range(k)]

if w is None:
w = [nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k)]

C = [
nx.dot((X[i] * w[i]).T, X[i]) / nx.sum(w[i]) + reg * nx.eye(d[i], type_as=X[i])
for i in range(k)
]
m = nx.stack(m, axis=0)
C = nx.stack(C, axis=0)
if log:
mb, Cb, log = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log)
return mb, Cb, log
else:
mb, Cb = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log)
return mb, Cb


def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False):
r""" Return the Gaussian Gromov-Wasserstein value from [57].

Expand Down
65 changes: 65 additions & 0 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,71 @@ def test_empirical_bures_wasserstein_distance(nx, bias):
np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)


def test_bures_wasserstein_barycenter(nx):
n = 50
k = 10
X = []
y = []
m = []
C = []
for _ in range(k):
X_, y_ = make_data_classif('3gauss', n)
m_ = np.mean(X_, axis=0)[None, :]
C_ = np.cov(X_.T)
X.append(X_)
y.append(y_)
m.append(m_)
C.append(C_)
m = np.array(m)
C = np.array(C)
X = nx.from_numpy(*X)
m = nx.from_numpy(m)
C = nx.from_numpy(C)

mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, log=True)
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, log=False)

np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)

# Test weights argument
weights = nx.ones(k) / k
mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, log=False)
np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2)

# test with closed form for diagonal covariance matrices
Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)]
Cdiag = nx.stack(Cdiag, axis=0)
mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, log=False)

Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag]
Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0)
Cdiag_mean = nx.mean(Cdiag_sqrt, axis=0)
Cdiag_cf = Cdiag_mean @ Cdiag_mean

np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("bias", [True, False])
def test_empirical_bures_wasserstein_barycenter(nx, bias):
n = 50
k = 10
X = []
y = []
for _ in range(k):
X_, y_ = make_data_classif('3gauss', n)
X.append(X_)
y.append(y_)

X = nx.from_numpy(*X)

mblog, Cblog, log = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=True, bias=bias)
mb, Cb = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=False, bias=bias)

np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("d_target", [1, 2, 3, 10])
def test_gaussian_gromov_wasserstein_distance(nx, d_target):
ns = 400
Expand Down