Skip to content

Commit a73ad08

Browse files
add exact line-search for (f)gw solvers with kl_loss (#556)
1 parent 53dde7a commit a73ad08

File tree

3 files changed

+98
-78
lines changed

3 files changed

+98
-78
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
+ Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543)
1414
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
1515
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)
16+
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)
1617

1718
#### Closed issues
1819
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/gromov/_gw.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,17 @@ def df(G):
160160

161161
def df(G):
162162
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
163-
if loss_fun == 'kl_loss':
164-
armijo = True # there is no closed form line-search with KL
163+
164+
# removed since 0.9.2
165+
#if loss_fun == 'kl_loss':
166+
# armijo = True # there is no closed form line-search with KL
165167

166168
if armijo:
167169
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
168170
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
169171
else:
170172
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
171-
return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs)
173+
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)
172174
if log:
173175
res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
174176
log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
@@ -296,9 +298,13 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
296298
if loss_fun == 'square_loss':
297299
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
298300
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
299-
gw = nx.set_gradients(gw, (p, q, C1, C2),
300-
(log_gw['u'] - nx.mean(log_gw['u']),
301-
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
301+
elif loss_fun == 'kl_loss':
302+
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
303+
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
304+
305+
gw = nx.set_gradients(gw, (p, q, C1, C2),
306+
(log_gw['u'] - nx.mean(log_gw['u']),
307+
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
302308

303309
if log:
304310
return gw, log_gw
@@ -449,15 +455,16 @@ def df(G):
449455
def df(G):
450456
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
451457

452-
if loss_fun == 'kl_loss':
453-
armijo = True # there is no closed form line-search with KL
458+
# removed since 0.9.2
459+
#if loss_fun == 'kl_loss':
460+
# armijo = True # there is no closed form line-search with KL
454461

455462
if armijo:
456463
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
457464
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
458465
else:
459466
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
460-
return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
467+
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
461468
if log:
462469
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
463470
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
@@ -591,18 +598,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
591598
if loss_fun == 'square_loss':
592599
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
593600
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
594-
if isinstance(alpha, int) or isinstance(alpha, float):
595-
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
596-
(log_fgw['u'] - nx.mean(log_fgw['u']),
597-
log_fgw['v'] - nx.mean(log_fgw['v']),
598-
alpha * gC1, alpha * gC2, (1 - alpha) * T))
599-
else:
600-
601-
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
602-
(log_fgw['u'] - nx.mean(log_fgw['u']),
603-
log_fgw['v'] - nx.mean(log_fgw['v']),
604-
alpha * gC1, alpha * gC2, (1 - alpha) * T,
605-
gw_term - lin_term))
601+
elif loss_fun == 'kl_loss':
602+
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
603+
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
604+
if isinstance(alpha, int) or isinstance(alpha, float):
605+
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
606+
(log_fgw['u'] - nx.mean(log_fgw['u']),
607+
log_fgw['v'] - nx.mean(log_fgw['v']),
608+
alpha * gC1, alpha * gC2, (1 - alpha) * T))
609+
else:
610+
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
611+
(log_fgw['u'] - nx.mean(log_fgw['u']),
612+
log_fgw['v'] - nx.mean(log_fgw['v']),
613+
alpha * gC1, alpha * gC2, (1 - alpha) * T,
614+
gw_term - lin_term))
606615

607616
if log:
608617
return fgw_dist, log_fgw
@@ -613,7 +622,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
613622
def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
614623
alpha_min=None, alpha_max=None, nx=None, **kwargs):
615624
"""
616-
Solve the linesearch in the FW iterations
625+
Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] <references-solve-linesearch>`.
617626
618627
Parameters
619628
----------
@@ -625,9 +634,11 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
625634
cost_G : float
626635
Value of the cost at `G`
627636
C1 : array-like (ns,ns), optional
628-
Structure matrix in the source domain.
637+
Transformed Structure matrix in the source domain.
638+
For the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix
629639
C2 : array-like (nt,nt), optional
630-
Structure matrix in the target domain.
640+
Transformed Structure matrix in the source domain.
641+
For the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix
631642
M : array-like (ns,nt)
632643
Cost matrix between the features.
633644
reg : float
@@ -649,11 +660,16 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
649660
650661
651662
.. _references-solve-linesearch:
663+
652664
References
653665
----------
654666
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
655667
"Optimal Transport for structured data with application on graphs"
656668
International Conference on Machine Learning (ICML). 2019.
669+
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
670+
"Gromov-Wasserstein averaging of kernel and distance matrices."
671+
International Conference on Machine Learning (ICML). 2016.
672+
657673
"""
658674
if nx is None:
659675
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
@@ -664,8 +680,8 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
664680
nx = get_backend(G, deltaG, C1, C2, M)
665681

666682
dot = nx.dot(nx.dot(C1, deltaG), C2.T)
667-
a = -2 * reg * nx.sum(dot * deltaG)
668-
b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
683+
a = - reg * nx.sum(dot * deltaG)
684+
b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
669685

670686
alpha = solve_1d_linesearch_quad(a, b)
671687
if alpha_min is not None or alpha_max is not None:
@@ -776,8 +792,9 @@ def gromov_barycenters(
776792
else:
777793
C = init_C
778794

779-
if loss_fun == 'kl_loss':
780-
armijo = True
795+
# removed since 0.9.2
796+
#if loss_fun == 'kl_loss':
797+
# armijo = True
781798

782799
cpt = 0
783800
err = 1
@@ -960,8 +977,9 @@ def fgw_barycenters(
960977

961978
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
962979

963-
if loss_fun == 'kl_loss':
964-
armijo = True
980+
# removed since 0.9.2
981+
#if loss_fun == 'kl_loss':
982+
# armijo = True
965983

966984
cpt = 0
967985
err_feature = 1

test/test_gromov.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,31 @@ def test_gromov(nx):
5252
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
5353

5454
np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
55+
for armijo in [False, True]:
56+
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True)
57+
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True)
58+
gwb = nx.to_numpy(gwb)
59+
60+
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False)
61+
gw_valb = nx.to_numpy(
62+
ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False)
63+
)
5564

56-
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True)
57-
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True)
58-
gwb = nx.to_numpy(gwb)
65+
G = log['T']
66+
Gb = nx.to_numpy(logb['T'])
5967

60-
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
61-
gw_valb = nx.to_numpy(
62-
ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
63-
)
68+
np.testing.assert_allclose(gw, gwb, atol=1e-06)
69+
np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)
6470

65-
G = log['T']
66-
Gb = nx.to_numpy(logb['T'])
71+
np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
72+
np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False
6773

68-
np.testing.assert_allclose(gw, gwb, atol=1e-06)
69-
np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)
70-
71-
np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
72-
np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False
73-
74-
# check constraints
75-
np.testing.assert_allclose(G, Gb, atol=1e-06)
76-
np.testing.assert_allclose(
77-
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
78-
np.testing.assert_allclose(
79-
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
74+
# check constraints
75+
np.testing.assert_allclose(G, Gb, atol=1e-06)
76+
np.testing.assert_allclose(
77+
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
78+
np.testing.assert_allclose(
79+
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
8080

8181

8282
def test_asymmetric_gromov(nx):
@@ -1191,33 +1191,34 @@ def test_asymmetric_fgw(nx):
11911191
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
11921192

11931193
# Tests with kl-loss:
1194-
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
1195-
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
1196-
Gb = nx.to_numpy(Gb)
1197-
# check constraints
1198-
np.testing.assert_allclose(G, Gb, atol=1e-06)
1199-
np.testing.assert_allclose(
1200-
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
1201-
np.testing.assert_allclose(
1202-
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
1203-
1204-
np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
1205-
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
1206-
1207-
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
1208-
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
1209-
1210-
G = log['T']
1211-
Gb = nx.to_numpy(logb['T'])
1212-
# check constraints
1213-
np.testing.assert_allclose(G, Gb, atol=1e-06)
1214-
np.testing.assert_allclose(
1215-
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
1216-
np.testing.assert_allclose(
1217-
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
1218-
1219-
np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
1220-
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
1194+
for armijo in [False, True]:
1195+
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, log=True, symmetric=False, verbose=True)
1196+
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, log=True, symmetric=None, G0=G0b, verbose=True)
1197+
Gb = nx.to_numpy(Gb)
1198+
# check constraints
1199+
np.testing.assert_allclose(G, Gb, atol=1e-06)
1200+
np.testing.assert_allclose(
1201+
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
1202+
np.testing.assert_allclose(
1203+
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
1204+
1205+
np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
1206+
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
1207+
1208+
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
1209+
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
1210+
1211+
G = log['T']
1212+
Gb = nx.to_numpy(logb['T'])
1213+
# check constraints
1214+
np.testing.assert_allclose(G, Gb, atol=1e-06)
1215+
np.testing.assert_allclose(
1216+
p, Gb.sum(1), atol=1e-04) # cf convergence gromov
1217+
np.testing.assert_allclose(
1218+
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
1219+
1220+
np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
1221+
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
12211222

12221223

12231224
def test_fgw2_gradients():

0 commit comments

Comments
 (0)