From 2b84679073e687761e15cd759bf1b64c16945121 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 27 Aug 2023 20:15:40 +0200 Subject: [PATCH 1/3] No need to parametrize for nx param --- test/test_1d_solver.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 265fab5b1..c2a10685f 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -9,13 +9,11 @@ import pytest import ot +from ot.backend import tf from ot.lp import wasserstein_1d -from ot.backend import get_backend_list, tf from scipy.stats import wasserstein_distance -backend_list = get_backend_list() - def test_emd_1d_emd2_1d_with_weights(): # test emd1d gives similar results as emd @@ -53,10 +51,7 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(w_v, G.sum(0)) -@pytest.mark.parametrize('nx', backend_list) def test_wasserstein_1d(nx): - from scipy.stats import wasserstein_distance - rng = np.random.RandomState(0) n = 100 From cf99e345844882fb789a198abda4a1fc957e19d8 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sun, 27 Aug 2023 20:21:05 +0200 Subject: [PATCH 2/3] Remove redundant check for TF backend --- test/test_1d_solver.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index c2a10685f..131757610 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -100,8 +100,6 @@ def test_wasserstein_1d_type_devices(nx): @pytest.mark.skipif(not tf, reason="tf not installed") def test_wasserstein_1d_device_tf(): - if not tf: - return nx = ot.backend.TensorflowBackend() rng = np.random.RandomState(0) n = 10 From be0b5f9c85fb5870b1ac4fd91191c7729eb942ec Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 1 Sep 2023 19:48:48 +0200 Subject: [PATCH 3/3] Remove redundant TF check from emd2 test --- test/test_ot.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_ot.py b/test/test_ot.py index cbb63185a..5c6e6732b 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -103,8 +103,6 @@ def test_emd_emd2_types_devices(nx): @pytest.mark.skipif(not tf, reason="tf not installed") def test_emd_emd2_devices_tf(): - if not tf: - return nx = ot.backend.TensorflowBackend() n_samples = 100