@@ -291,6 +291,24 @@ def line_search(cost, G, deltaG, Mi, cost_G):
291
291
np .testing .assert_allclose (res , Gb , atol = 1e-06 )
292
292
293
293
294
+ @pytest .mark .parametrize ('loss_fun' , [
295
+ 'square_loss' ,
296
+ 'kl_loss' ,
297
+ pytest .param ('unknown_loss' , marks = pytest .mark .xfail (raises = ValueError )),
298
+ ])
299
+ def test_gw_helper_validation (loss_fun ):
300
+ n_samples = 20 # nb samples
301
+ mu = np .array ([0 , 0 ])
302
+ cov = np .array ([[1 , 0 ], [0 , 1 ]])
303
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu , cov , random_state = 0 )
304
+ xt = ot .datasets .make_2D_samples_gauss (n_samples , mu , cov , random_state = 1 )
305
+ p = ot .unif (n_samples )
306
+ q = ot .unif (n_samples )
307
+ C1 = ot .dist (xs , xs )
308
+ C2 = ot .dist (xt , xt )
309
+ ot .gromov .init_matrix (C1 , C2 , p , q , loss_fun = loss_fun )
310
+
311
+
294
312
@pytest .skip_backend ("jax" , reason = "test very slow with jax backend" )
295
313
@pytest .skip_backend ("tf" , reason = "test very slow with tf backend" )
296
314
def test_entropic_gromov (nx ):
@@ -2026,6 +2044,24 @@ def line_search(cost, G, deltaG, Mi, cost_G):
2026
2044
np .testing .assert_allclose (res , Gb , atol = 1e-06 )
2027
2045
2028
2046
2047
+ @pytest .mark .parametrize ('loss_fun' , [
2048
+ 'square_loss' ,
2049
+ pytest .param ('kl_loss' , marks = pytest .mark .xfail (raises = NotImplementedError )),
2050
+ pytest .param ('unknown_loss' , marks = pytest .mark .xfail (raises = ValueError )),
2051
+ ])
2052
+ def test_gw_semirelaxed_helper_validation (loss_fun ):
2053
+ n_samples = 20 # nb samples
2054
+ mu = np .array ([0 , 0 ])
2055
+ cov = np .array ([[1 , 0 ], [0 , 1 ]])
2056
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu , cov , random_state = 0 )
2057
+ xt = ot .datasets .make_2D_samples_gauss (n_samples , mu , cov , random_state = 1 )
2058
+ p = ot .unif (n_samples )
2059
+ q = ot .unif (n_samples )
2060
+ C1 = ot .dist (xs , xs )
2061
+ C2 = ot .dist (xt , xt )
2062
+ ot .gromov .init_matrix_semirelaxed (C1 , C2 , p , loss_fun = loss_fun )
2063
+
2064
+
2029
2065
def test_semirelaxed_fgw (nx ):
2030
2066
rng = np .random .RandomState (0 )
2031
2067
list_n = [16 , 8 ]
0 commit comments