@@ -311,7 +311,12 @@ def test_gw_helper_validation(loss_fun):
311
311
312
312
@pytest .skip_backend ("jax" , reason = "test very slow with jax backend" )
313
313
@pytest .skip_backend ("tf" , reason = "test very slow with tf backend" )
314
- def test_entropic_gromov (nx ):
314
+ @pytest .mark .parametrize ('loss_fun' , [
315
+ 'square_loss' ,
316
+ 'kl_loss' ,
317
+ pytest .param ('unknown_loss' , marks = pytest .mark .xfail (raises = ValueError )),
318
+ ])
319
+ def test_entropic_gromov (nx , loss_fun ):
315
320
n_samples = 10 # nb samples
316
321
317
322
mu_s = np .array ([0 , 0 ])
@@ -333,10 +338,10 @@ def test_entropic_gromov(nx):
333
338
C1b , C2b , pb , qb , G0b = nx .from_numpy (C1 , C2 , p , q , G0 )
334
339
335
340
G , log = ot .gromov .entropic_gromov_wasserstein (
336
- C1 , C2 , None , q , 'square_loss' , symmetric = None , G0 = G0 ,
341
+ C1 , C2 , None , q , loss_fun , symmetric = None , G0 = G0 ,
337
342
epsilon = 1e-2 , max_iter = 10 , verbose = True , log = True )
338
343
Gb = nx .to_numpy (ot .gromov .entropic_gromov_wasserstein (
339
- C1b , C2b , pb , None , 'square_loss' , symmetric = True , G0 = None ,
344
+ C1b , C2b , pb , None , loss_fun , symmetric = True , G0 = None ,
340
345
epsilon = 1e-2 , max_iter = 10 , verbose = True , log = False
341
346
))
342
347
@@ -347,11 +352,40 @@ def test_entropic_gromov(nx):
347
352
np .testing .assert_allclose (
348
353
q , Gb .sum (0 ), atol = 1e-04 ) # cf convergence gromov
349
354
355
+
356
+ @pytest .skip_backend ("jax" , reason = "test very slow with jax backend" )
357
+ @pytest .skip_backend ("tf" , reason = "test very slow with tf backend" )
358
+ @pytest .mark .parametrize ('loss_fun' , [
359
+ 'square_loss' ,
360
+ 'kl_loss' ,
361
+ pytest .param ('unknown_loss' , marks = pytest .mark .xfail (raises = ValueError )),
362
+ ])
363
+ def test_entropic_gromov2 (nx , loss_fun ):
364
+ n_samples = 10 # nb samples
365
+
366
+ mu_s = np .array ([0 , 0 ])
367
+ cov_s = np .array ([[1 , 0 ], [0 , 1 ]])
368
+
369
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu_s , cov_s , random_state = 42 )
370
+
371
+ xt = xs [::- 1 ].copy ()
372
+
373
+ p = ot .unif (n_samples )
374
+ q = ot .unif (n_samples )
375
+ G0 = p [:, None ] * q [None , :]
376
+ C1 = ot .dist (xs , xs )
377
+ C2 = ot .dist (xt , xt )
378
+
379
+ C1 /= C1 .max ()
380
+ C2 /= C2 .max ()
381
+
382
+ C1b , C2b , pb , qb , G0b = nx .from_numpy (C1 , C2 , p , q , G0 )
383
+
350
384
gw , log = ot .gromov .entropic_gromov_wasserstein2 (
351
- C1 , C2 , p , None , 'kl_loss' , symmetric = True , G0 = None ,
385
+ C1 , C2 , p , None , loss_fun , symmetric = True , G0 = None ,
352
386
max_iter = 10 , epsilon = 1e-2 , log = True )
353
387
gwb , logb = ot .gromov .entropic_gromov_wasserstein2 (
354
- C1b , C2b , None , qb , 'kl_loss' , symmetric = None , G0 = G0b ,
388
+ C1b , C2b , None , qb , loss_fun , symmetric = None , G0 = G0b ,
355
389
max_iter = 10 , epsilon = 1e-2 , log = True )
356
390
gwb = nx .to_numpy (gwb )
357
391
0 commit comments