25
25
26
26
lst_method_params_solve_sample = [
27
27
{'method' : '1d' },
28
+ {'method' : '1d' , 'metric' : 'euclidean' },
28
29
{'method' : 'gaussian' },
29
30
{'method' : 'gaussian' , 'reg' : 1 },
30
31
{'method' : 'factored' , 'rank' : 10 },
31
32
]
33
+
34
+ lst_parameters_solve_sample_NotImplemented = [
35
+ {'method' : '1d' , 'metric' : 'any other one' }, # fail 1d on weird metrics
36
+ {'method' : 'gaussian' , 'metric' : 'euclidean' }, # fail gaussian on metric not euclidean
37
+ {'method' : 'factored' , 'metric' : 'euclidean' }, # fail factored on metric not euclidean
38
+ {'lazy' : True }, # fail lazy for non regularized
39
+ {'lazy' : True , 'unbalanced' : 1 }, # fail lazy for non regularized unbalanced
40
+ {'lazy' : True , 'reg' : 1 , 'unbalanced' : 1 }, # fail lazy for unbalanced and regularized
41
+ ]
42
+
32
43
# set readable ids for each param
33
44
lst_method_params_solve_sample = [pytest .param (param , id = str (param )) for param in lst_method_params_solve_sample ]
45
+ lst_parameters_solve_sample_NotImplemented = [pytest .param (param , id = str (param )) for param in lst_parameters_solve_sample_NotImplemented ]
34
46
35
47
36
48
def assert_allclose_sol (sol1 , sol2 ):
@@ -268,7 +280,7 @@ def test_solve_gromov_not_implemented(nx):
268
280
269
281
def test_solve_sample (nx ):
270
282
# test solve_sample when is_Lazy = False
271
- n = 100
283
+ n = 20
272
284
X_s = np .reshape (1.0 * np .arange (n ), (n , 1 ))
273
285
X_t = np .reshape (1.0 * np .arange (0 , n ), (n , 1 ))
274
286
@@ -310,6 +322,32 @@ def test_solve_sample(nx):
310
322
sol0 = ot .solve_sample (X_s , X_t , reg = 1 , reg_type = 'cryptic divergence' )
311
323
312
324
325
+ def test_solve_sample_lazy (nx ):
326
+ # test solve_sample when is_Lazy = False
327
+ n = 20
328
+ X_s = np .reshape (1.0 * np .arange (n ), (n , 1 ))
329
+ X_t = np .reshape (1.0 * np .arange (0 , n ), (n , 1 ))
330
+
331
+ a = ot .utils .unif (X_s .shape [0 ])
332
+ b = ot .utils .unif (X_t .shape [0 ])
333
+
334
+ X_s , X_t , a , b = nx .from_numpy (X_s , X_t , a , b )
335
+
336
+ M = ot .dist (X_s , X_t )
337
+
338
+ # solve with ot.solve
339
+ sol00 = ot .solve (M , a , b , reg = 1 )
340
+
341
+ sol0 = ot .solve_sample (X_s , X_t , a , b , reg = 1 )
342
+
343
+ # solve signe weights
344
+ sol = ot .solve_sample (X_s , X_t , a , b , reg = 1 , lazy = True )
345
+
346
+ assert_allclose_sol (sol0 , sol00 )
347
+
348
+ np .testing .assert_allclose (sol0 .plan , sol .lazy_plan [:], rtol = 1e-5 , atol = 1e-5 )
349
+
350
+
313
351
@pytest .mark .parametrize ("method_params" , lst_method_params_solve_sample )
314
352
def test_solve_sample_methods (nx , method_params ):
315
353
@@ -336,41 +374,20 @@ def test_solve_sample_methods(nx, method_params):
336
374
np .testing .assert_allclose (sol2 .value , 0 )
337
375
338
376
339
- # def test_lazy_solve_sample(nx):
340
- # # test solve_sample when is_Lazy = True
341
- # n = 100
342
- # X_s = np.reshape(1.0 * np.arange(n), (n, 1))
343
- # X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
344
-
345
- # a = ot.utils.unif(X_s.shape[0])
346
- # b = ot.utils.unif(X_t.shape[0])
347
-
348
- # # solve unif weights
349
- # sol0 = ot.solve_sample(X_s, X_t, reg=0.1, lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True
377
+ @pytest .mark .parametrize ("method_params" , lst_parameters_solve_sample_NotImplemented )
378
+ def test_solve_sample_NotImplemented (nx , method_params ):
350
379
351
- # # solve signe weights
352
- # sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, lazy=True)
353
-
354
- # # check some attributes
355
- # sol.potentials
356
- # sol.lazy_plan
357
-
358
- # assert_allclose_sol(sol0, sol)
359
-
360
- # # solve in backend
361
- # X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b)
362
- # solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, lazy=True)
363
-
364
- # assert_allclose_sol(sol, solb)
380
+ n_samples_s = 20
381
+ n_samples_t = 7
382
+ n_features = 2
383
+ rng = np .random .RandomState (0 )
365
384
366
- # # test not implemented reg==0 (or None) + balanced and check raise
367
- # with pytest.raises(NotImplementedError):
368
- # sol0 = ot.solve_sample(X_s, X_t, lazy=True) # reg == 0 (or None) + unbalanced= None are default
385
+ x = rng .randn (n_samples_s , n_features )
386
+ y = rng .randn (n_samples_t , n_features )
387
+ a = ot .utils .unif (n_samples_s )
388
+ b = ot .utils .unif (n_samples_t )
369
389
370
- # # test not implemented reg==0 (or None) + unbalanced_type and check raise
371
- # with pytest.raises(NotImplementedError):
372
- # sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", lazy=True) # reg == 0 (or None) is default
390
+ xb , yb , ab , bb = nx .from_numpy (x , y , a , b )
373
391
374
- # # test not implemented reg != 0 + unbalanced_type and check raise
375
- # with pytest.raises(NotImplementedError):
376
- # sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", lazy=True)
392
+ with pytest .raises (NotImplementedError ):
393
+ ot .solve_sample (xb , yb , ab , bb , ** method_params )
0 commit comments