Skip to content

[MRG] Add exact line-search for (f)gw solvers with kl_loss #556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
+ Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543)
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
78 changes: 48 additions & 30 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,17 @@ def df(G):

def df(G):
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
if loss_fun == 'kl_loss':
armijo = True # there is no closed form line-search with KL

# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True # there is no closed form line-search with KL

if armijo:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs)
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)
if log:
res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
Expand Down Expand Up @@ -296,9 +298,13 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
gw = nx.set_gradients(gw, (p, q, C1, C2),
(log_gw['u'] - nx.mean(log_gw['u']),
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

gw = nx.set_gradients(gw, (p, q, C1, C2),
(log_gw['u'] - nx.mean(log_gw['u']),
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))

if log:
return gw, log_gw
Expand Down Expand Up @@ -449,15 +455,16 @@ def df(G):
def df(G):
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))

if loss_fun == 'kl_loss':
armijo = True # there is no closed form line-search with KL
# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True # there is no closed form line-search with KL

if armijo:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
Expand Down Expand Up @@ -591,18 +598,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
if isinstance(alpha, int) or isinstance(alpha, float):
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:

fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T,
gw_term - lin_term))
elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
if isinstance(alpha, int) or isinstance(alpha, float):
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T,
gw_term - lin_term))

if log:
return fgw_dist, log_fgw
Expand All @@ -613,7 +622,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
alpha_min=None, alpha_max=None, nx=None, **kwargs):
"""
Solve the linesearch in the FW iterations
Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] <references-solve-linesearch>`.

Parameters
----------
Expand All @@ -625,9 +634,11 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
cost_G : float
Value of the cost at `G`
C1 : array-like (ns,ns), optional
Structure matrix in the source domain.
Transformed Structure matrix in the source domain.
For the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix
C2 : array-like (nt,nt), optional
Structure matrix in the target domain.
Transformed Structure matrix in the source domain.
For the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix
M : array-like (ns,nt)
Cost matrix between the features.
reg : float
Expand All @@ -649,11 +660,16 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,


.. _references-solve-linesearch:

References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
if nx is None:
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
Expand All @@ -664,8 +680,8 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
nx = get_backend(G, deltaG, C1, C2, M)

dot = nx.dot(nx.dot(C1, deltaG), C2.T)
a = -2 * reg * nx.sum(dot * deltaG)
b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
a = - reg * nx.sum(dot * deltaG)
b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))

alpha = solve_1d_linesearch_quad(a, b)
if alpha_min is not None or alpha_max is not None:
Expand Down Expand Up @@ -776,8 +792,9 @@ def gromov_barycenters(
else:
C = init_C

if loss_fun == 'kl_loss':
armijo = True
# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True

cpt = 0
err = 1
Expand Down Expand Up @@ -960,8 +977,9 @@ def fgw_barycenters(

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if loss_fun == 'kl_loss':
armijo = True
# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True

cpt = 0
err_feature = 1
Expand Down
97 changes: 49 additions & 48 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,31 @@ def test_gromov(nx):
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)

np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
for armijo in [False, True]:
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True)
gwb = nx.to_numpy(gwb)

gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False)
gw_valb = nx.to_numpy(
ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False)
)

gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])

gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
gw_valb = nx.to_numpy(
ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
)
np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)

G = log['T']
Gb = nx.to_numpy(logb['T'])
np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False

np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)

np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False

# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov


def test_asymmetric_gromov(nx):
Expand Down Expand Up @@ -1191,33 +1191,34 @@ def test_asymmetric_fgw(nx):
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)

# Tests with kl-loss:
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
Gb = nx.to_numpy(Gb)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)

fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)

G = log['T']
Gb = nx.to_numpy(logb['T'])
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
for armijo in [False, True]:
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, log=True, symmetric=False, verbose=True)
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, log=True, symmetric=None, G0=G0b, verbose=True)
Gb = nx.to_numpy(Gb)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)

fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)

G = log['T']
Gb = nx.to_numpy(logb['T'])
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov

np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)


def test_fgw2_gradients():
Expand Down