Skip to content

Commit 064898d

Browse files
kachayevrflamary
andauthored
Remove redundant parametrization from wasserstein_1d tests (#517)
* No need to parametrize for nx param * Remove redundant check for TF backend * Remove redundant TF check from emd2 test --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 5331480 commit 064898d

File tree

2 files changed

+1
-10
lines changed

2 files changed

+1
-10
lines changed

test/test_1d_solver.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,11 @@
99
import pytest
1010

1111
import ot
12+
from ot.backend import tf
1213
from ot.lp import wasserstein_1d
1314

14-
from ot.backend import get_backend_list, tf
1515
from scipy.stats import wasserstein_distance
1616

17-
backend_list = get_backend_list()
18-
1917

2018
def test_emd_1d_emd2_1d_with_weights():
2119
# test emd1d gives similar results as emd
@@ -53,10 +51,7 @@ def test_emd_1d_emd2_1d_with_weights():
5351
np.testing.assert_allclose(w_v, G.sum(0))
5452

5553

56-
@pytest.mark.parametrize('nx', backend_list)
5754
def test_wasserstein_1d(nx):
58-
from scipy.stats import wasserstein_distance
59-
6055
rng = np.random.RandomState(0)
6156

6257
n = 100
@@ -105,8 +100,6 @@ def test_wasserstein_1d_type_devices(nx):
105100

106101
@pytest.mark.skipif(not tf, reason="tf not installed")
107102
def test_wasserstein_1d_device_tf():
108-
if not tf:
109-
return
110103
nx = ot.backend.TensorflowBackend()
111104
rng = np.random.RandomState(0)
112105
n = 10

test/test_ot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ def test_emd_emd2_types_devices(nx):
103103

104104
@pytest.mark.skipif(not tf, reason="tf not installed")
105105
def test_emd_emd2_devices_tf():
106-
if not tf:
107-
return
108106
nx = ot.backend.TensorflowBackend()
109107

110108
n_samples = 100

0 commit comments

Comments
 (0)