Skip to content

Commit 829ce41

Browse files
authored
loss_fun parameter validationf for entropic GW functions (#515)
1 parent 7fa8438 commit 829ce41

File tree

2 files changed

+66
-22
lines changed

2 files changed

+66
-22
lines changed

ot/gromov/_bregman.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
8181
q : array-like, shape (nt,), optional
8282
Distribution in the target space.
8383
If let to its default value None, uniform distribution is taken.
84-
loss_fun : string, optional
84+
loss_fun : string, optional (default='square_loss')
8585
Loss function used for the solver either 'square_loss' or 'kl_loss'
8686
epsilon : float, optional
8787
Regularization term >0
@@ -92,8 +92,8 @@ def entropic_gromov_wasserstein(
9292
G0: array-like, shape (ns,nt), optional
9393
If None the initial transport plan of the solver is pq^T.
9494
Otherwise G0 will be used as initial transport of the solver. G0 is not
95-
required to satisfy marginal constraints but we strongly recommand it
96-
to correcly estimate the GW distance.
95+
required to satisfy marginal constraints but we strongly recommend it
96+
to correctly estimate the GW distance.
9797
max_iter : int, optional
9898
Max number of iterations
9999
tol : float, optional
@@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
135135
if solver not in ['PGD', 'PPA']:
136136
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
137137

138+
if loss_fun not in ('square_loss', 'kl_loss'):
139+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
140+
138141
C1, C2 = list_to_array(C1, C2)
139142
arr = [C1, C2]
140143
if p is not None:
@@ -280,7 +283,7 @@ def entropic_gromov_wasserstein2(
280283
q : array-like, shape (nt,), optional
281284
Distribution in the target space.
282285
If let to its default value None, uniform distribution is taken.
283-
loss_fun : string, optional
286+
loss_fun : string, optional (default='square_loss')
284287
Loss function used for the solver either 'square_loss' or 'kl_loss'
285288
epsilon : float, optional
286289
Regularization term >0
@@ -373,8 +376,8 @@ def entropic_gromov_barycenters(
373376
lambdas : list of float, optional
374377
List of the `S` spaces' weights.
375378
If let to its default value None, uniform weights are taken.
376-
loss_fun : callable, optional
377-
tensor-matrix multiplication function based on specific loss function
379+
loss_fun : string, optional (default='square_loss')
380+
Loss function used for the solver either 'square_loss' or 'kl_loss'
378381
epsilon : float, optional
379382
Regularization term >0
380383
symmetric : bool, optional.
@@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
411414
"Gromov-Wasserstein averaging of kernel and distance matrices."
412415
International Conference on Machine Learning (ICML). 2016.
413416
"""
417+
if loss_fun not in ('square_loss', 'kl_loss'):
418+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
419+
414420
Cs = list_to_array(*Cs)
415421
arr = [*Cs]
416422
if ps is not None:
@@ -459,7 +465,6 @@ def entropic_gromov_barycenters(
459465

460466
if loss_fun == 'square_loss':
461467
C = update_square_loss(p, lambdas, T, Cs)
462-
463468
elif loss_fun == 'kl_loss':
464469
C = update_kl_loss(p, lambdas, T, Cs)
465470

@@ -550,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
550555
q : array-like, shape (nt,), optional
551556
Distribution in the target space.
552557
If let to its default value None, uniform distribution is taken.
553-
loss_fun : string, optional
558+
loss_fun : string, optional (default='square_loss')
554559
Loss function used for the solver either 'square_loss' or 'kl_loss'
555560
epsilon : float, optional
556561
Regularization term >0
557562
symmetric : bool, optional
558563
Either C1 and C2 are to be assumed symmetric or not.
559564
If let to its default None value, a symmetry test will be conducted.
560-
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
565+
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
561566
alpha : float, optional
562567
Trade-off parameter (0 < alpha < 1)
563568
G0: array-like, shape (ns,nt), optional
564569
If None the initial transport plan of the solver is pq^T.
565570
Otherwise G0 will be used as initial transport of the solver. G0 is not
566-
required to satisfy marginal constraints but we strongly recommand it
567-
to correcly estimate the GW distance.
571+
required to satisfy marginal constraints but we strongly recommend it
572+
to correctly estimate the GW distance.
568573
max_iter : int, optional
569574
Max number of iterations
570575
tol : float, optional
@@ -611,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
611616
if solver not in ['PGD', 'PPA']:
612617
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
613618

619+
if loss_fun not in ('square_loss', 'kl_loss'):
620+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
621+
614622
M, C1, C2 = list_to_array(M, C1, C2)
615623
arr = [M, C1, C2]
616624
if p is not None:
@@ -762,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
762770
q : array-like, shape (nt,), optional
763771
Distribution in the target space.
764772
If let to its default value None, uniform distribution is taken.
765-
loss_fun : string, optional
773+
loss_fun : string, optional (default='square_loss')
766774
Loss function used for the solver either 'square_loss' or 'kl_loss'
767775
epsilon : float, optional
768776
Regularization term >0
@@ -775,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
775783
G0: array-like, shape (ns,nt), optional
776784
If None the initial transport plan of the solver is pq^T.
777785
Otherwise G0 will be used as initial transport of the solver. G0 is not
778-
required to satisfy marginal constraints but we strongly recommand it
779-
to correcly estimate the GW distance.
786+
required to satisfy marginal constraints but we strongly recommend it
787+
to correctly estimate the GW distance.
780788
max_iter : int, optional
781789
Max number of iterations
782790
tol : float, optional
@@ -857,8 +865,8 @@ def entropic_fused_gromov_barycenters(
857865
lambdas : list of float, optional
858866
List of the `S` spaces' weights.
859867
If let to its default value None, uniform weights are taken.
860-
loss_fun : callable, optional
861-
tensor-matrix multiplication function based on specific loss function
868+
loss_fun : string, optional (default='square_loss')
869+
Loss function used for the solver either 'square_loss' or 'kl_loss'
862870
epsilon : float, optional
863871
Regularization term >0
864872
symmetric : bool, optional.
@@ -907,6 +915,9 @@ def entropic_fused_gromov_barycenters(
907915
"Optimal Transport for structured data with application on graphs"
908916
International Conference on Machine Learning (ICML). 2019.
909917
"""
918+
if loss_fun not in ('square_loss', 'kl_loss'):
919+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
920+
910921
Cs = list_to_array(*Cs)
911922
Ys = list_to_array(*Ys)
912923
arr = [*Cs, *Ys]
@@ -977,7 +988,6 @@ def entropic_fused_gromov_barycenters(
977988

978989
if loss_fun == 'square_loss':
979990
C = update_square_loss(p, lambdas, T, Cs)
980-
981991
elif loss_fun == 'kl_loss':
982992
C = update_kl_loss(p, lambdas, T, Cs)
983993

test/test_gromov.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,12 @@ def test_gw_helper_validation(loss_fun):
311311

312312
@pytest.skip_backend("jax", reason="test very slow with jax backend")
313313
@pytest.skip_backend("tf", reason="test very slow with tf backend")
314-
def test_entropic_gromov(nx):
314+
@pytest.mark.parametrize('loss_fun', [
315+
'square_loss',
316+
'kl_loss',
317+
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
318+
])
319+
def test_entropic_gromov(nx, loss_fun):
315320
n_samples = 10 # nb samples
316321

317322
mu_s = np.array([0, 0])
@@ -333,10 +338,10 @@ def test_entropic_gromov(nx):
333338
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
334339

335340
G, log = ot.gromov.entropic_gromov_wasserstein(
336-
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
341+
C1, C2, None, q, loss_fun, symmetric=None, G0=G0,
337342
epsilon=1e-2, max_iter=10, verbose=True, log=True)
338343
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
339-
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None,
344+
C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None,
340345
epsilon=1e-2, max_iter=10, verbose=True, log=False
341346
))
342347

@@ -347,11 +352,40 @@ def test_entropic_gromov(nx):
347352
np.testing.assert_allclose(
348353
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
349354

355+
356+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
357+
@pytest.skip_backend("tf", reason="test very slow with tf backend")
358+
@pytest.mark.parametrize('loss_fun', [
359+
'square_loss',
360+
'kl_loss',
361+
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
362+
])
363+
def test_entropic_gromov2(nx, loss_fun):
364+
n_samples = 10 # nb samples
365+
366+
mu_s = np.array([0, 0])
367+
cov_s = np.array([[1, 0], [0, 1]])
368+
369+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
370+
371+
xt = xs[::-1].copy()
372+
373+
p = ot.unif(n_samples)
374+
q = ot.unif(n_samples)
375+
G0 = p[:, None] * q[None, :]
376+
C1 = ot.dist(xs, xs)
377+
C2 = ot.dist(xt, xt)
378+
379+
C1 /= C1.max()
380+
C2 /= C2.max()
381+
382+
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
383+
350384
gw, log = ot.gromov.entropic_gromov_wasserstein2(
351-
C1, C2, p, None, 'kl_loss', symmetric=True, G0=None,
385+
C1, C2, p, None, loss_fun, symmetric=True, G0=None,
352386
max_iter=10, epsilon=1e-2, log=True)
353387
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
354-
C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b,
388+
C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b,
355389
max_iter=10, epsilon=1e-2, log=True)
356390
gwb = nx.to_numpy(gwb)
357391

0 commit comments

Comments
 (0)