From 9f0e69c97b88d3a1ea01157d8ba84ddf3dae275d Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 25 Aug 2023 21:52:16 +0200 Subject: [PATCH 1/3] loss_fun validation in GW utils --- ot/gromov/_utils.py | 8 ++++++++ test/test_gromov.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 0b8bb00fb..1b391d531 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -72,6 +72,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') nx : backend, optional If let to its default value None, a backend test will be conducted. + Returns ------- constC : array-like, shape (ns, nt) @@ -118,6 +119,8 @@ def h1(a): def h2(b): return nx.log(b + 1e-15) + else: + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") constC1 = nx.dot( nx.dot(f1(C1), nx.reshape(p, (-1, 1))), @@ -407,6 +410,7 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): p : array-like, shape (ns,) nx : backend, optional If let to its default value None, a backend test will be conducted. + Returns ------- constC : array-like, shape (ns, nt) @@ -446,6 +450,10 @@ def h1(a): def h2(b): return 2 * b + elif loss_fun == 'kl_loss': + raise NotImplementedError() + else: + raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.") constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))), nx.ones((1, C2.shape[0]), type_as=p)) diff --git a/test/test_gromov.py b/test/test_gromov.py index be4f659fe..af451c90d 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -291,6 +291,24 @@ def line_search(cost, G, deltaG, Mi, cost_G): np.testing.assert_allclose(res, Gb, atol=1e-06) +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', + 'kl_loss', + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_gw_helper_validation(loss_fun): + n_samples = 20 # nb samples + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + p = ot.unif(n_samples) + q = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + ot.gromov.init_matrix(C1, C2, p, q, loss_fun=loss_fun) + + @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov(nx): @@ -2026,6 +2044,24 @@ def line_search(cost, G, deltaG, Mi, cost_G): np.testing.assert_allclose(res, Gb, atol=1e-06) +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', + pytest.param('kl_loss', marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_gw_semirelaxed_helper_validation(loss_fun): + n_samples = 20 # nb samples + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + p = ot.unif(n_samples) + q = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun) + + def test_semirelaxed_fgw(nx): rng = np.random.RandomState(0) list_n = [16, 8] From 9f06a34b2362b7c75e6b999696fb84af2f73afc7 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 25 Aug 2023 21:58:01 +0200 Subject: [PATCH 2/3] Fix docstring for semirelaxed use case --- ot/gromov/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 1b391d531..d77e44f9e 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -405,9 +405,9 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - T : array-like, shape (ns, nt) - Coupling between source and target spaces p : array-like, shape (ns,) + loss_fun : str, optional + Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') nx : backend, optional If let to its default value None, a backend test will be conducted. From 3f873169c41c15c73893be078a5a0cad2313add3 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 25 Aug 2023 22:22:24 +0200 Subject: [PATCH 3/3] Remove unused variable --- test/test_gromov.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index af451c90d..8104ad03f 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -2056,7 +2056,6 @@ def test_gw_semirelaxed_helper_validation(loss_fun): xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) p = ot.unif(n_samples) - q = ot.unif(n_samples) C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun)