Skip to content

loss_fun parameter validationf for entropic GW functions #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
Expand All @@ -92,8 +92,8 @@ def entropic_gromov_wasserstein(
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
Otherwise G0 will be used as initial transport of the solver. G0 is not
required to satisfy marginal constraints but we strongly recommand it
to correcly estimate the GW distance.
required to satisfy marginal constraints but we strongly recommend it
to correctly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Expand Down Expand Up @@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
if solver not in ['PGD', 'PPA']:
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)

if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

C1, C2 = list_to_array(C1, C2)
arr = [C1, C2]
if p is not None:
Expand Down Expand Up @@ -280,7 +283,7 @@ def entropic_gromov_wasserstein2(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
Expand Down Expand Up @@ -373,8 +376,8 @@ def entropic_gromov_barycenters(
lambdas : list of float, optional
List of the `S` spaces' weights.
If let to its default value None, uniform weights are taken.
loss_fun : callable, optional
tensor-matrix multiplication function based on specific loss function
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
symmetric : bool, optional.
Expand Down Expand Up @@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

Cs = list_to_array(*Cs)
arr = [*Cs]
if ps is not None:
Expand Down Expand Up @@ -459,7 +465,6 @@ def entropic_gromov_barycenters(

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs)

Expand Down Expand Up @@ -550,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
Otherwise G0 will be used as initial transport of the solver. G0 is not
required to satisfy marginal constraints but we strongly recommand it
to correcly estimate the GW distance.
required to satisfy marginal constraints but we strongly recommend it
to correctly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Expand Down Expand Up @@ -611,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
if solver not in ['PGD', 'PPA']:
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)

if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

M, C1, C2 = list_to_array(M, C1, C2)
arr = [M, C1, C2]
if p is not None:
Expand Down Expand Up @@ -762,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
q : array-like, shape (nt,), optional
Distribution in the target space.
If let to its default value None, uniform distribution is taken.
loss_fun : string, optional
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
Expand All @@ -775,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
Otherwise G0 will be used as initial transport of the solver. G0 is not
required to satisfy marginal constraints but we strongly recommand it
to correcly estimate the GW distance.
required to satisfy marginal constraints but we strongly recommend it
to correctly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Expand Down Expand Up @@ -857,8 +865,8 @@ def entropic_fused_gromov_barycenters(
lambdas : list of float, optional
List of the `S` spaces' weights.
If let to its default value None, uniform weights are taken.
loss_fun : callable, optional
tensor-matrix multiplication function based on specific loss function
loss_fun : string, optional (default='square_loss')
Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float, optional
Regularization term >0
symmetric : bool, optional.
Expand Down Expand Up @@ -907,6 +915,9 @@ def entropic_fused_gromov_barycenters(
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

Cs = list_to_array(*Cs)
Ys = list_to_array(*Ys)
arr = [*Cs, *Ys]
Expand Down Expand Up @@ -977,7 +988,6 @@ def entropic_fused_gromov_barycenters(

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs)

Expand Down
44 changes: 39 additions & 5 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,12 @@ def test_gw_helper_validation(loss_fun):

@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
@pytest.mark.parametrize('loss_fun', [
'square_loss',
'kl_loss',
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_entropic_gromov(nx, loss_fun):
n_samples = 10 # nb samples

mu_s = np.array([0, 0])
Expand All @@ -333,10 +338,10 @@ def test_entropic_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)

G, log = ot.gromov.entropic_gromov_wasserstein(
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
C1, C2, None, q, loss_fun, symmetric=None, G0=G0,
epsilon=1e-2, max_iter=10, verbose=True, log=True)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None,
C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None,
epsilon=1e-2, max_iter=10, verbose=True, log=False
))

Expand All @@ -347,11 +352,40 @@ def test_entropic_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov


@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.mark.parametrize('loss_fun', [
'square_loss',
'kl_loss',
pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_entropic_gromov2(nx, loss_fun):
n_samples = 10 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)

xt = xs[::-1].copy()

p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()

C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)

gw, log = ot.gromov.entropic_gromov_wasserstein2(
C1, C2, p, None, 'kl_loss', symmetric=True, G0=None,
C1, C2, p, None, loss_fun, symmetric=True, G0=None,
max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b,
C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b,
max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)

Expand Down