File tree Expand file tree Collapse file tree 2 files changed +1
-10
lines changed Expand file tree Collapse file tree 2 files changed +1
-10
lines changed Original file line number Diff line number Diff line change 9
9
import pytest
10
10
11
11
import ot
12
+ from ot .backend import tf
12
13
from ot .lp import wasserstein_1d
13
14
14
- from ot .backend import get_backend_list , tf
15
15
from scipy .stats import wasserstein_distance
16
16
17
- backend_list = get_backend_list ()
18
-
19
17
20
18
def test_emd_1d_emd2_1d_with_weights ():
21
19
# test emd1d gives similar results as emd
@@ -53,10 +51,7 @@ def test_emd_1d_emd2_1d_with_weights():
53
51
np .testing .assert_allclose (w_v , G .sum (0 ))
54
52
55
53
56
- @pytest .mark .parametrize ('nx' , backend_list )
57
54
def test_wasserstein_1d (nx ):
58
- from scipy .stats import wasserstein_distance
59
-
60
55
rng = np .random .RandomState (0 )
61
56
62
57
n = 100
@@ -105,8 +100,6 @@ def test_wasserstein_1d_type_devices(nx):
105
100
106
101
@pytest .mark .skipif (not tf , reason = "tf not installed" )
107
102
def test_wasserstein_1d_device_tf ():
108
- if not tf :
109
- return
110
103
nx = ot .backend .TensorflowBackend ()
111
104
rng = np .random .RandomState (0 )
112
105
n = 10
Original file line number Diff line number Diff line change @@ -103,8 +103,6 @@ def test_emd_emd2_types_devices(nx):
103
103
104
104
@pytest .mark .skipif (not tf , reason = "tf not installed" )
105
105
def test_emd_emd2_devices_tf ():
106
- if not tf :
107
- return
108
106
nx = ot .backend .TensorflowBackend ()
109
107
110
108
n_samples = 100
You can’t perform that action at this time.
0 commit comments