|
14 | 14 | from ot.unbalanced import barycenter_unbalanced
|
15 | 15 |
|
16 | 16 |
|
17 |
| -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) |
| 17 | +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) |
18 | 18 | def test_unbalanced_convergence(nx, method, reg_type):
|
19 | 19 | # test generalized sinkhorn for unbalanced OT
|
20 | 20 | n = 100
|
@@ -78,7 +78,7 @@ def test_unbalanced_convergence(nx, method, reg_type):
|
78 | 78 | np.testing.assert_allclose(G_np, nx.to_numpy(G))
|
79 | 79 |
|
80 | 80 |
|
81 |
| -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"])) |
| 81 | +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) |
82 | 82 | def test_unbalanced_warmstart(nx, method, reg_type):
|
83 | 83 | # test generalized sinkhorn for unbalanced OT
|
84 | 84 | n = 100
|
@@ -140,7 +140,7 @@ def test_unbalanced_warmstart(nx, method, reg_type):
|
140 | 140 | np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5)
|
141 | 141 |
|
142 | 142 |
|
143 |
| -@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")])) |
| 143 | +@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [1, float("inf")])) |
144 | 144 | def test_unbalanced_relaxation_parameters(nx, method, reg_m):
|
145 | 145 | # test generalized sinkhorn for unbalanced OT
|
146 | 146 | n = 100
|
@@ -184,7 +184,7 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m):
|
184 | 184 | nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05)
|
185 | 185 |
|
186 | 186 |
|
187 |
| -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) |
| 187 | +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) |
188 | 188 | def test_unbalanced_multiple_inputs(nx, method):
|
189 | 189 | # test generalized sinkhorn for unbalanced OT
|
190 | 190 | n = 100
|
|
0 commit comments