From e1d815b0fb6e63cba61d4c208a138e85aa3b8821 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Wed, 29 Nov 2023 14:39:18 +0100 Subject: [PATCH 1/3] add barycenter functions --- ot/gaussian.py | 182 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) diff --git a/ot/gaussian.py b/ot/gaussian.py index 0ddb92013..a3ae90f87 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -344,6 +344,188 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, return W +def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False): + r"""Return OT linear operator between samples. + + The function estimates the optimal barycenter of the + empirical distributions. This is equivalent to resolving the fixed point + algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n` + :ref:`[1] `. + + The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` + where : + + .. math:: + \mu_b = \sum_{i=1}^n w_i \mu_i + + And the barycentric covariance is the solution of the following fixed-point algorithm: + + .. math:: + \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} + + + Parameters + ---------- + m : array-like (k,d) + mean of k distributions + C : array-like (k,d,d) + covariance of k distributions + weights : array-like (k), optional + weights for each distribution + num_iter : int, optional + number of iteration for the fixed point algorithm + eps : float, optional + tolerance for the fixed point algorithm + log : bool, optional + record log if True + + + Returns + ------- + mb : (d,) array-like + mean of the barycenter + Cb : (d, d) array-like + covariance of the barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear-barycenter: + References + ---------- + .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", + SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, + 2011. + """ + nx = get_backend(*C, *m,) + + # Compute the mean barycenter + mb = nx.mean(m) + + # Init the covariance barycenter + Cb = nx.mean(C, axis=0) + + if weights is None: + weights = nx.ones(len(C), type_as=C[0]) / len(C) + + for it in range(num_iter): + # fixed point update + Cb12 = nx.sqrtm(Cb) + + Cnew = Cb12 @ C @ Cb12 + C_ = [] + for i in range(len(C)): + C_.append(nx.sqrtm(Cnew[i])) + Cnew = nx.stack(C_, axis=0) + Cnew *= weights[:, None, None] + Cnew = nx.sum(Cnew, axis=0) + + # check convergence + diff = nx.norm(Cb - Cnew) + if diff <= eps: + break + Cb = Cnew + else: + print("Dit not converge.") + + if log: + log = {} + log['num_iter'] = it + log['final_diff'] = diff + return mb, Cb, log + else: + return mb, Cb + + +def empirical_bures_wasserstein_barycenter( + X, reg=1e-6, weights=None, num_iter=1000, eps=1e-7, + w=None, bias=True, log=False +): + r"""Return OT linear operator between samples. + + The function estimates the optimal barycenter of the + empirical distributions. This is equivalent to resolving the fixed point + algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n` + :ref:`[1] `. + + The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` + where : + + .. math:: + \mu_b = \sum_{i=1}^n w_i \mu_i + + And the barycentric covariance is the solution of the following fixed-point algorithm: + + .. math:: + \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} + + + Parameters + ---------- + X : list of array-like (n,d) + samples in each distribution + reg : float,optional + regularization added to the diagonals of covariances (>0) + weights : array-like (n,), optional + weights for each distribution + num_iter : int, optional + number of iteration for the fixed point algorithm + eps : float, optional + tolerance for the fixed point algorithm + w : list of array-like (n,), optional + weights for each sample in each distribution + bias: boolean, optional + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) + log : bool, optional + record log if True + + + Returns + ------- + mb : (d,) array-like + mean of the barycenter + Cb : (d, d) array-like + covariance of the barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear-barycenter: + References + ---------- + .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", + SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, + 2011. + """ + X = list_to_array(*X) + nx = get_backend(*X) + + k = len(X) + d = [X[i].shape[1] for i in range(k)] + + if bias: + m = [nx.mean(X[i], axis=0)[None, :] for i in range(k)] + X = [X[i] - m[i] for i in range(k)] + else: + m = [nx.zeros((1, d[i]), type_as=X[i]) for i in range(k)] + + if w is None: + w = [nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k)] + + C = [ + nx.dot((X[i] * w[i]).T, X[i]) / nx.sum(w[i]) + reg * nx.eye(d[i], type_as=X[i]) + for i in range(k) + ] + m = nx.stack(m, axis=0) + C = nx.stack(C, axis=0) + if log: + mb, Cb, log = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log) + return mb, Cb, log + else: + mb, Cb = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log) + return mb, Cb + + def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): r""" Return the Gaussian Gromov-Wasserstein value from [57]. From d35a020ab3f79e9ad5fbba8b2c0a9c851bcd551d Mon Sep 17 00:00:00 2001 From: tgnassou Date: Wed, 29 Nov 2023 14:39:46 +0100 Subject: [PATCH 2/3] add tests --- test/test_gaussian.py | 65 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 02b5bbe86..c66d5908c 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -108,6 +108,71 @@ def test_empirical_bures_wasserstein_distance(nx, bias): np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) +def test_bures_wasserstein_barycenter(nx): + n = 50 + k = 10 + X = [] + y = [] + m = [] + C = [] + for _ in range(k): + X_, y_ = make_data_classif('3gauss', n) + m_ = np.mean(X_, axis=0)[None, :] + C_ = np.cov(X_.T) + X.append(X_) + y.append(y_) + m.append(m_) + C.append(C_) + m = np.array(m) + C = np.array(C) + X = nx.from_numpy(*X) + m = nx.from_numpy(m) + C = nx.from_numpy(C) + + mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, log=True) + mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, log=False) + + np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2) + + # Test weights argument + weights = nx.ones(k) / k + mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, log=False) + np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2) + + # test with closed form for diagonal covariance matrices + Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)] + Cdiag = nx.stack(Cdiag, axis=0) + mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, log=False) + + Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag] + Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0) + Cdiag_mean = nx.mean(Cdiag_sqrt, axis=0) + Cdiag_cf = Cdiag_mean @ Cdiag_mean + + np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_barycenter(nx, bias): + n = 50 + k = 10 + X = [] + y = [] + for _ in range(k): + X_, y_ = make_data_classif('3gauss', n) + X.append(X_) + y.append(y_) + + X = nx.from_numpy(*X) + + mblog, Cblog, log = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=True, bias=bias) + mb, Cb = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=False, bias=bias) + + np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("d_target", [1, 2, 3, 10]) def test_gaussian_gromov_wasserstein_distance(nx, d_target): ns = 400 From 1dbf96f79a32b909af3b590847bb0bdadb9bba21 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Wed, 29 Nov 2023 14:42:48 +0100 Subject: [PATCH 3/3] update release --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 349c56214..5e97c1d5a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -20,6 +20,7 @@ + Wrapper for `geomloss`` solver on empirical samples (PR #571) + Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578) + Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578) ++ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)