From 7ee082deda902ae6782d4c8fa2ae65cc0357cff4 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sat, 26 Aug 2023 23:22:34 +0200 Subject: [PATCH 1/5] Fix docstring for loss_fun parameter --- ot/gromov/_bregman.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 0fef338f5..c4db0433e 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -373,8 +373,8 @@ def entropic_gromov_barycenters( lambdas : list of float, optional List of the `S` spaces' weights. If let to its default value None, uniform weights are taken. - loss_fun : callable, optional - tensor-matrix multiplication function based on specific loss function + loss_fun : string, optional (default='square_loss') + Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 symmetric : bool, optional. From a3fd8832517f38a6607df5e35fbc12ab24cf742c Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sat, 26 Aug 2023 23:27:12 +0200 Subject: [PATCH 2/5] Explicitly check for correct loss_fun value --- ot/gromov/_bregman.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index c4db0433e..d0b813caa 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -81,7 +81,7 @@ def entropic_gromov_wasserstein( q : array-like, shape (nt,), optional Distribution in the target space. If let to its default value None, uniform distribution is taken. - loss_fun : string, optional + loss_fun : string, optional (default='square_loss') Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 @@ -135,6 +135,9 @@ def entropic_gromov_wasserstein( if solver not in ['PGD', 'PPA']: raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver) + if loss_fun not in ('square_loss', 'kl_loss'): + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + C1, C2 = list_to_array(C1, C2) arr = [C1, C2] if p is not None: @@ -411,6 +414,9 @@ def entropic_gromov_barycenters( "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + if loss_fun not in ('square_loss', 'kl_loss'): + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + Cs = list_to_array(*Cs) arr = [*Cs] if ps is not None: @@ -459,7 +465,6 @@ def entropic_gromov_barycenters( if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) - elif loss_fun == 'kl_loss': C = update_kl_loss(p, lambdas, T, Cs) From 61a6849b71ec83f4eb8cd2308c6ecfd9faa4b975 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sat, 26 Aug 2023 23:27:36 +0200 Subject: [PATCH 3/5] Typo --- ot/gromov/_bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index d0b813caa..2cd0282fe 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -93,7 +93,7 @@ def entropic_gromov_wasserstein( If None the initial transport plan of the solver is pq^T. Otherwise G0 will be used as initial transport of the solver. G0 is not required to satisfy marginal constraints but we strongly recommand it - to correcly estimate the GW distance. + to correctly estimate the GW distance. max_iter : int, optional Max number of iterations tol : float, optional From db76ee18f104a9d979217376fa5da91eb27a6818 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sat, 26 Aug 2023 23:33:36 +0200 Subject: [PATCH 4/5] Update tests to cover unknown loss_fun --- test/test_gromov.py | 44 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index 8104ad03f..559d07855 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -311,7 +311,12 @@ def test_gw_helper_validation(loss_fun): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_gromov(nx): +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', + 'kl_loss', + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_entropic_gromov(nx, loss_fun): n_samples = 10 # nb samples mu_s = np.array([0, 0]) @@ -333,10 +338,10 @@ def test_entropic_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + C1, C2, None, q, loss_fun, symmetric=None, G0=G0, epsilon=1e-2, max_iter=10, verbose=True, log=True) Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, + C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None, epsilon=1e-2, max_iter=10, verbose=True, log=False )) @@ -347,11 +352,40 @@ def test_entropic_gromov(nx): np.testing.assert_allclose( q, Gb.sum(0), atol=1e-04) # cf convergence gromov + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', + 'kl_loss', + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_entropic_gromov2(nx, loss_fun): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, + C1, C2, p, None, loss_fun, symmetric=True, G0=None, max_iter=10, epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, + C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b, max_iter=10, epsilon=1e-2, log=True) gwb = nx.to_numpy(gwb) From 2a74966783a343a899e0868363495f81a21a21df Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sat, 26 Aug 2023 23:49:51 +0200 Subject: [PATCH 5/5] Unify docstring for other functions as well --- ot/gromov/_bregman.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 2cd0282fe..6dc705949 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -81,7 +81,7 @@ def entropic_gromov_wasserstein( q : array-like, shape (nt,), optional Distribution in the target space. If let to its default value None, uniform distribution is taken. - loss_fun : string, optional (default='square_loss') + loss_fun : string, optional (default='square_loss') Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 @@ -92,7 +92,7 @@ def entropic_gromov_wasserstein( G0: array-like, shape (ns,nt), optional If None the initial transport plan of the solver is pq^T. Otherwise G0 will be used as initial transport of the solver. G0 is not - required to satisfy marginal constraints but we strongly recommand it + required to satisfy marginal constraints but we strongly recommend it to correctly estimate the GW distance. max_iter : int, optional Max number of iterations @@ -283,7 +283,7 @@ def entropic_gromov_wasserstein2( q : array-like, shape (nt,), optional Distribution in the target space. If let to its default value None, uniform distribution is taken. - loss_fun : string, optional + loss_fun : string, optional (default='square_loss') Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 @@ -376,7 +376,7 @@ def entropic_gromov_barycenters( lambdas : list of float, optional List of the `S` spaces' weights. If let to its default value None, uniform weights are taken. - loss_fun : string, optional (default='square_loss') + loss_fun : string, optional (default='square_loss') Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 @@ -555,21 +555,21 @@ def entropic_fused_gromov_wasserstein( q : array-like, shape (nt,), optional Distribution in the target space. If let to its default value None, uniform distribution is taken. - loss_fun : string, optional + loss_fun : string, optional (default='square_loss') Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like, shape (ns,nt), optional If None the initial transport plan of the solver is pq^T. Otherwise G0 will be used as initial transport of the solver. G0 is not - required to satisfy marginal constraints but we strongly recommand it - to correcly estimate the GW distance. + required to satisfy marginal constraints but we strongly recommend it + to correctly estimate the GW distance. max_iter : int, optional Max number of iterations tol : float, optional @@ -616,6 +616,9 @@ def entropic_fused_gromov_wasserstein( if solver not in ['PGD', 'PPA']: raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver) + if loss_fun not in ('square_loss', 'kl_loss'): + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + M, C1, C2 = list_to_array(M, C1, C2) arr = [M, C1, C2] if p is not None: @@ -767,7 +770,7 @@ def entropic_fused_gromov_wasserstein2( q : array-like, shape (nt,), optional Distribution in the target space. If let to its default value None, uniform distribution is taken. - loss_fun : string, optional + loss_fun : string, optional (default='square_loss') Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 @@ -780,8 +783,8 @@ def entropic_fused_gromov_wasserstein2( G0: array-like, shape (ns,nt), optional If None the initial transport plan of the solver is pq^T. Otherwise G0 will be used as initial transport of the solver. G0 is not - required to satisfy marginal constraints but we strongly recommand it - to correcly estimate the GW distance. + required to satisfy marginal constraints but we strongly recommend it + to correctly estimate the GW distance. max_iter : int, optional Max number of iterations tol : float, optional @@ -862,8 +865,8 @@ def entropic_fused_gromov_barycenters( lambdas : list of float, optional List of the `S` spaces' weights. If let to its default value None, uniform weights are taken. - loss_fun : callable, optional - tensor-matrix multiplication function based on specific loss function + loss_fun : string, optional (default='square_loss') + Loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float, optional Regularization term >0 symmetric : bool, optional. @@ -912,6 +915,9 @@ def entropic_fused_gromov_barycenters( "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + if loss_fun not in ('square_loss', 'kl_loss'): + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + Cs = list_to_array(*Cs) Ys = list_to_array(*Ys) arr = [*Cs, *Ys] @@ -982,7 +988,6 @@ def entropic_fused_gromov_barycenters( if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) - elif loss_fun == 'kl_loss': C = update_kl_loss(p, lambdas, T, Cs)