Skip to content

Commit f54ec56

Browse files
committed
big update tests
1 parent b1182c8 commit f54ec56

File tree

3 files changed

+76
-40
lines changed

3 files changed

+76
-40
lines changed

test/test_bregman.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,10 +1078,10 @@ def test_lazy_empirical_sinkhorn(nx):
10781078
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
10791079

10801080
f, g, log_es = ot.bregman.empirical_sinkhorn(
1081-
X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
1081+
X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True)
10821082
f, g = nx.to_numpy(f), nx.to_numpy(g)
1083-
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
1084-
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
1083+
G_log = np.exp(f[:, None] + g[None, :] - M / 1)
1084+
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 1, log=True)
10851085
sinkhorn_log = nx.to_numpy(sinkhorn_log)
10861086

10871087
f, g = ot.bregman.empirical_sinkhorn(
@@ -1091,10 +1091,14 @@ def test_lazy_empirical_sinkhorn(nx):
10911091
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
10921092

10931093
loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(
1094-
X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
1094+
X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True)
1095+
G_lazy = nx.to_numpy(log['lazy_plan'][:])
10951096
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
10961097
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
10971098

1099+
loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(
1100+
X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=False)
1101+
10981102
# check constraints
10991103
np.testing.assert_allclose(
11001104
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
@@ -1109,6 +1113,7 @@ def test_lazy_empirical_sinkhorn(nx):
11091113
np.testing.assert_allclose(
11101114
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
11111115
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
1116+
np.testing.assert_allclose(G_log, G_lazy, atol=1e-05)
11121117

11131118

11141119
def test_empirical_sinkhorn_divergence(nx):

test/test_solvers.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,24 @@
2525

2626
lst_method_params_solve_sample = [
2727
{'method': '1d'},
28+
{'method': '1d', 'metric': 'euclidean'},
2829
{'method': 'gaussian'},
2930
{'method': 'gaussian', 'reg': 1},
3031
{'method': 'factored', 'rank': 10},
3132
]
33+
34+
lst_parameters_solve_sample_NotImplemented = [
35+
{'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics
36+
{'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean
37+
{'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean
38+
{'lazy': True}, # fail lazy for non regularized
39+
{'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced
40+
{'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized
41+
]
42+
3243
# set readable ids for each param
3344
lst_method_params_solve_sample = [pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample]
45+
lst_parameters_solve_sample_NotImplemented = [pytest.param(param, id=str(param)) for param in lst_parameters_solve_sample_NotImplemented]
3446

3547

3648
def assert_allclose_sol(sol1, sol2):
@@ -268,7 +280,7 @@ def test_solve_gromov_not_implemented(nx):
268280

269281
def test_solve_sample(nx):
270282
# test solve_sample when is_Lazy = False
271-
n = 100
283+
n = 20
272284
X_s = np.reshape(1.0 * np.arange(n), (n, 1))
273285
X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
274286

@@ -310,6 +322,32 @@ def test_solve_sample(nx):
310322
sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence')
311323

312324

325+
def test_solve_sample_lazy(nx):
326+
# test solve_sample when is_Lazy = False
327+
n = 20
328+
X_s = np.reshape(1.0 * np.arange(n), (n, 1))
329+
X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
330+
331+
a = ot.utils.unif(X_s.shape[0])
332+
b = ot.utils.unif(X_t.shape[0])
333+
334+
X_s, X_t, a, b = nx.from_numpy(X_s, X_t, a, b)
335+
336+
M = ot.dist(X_s, X_t)
337+
338+
# solve with ot.solve
339+
sol00 = ot.solve(M, a, b, reg=1)
340+
341+
sol0 = ot.solve_sample(X_s, X_t, a, b, reg=1)
342+
343+
# solve signe weights
344+
sol = ot.solve_sample(X_s, X_t, a, b, reg=1, lazy=True)
345+
346+
assert_allclose_sol(sol0, sol00)
347+
348+
np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5)
349+
350+
313351
@pytest.mark.parametrize("method_params", lst_method_params_solve_sample)
314352
def test_solve_sample_methods(nx, method_params):
315353

@@ -336,41 +374,20 @@ def test_solve_sample_methods(nx, method_params):
336374
np.testing.assert_allclose(sol2.value, 0)
337375

338376

339-
# def test_lazy_solve_sample(nx):
340-
# # test solve_sample when is_Lazy = True
341-
# n = 100
342-
# X_s = np.reshape(1.0 * np.arange(n), (n, 1))
343-
# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
344-
345-
# a = ot.utils.unif(X_s.shape[0])
346-
# b = ot.utils.unif(X_t.shape[0])
347-
348-
# # solve unif weights
349-
# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True
377+
@pytest.mark.parametrize("method_params", lst_parameters_solve_sample_NotImplemented)
378+
def test_solve_sample_NotImplemented(nx, method_params):
350379

351-
# # solve signe weights
352-
# sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, lazy=True)
353-
354-
# # check some attributes
355-
# sol.potentials
356-
# sol.lazy_plan
357-
358-
# assert_allclose_sol(sol0, sol)
359-
360-
# # solve in backend
361-
# X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b)
362-
# solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, lazy=True)
363-
364-
# assert_allclose_sol(sol, solb)
380+
n_samples_s = 20
381+
n_samples_t = 7
382+
n_features = 2
383+
rng = np.random.RandomState(0)
365384

366-
# # test not implemented reg==0 (or None) + balanced and check raise
367-
# with pytest.raises(NotImplementedError):
368-
# sol0 = ot.solve_sample(X_s, X_t, lazy=True) # reg == 0 (or None) + unbalanced= None are default
385+
x = rng.randn(n_samples_s, n_features)
386+
y = rng.randn(n_samples_t, n_features)
387+
a = ot.utils.unif(n_samples_s)
388+
b = ot.utils.unif(n_samples_t)
369389

370-
# # test not implemented reg==0 (or None) + unbalanced_type and check raise
371-
# with pytest.raises(NotImplementedError):
372-
# sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", lazy=True) # reg == 0 (or None) is default
390+
xb, yb, ab, bb = nx.from_numpy(x, y, a, b)
373391

374-
# # test not implemented reg != 0 + unbalanced_type and check raise
375-
# with pytest.raises(NotImplementedError):
376-
# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", lazy=True)
392+
with pytest.raises(NotImplementedError):
393+
ot.solve_sample(xb, yb, ab, bb, **method_params)

test/test_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ def test_cost_normalization(nx):
318318
M1 = nx.to_numpy(M)
319319
np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max())
320320

321+
with pytest.raises(ValueError):
322+
ot.utils.cost_normalization(C1, 'error')
323+
321324

322325
def test_check_params():
323326

@@ -328,6 +331,16 @@ def test_check_params():
328331
assert res0 is False
329332

330333

334+
def test_check_random_state_error():
335+
with pytest.raises(ValueError):
336+
ot.utils.check_random_state('error')
337+
338+
339+
def test_get_parameter_pairs_error():
340+
with pytest.raises(ValueError):
341+
ot.utils.get_parameter_pairs((1, 2, 3)) # not pair ;)
342+
343+
331344
def test_deprecated_func():
332345

333346
@ot.utils.deprecated('deprecated text for fun')
@@ -408,7 +421,8 @@ def test_OTResult():
408421
'status',
409422
'value',
410423
'value_linear',
411-
'value_quad']
424+
'value_quad',
425+
'log']
412426
for at in lst_attributes:
413427
print(at)
414428
with pytest.raises(NotImplementedError):

0 commit comments

Comments
 (0)