From e3f35cd61b58148742bd2998ba2a586ef8091cb2 Mon Sep 17 00:00:00 2001 From: clvincen Date: Thu, 2 Nov 2023 21:36:05 +0100 Subject: [PATCH] add exact line-search for (f)gw solvers with kl_loss --- RELEASES.md | 1 + ot/gromov/_gw.py | 78 ++++++++++++++++++++++-------------- test/test_gromov.py | 97 +++++++++++++++++++++++---------------------- 3 files changed, 98 insertions(+), 78 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index f943886d5..7c090bef8 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -13,6 +13,7 @@ + Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543) + Upgraded unbalanced OT solvers for more flexibility (PR #539) + Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) ++ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index d5e4c7f13..88b1eb75f 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -160,15 +160,17 @@ def df(G): def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) - if loss_fun == 'kl_loss': - armijo = True # there is no closed form line-search with KL + + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True # there is no closed form line-search with KL if armijo: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs) if log: res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) @@ -296,9 +298,13 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - gw = nx.set_gradients(gw, (p, q, C1, C2), - (log_gw['u'] - nx.mean(log_gw['u']), - log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + gw = nx.set_gradients(gw, (p, q, C1, C2), + (log_gw['u'] - nx.mean(log_gw['u']), + log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) if log: return gw, log_gw @@ -449,15 +455,16 @@ def df(G): def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) - if loss_fun == 'kl_loss': - armijo = True # there is no closed form line-search with KL + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True # there is no closed form line-search with KL if armijo: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) @@ -591,18 +598,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - if isinstance(alpha, int) or isinstance(alpha, float): - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T)) - else: - - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T, - gw_term - lin_term)) + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + if isinstance(alpha, int) or isinstance(alpha, float): + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T)) + else: + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T, + gw_term - lin_term)) if log: return fgw_dist, log_fgw @@ -613,7 +622,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): """ - Solve the linesearch in the FW iterations + Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] `. Parameters ---------- @@ -625,9 +634,11 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, cost_G : float Value of the cost at `G` C1 : array-like (ns,ns), optional - Structure matrix in the source domain. + Transformed Structure matrix in the source domain. + For the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix C2 : array-like (nt,nt), optional - Structure matrix in the target domain. + Transformed Structure matrix in the source domain. + For the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix M : array-like (ns,nt) Cost matrix between the features. reg : float @@ -649,11 +660,16 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, .. _references-solve-linesearch: + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + """ if nx is None: G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) @@ -664,8 +680,8 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, nx = get_backend(G, deltaG, C1, C2, M) dot = nx.dot(nx.dot(C1, deltaG), C2.T) - a = -2 * reg * nx.sum(dot * deltaG) - b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + a = - reg * nx.sum(dot * deltaG) + b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: @@ -776,8 +792,9 @@ def gromov_barycenters( else: C = init_C - if loss_fun == 'kl_loss': - armijo = True + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True cpt = 0 err = 1 @@ -960,8 +977,9 @@ def fgw_barycenters( Ms = [dist(X, Ys[s]) for s in range(len(Ys))] - if loss_fun == 'kl_loss': - armijo = True + # removed since 0.9.2 + #if loss_fun == 'kl_loss': + # armijo = True cpt = 0 err_feature = 1 diff --git a/test/test_gromov.py b/test/test_gromov.py index 06f843a4a..8870a5023 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -52,31 +52,31 @@ def test_gromov(nx): Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) + for armijo in [False, True]: + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True) + gwb = nx.to_numpy(gwb) + + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False) + gw_valb = nx.to_numpy( + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False) + ) - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True) - gwb = nx.to_numpy(gwb) + G = log['T'] + Gb = nx.to_numpy(logb['T']) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False) - gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) - ) + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) + np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False - np.testing.assert_allclose(gw, gwb, atol=1e-06) - np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - - np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) - np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov def test_asymmetric_gromov(nx): @@ -1191,33 +1191,34 @@ def test_asymmetric_fgw(nx): np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) # Tests with kl-loss: - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True) - Gb = nx.to_numpy(Gb) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) - - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + for armijo in [False, True]: + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) def test_fgw2_gradients():