Skip to content

Commit db76ee1

Browse files
committed
Update tests to cover unknown loss_fun
1 parent 61a6849 commit db76ee1

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

test/test_gromov.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,12 @@ def test_gw_helper_validation(loss_fun):
311311

312312
@pytest.skip_backend("jax", reason="test very slow with jax backend")
313313
@pytest.skip_backend("tf", reason="test very slow with tf backend")
314-
def test_entropic_gromov(nx):
314+
@pytest.mark.parametrize('loss_fun', [
315+
'square_loss',
316+
'kl_loss',
317+
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
318+
])
319+
def test_entropic_gromov(nx, loss_fun):
315320
n_samples = 10 # nb samples
316321

317322
mu_s = np.array([0, 0])
@@ -333,10 +338,10 @@ def test_entropic_gromov(nx):
333338
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
334339

335340
G, log = ot.gromov.entropic_gromov_wasserstein(
336-
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
341+
C1, C2, None, q, loss_fun, symmetric=None, G0=G0,
337342
epsilon=1e-2, max_iter=10, verbose=True, log=True)
338343
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
339-
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None,
344+
C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None,
340345
epsilon=1e-2, max_iter=10, verbose=True, log=False
341346
))
342347

@@ -347,11 +352,40 @@ def test_entropic_gromov(nx):
347352
np.testing.assert_allclose(
348353
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
349354

355+
356+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
357+
@pytest.skip_backend("tf", reason="test very slow with tf backend")
358+
@pytest.mark.parametrize('loss_fun', [
359+
'square_loss',
360+
'kl_loss',
361+
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
362+
])
363+
def test_entropic_gromov2(nx, loss_fun):
364+
n_samples = 10 # nb samples
365+
366+
mu_s = np.array([0, 0])
367+
cov_s = np.array([[1, 0], [0, 1]])
368+
369+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
370+
371+
xt = xs[::-1].copy()
372+
373+
p = ot.unif(n_samples)
374+
q = ot.unif(n_samples)
375+
G0 = p[:, None] * q[None, :]
376+
C1 = ot.dist(xs, xs)
377+
C2 = ot.dist(xt, xt)
378+
379+
C1 /= C1.max()
380+
C2 /= C2.max()
381+
382+
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
383+
350384
gw, log = ot.gromov.entropic_gromov_wasserstein2(
351-
C1, C2, p, None, 'kl_loss', symmetric=True, G0=None,
385+
C1, C2, p, None, loss_fun, symmetric=True, G0=None,
352386
max_iter=10, epsilon=1e-2, log=True)
353387
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
354-
C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b,
388+
C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b,
355389
max_iter=10, epsilon=1e-2, log=True)
356390
gwb = nx.to_numpy(gwb)
357391

0 commit comments

Comments
 (0)