Skip to content

Commit 9f0e69c

Browse files
committed
loss_fun validation in GW utils
1 parent 20cc202 commit 9f0e69c

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

ot/gromov/_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None):
7272
Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
7373
nx : backend, optional
7474
If let to its default value None, a backend test will be conducted.
75+
7576
Returns
7677
-------
7778
constC : array-like, shape (ns, nt)
@@ -118,6 +119,8 @@ def h1(a):
118119

119120
def h2(b):
120121
return nx.log(b + 1e-15)
122+
else:
123+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
121124

122125
constC1 = nx.dot(
123126
nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
@@ -407,6 +410,7 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
407410
p : array-like, shape (ns,)
408411
nx : backend, optional
409412
If let to its default value None, a backend test will be conducted.
413+
410414
Returns
411415
-------
412416
constC : array-like, shape (ns, nt)
@@ -446,6 +450,10 @@ def h1(a):
446450

447451
def h2(b):
448452
return 2 * b
453+
elif loss_fun == 'kl_loss':
454+
raise NotImplementedError()
455+
else:
456+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.")
449457

450458
constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
451459
nx.ones((1, C2.shape[0]), type_as=p))

test/test_gromov.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,24 @@ def line_search(cost, G, deltaG, Mi, cost_G):
291291
np.testing.assert_allclose(res, Gb, atol=1e-06)
292292

293293

294+
@pytest.mark.parametrize('loss_fun', [
295+
'square_loss',
296+
'kl_loss',
297+
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
298+
])
299+
def test_gw_helper_validation(loss_fun):
300+
n_samples = 20 # nb samples
301+
mu = np.array([0, 0])
302+
cov = np.array([[1, 0], [0, 1]])
303+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
304+
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
305+
p = ot.unif(n_samples)
306+
q = ot.unif(n_samples)
307+
C1 = ot.dist(xs, xs)
308+
C2 = ot.dist(xt, xt)
309+
ot.gromov.init_matrix(C1, C2, p, q, loss_fun=loss_fun)
310+
311+
294312
@pytest.skip_backend("jax", reason="test very slow with jax backend")
295313
@pytest.skip_backend("tf", reason="test very slow with tf backend")
296314
def test_entropic_gromov(nx):
@@ -2026,6 +2044,24 @@ def line_search(cost, G, deltaG, Mi, cost_G):
20262044
np.testing.assert_allclose(res, Gb, atol=1e-06)
20272045

20282046

2047+
@pytest.mark.parametrize('loss_fun', [
2048+
'square_loss',
2049+
pytest.param('kl_loss', marks=pytest.mark.xfail(raises=NotImplementedError)),
2050+
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
2051+
])
2052+
def test_gw_semirelaxed_helper_validation(loss_fun):
2053+
n_samples = 20 # nb samples
2054+
mu = np.array([0, 0])
2055+
cov = np.array([[1, 0], [0, 1]])
2056+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
2057+
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
2058+
p = ot.unif(n_samples)
2059+
q = ot.unif(n_samples)
2060+
C1 = ot.dist(xs, xs)
2061+
C2 = ot.dist(xt, xt)
2062+
ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun)
2063+
2064+
20292065
def test_semirelaxed_fgw(nx):
20302066
rng = np.random.RandomState(0)
20312067
list_n = [16, 8]

0 commit comments

Comments
 (0)