Skip to content

Commit 1071759

Browse files
add kl_loss to all semi-relaxed (f)gw solvers (#559)
1 parent a73ad08 commit 1071759

File tree

4 files changed

+230
-210
lines changed

4 files changed

+230
-210
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
1515
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)
1616
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)
17+
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)
1718

1819
#### Closed issues
1920
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/gromov/_semirelaxed.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
5656
If let to its default value None, uniform distribution is taken.
5757
loss_fun : str
5858
loss function used for the solver either 'square_loss' or 'kl_loss'.
59-
'kl_loss' is not implemented yet and will raise an error.
6059
symmetric : bool, optional
6160
Either C1 and C2 are to be assumed symmetric or not.
6261
If let to its default None value, a symmetry test will be conducted.
@@ -92,8 +91,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
9291
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
9392
International Conference on Learning Representations (ICLR), 2022.
9493
"""
95-
if loss_fun == 'kl_loss':
96-
raise NotImplementedError()
9794
arr = [C1, C2]
9895
if p is not None:
9996
arr.append(list_to_array(p))
@@ -139,7 +136,7 @@ def df(G):
139136
return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
140137

141138
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
142-
return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs)
139+
return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs)
143140

144141
if log:
145142
res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
@@ -190,7 +187,6 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
190187
If let to its default value None, uniform distribution is taken.
191188
loss_fun : str
192189
loss function used for the solver either 'square_loss' or 'kl_loss'.
193-
'kl_loss' is not implemented yet and will raise an error.
194190
symmetric : bool, optional
195191
Either C1 and C2 are to be assumed symmetric or not.
196192
If let to its default None value, a symmetry test will be conducted.
@@ -243,7 +239,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
243239
if loss_fun == 'square_loss':
244240
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
245241
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
246-
srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
242+
243+
elif loss_fun == 'kl_loss':
244+
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
245+
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
246+
247+
srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
247248

248249
if log:
249250
return srgw, log_srgw
@@ -291,7 +292,6 @@ def semirelaxed_fused_gromov_wasserstein(
291292
If let to its default value None, uniform distribution is taken.
292293
loss_fun : str
293294
loss function used for the solver either 'square_loss' or 'kl_loss'.
294-
'kl_loss' is not implemented yet and will raise an error.
295295
symmetric : bool, optional
296296
Either C1 and C2 are to be assumed symmetric or not.
297297
If let to its default None value, a symmetry test will be conducted.
@@ -332,9 +332,6 @@ def semirelaxed_fused_gromov_wasserstein(
332332
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
333333
International Conference on Learning Representations (ICLR), 2022.
334334
"""
335-
if loss_fun == 'kl_loss':
336-
raise NotImplementedError()
337-
338335
arr = [M, C1, C2]
339336
if p is not None:
340337
arr.append(list_to_array(p))
@@ -382,7 +379,7 @@ def df(G):
382379

383380
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
384381
return solve_semirelaxed_gromov_linesearch(
385-
G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs)
382+
G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs)
386383

387384
if log:
388385
res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
@@ -434,7 +431,6 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
434431
If let to its default value None, uniform distribution is taken.
435432
loss_fun : str, optional
436433
loss function used for the solver either 'square_loss' or 'kl_loss'.
437-
'kl_loss' is not implemented yet and will raise an error.
438434
symmetric : bool, optional
439435
Either C1 and C2 are to be assumed symmetric or not.
440436
If let to its default None value, a symmetry test will be conducted.
@@ -494,15 +490,20 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
494490
if loss_fun == 'square_loss':
495491
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
496492
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
497-
if isinstance(alpha, int) or isinstance(alpha, float):
498-
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
499-
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
500-
else:
501-
lin_term = nx.sum(T * M)
502-
srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
503-
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
504-
(alpha * gC1, alpha * gC2, (1 - alpha) * T,
505-
srgw_term - lin_term))
493+
494+
elif loss_fun == 'kl_loss':
495+
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
496+
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
497+
498+
if isinstance(alpha, int) or isinstance(alpha, float):
499+
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
500+
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
501+
else:
502+
lin_term = nx.sum(T * M)
503+
srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
504+
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
505+
(alpha * gC1, alpha * gC2, (1 - alpha) * T,
506+
srgw_term - lin_term))
506507

507508
if log:
508509
return srfgw_dist, log_fgw
@@ -511,7 +512,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
511512

512513

513514
def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
514-
M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs):
515+
M, reg, fC2t=None, alpha_min=None, alpha_max=None, nx=None, **kwargs):
515516
"""
516517
Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence.
517518
@@ -524,16 +525,22 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
524525
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
525526
cost_G : float
526527
Value of the cost at `G`
527-
C1 : array-like (ns,ns)
528-
Structure matrix in the source domain.
529-
C2 : array-like (nt,nt)
530-
Structure matrix in the target domain.
528+
C1 : array-like (ns,ns), optional
529+
Transformed Structure matrix in the source domain.
530+
Note that for the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix_semirelaxed
531+
C2 : array-like (nt,nt), optional
532+
Transformed Structure matrix in the source domain.
533+
Note that for the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix_semirelaxed
531534
ones_p: array-like (ns,1)
532535
Array of ones of size ns
533536
M : array-like (ns,nt)
534537
Cost matrix between the features.
535538
reg : float
536539
Regularization parameter.
540+
fC2t: array-like (nt,nt), optional
541+
Transformed Structure matrix in the source domain.
542+
Note that for the 'square_loss' and 'kl_loss', we provide fC2t from ot.gromov.init_matrix_semirelaxed.
543+
If fC2t is not provided, it is by default fC2t corresponding to the 'square_loss'.
537544
alpha_min : float, optional
538545
Minimum value for alpha
539546
alpha_max : float, optional
@@ -565,11 +572,14 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
565572

566573
qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0)
567574
dot = nx.dot(nx.dot(C1, deltaG), C2.T)
568-
C2t_square = C2.T ** 2
569-
dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square)
570-
dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square)
571-
a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG)
572-
b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG))
575+
if fC2t is None:
576+
fC2t = C2.T ** 2
577+
dot_qG = nx.dot(nx.outer(ones_p, qG), fC2t)
578+
dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), fC2t)
579+
580+
a = reg * nx.sum((dot_qdeltaG - dot) * deltaG)
581+
b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - dot) * G) + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG))
582+
573583
alpha = solve_1d_linesearch_quad(a, b)
574584
if alpha_min is not None or alpha_max is not None:
575585
alpha = np.clip(alpha, alpha_min, alpha_max)
@@ -620,7 +630,6 @@ def entropic_semirelaxed_gromov_wasserstein(
620630
If let to its default value None, uniform distribution is taken.
621631
loss_fun : str
622632
loss function used for the solver either 'square_loss' or 'kl_loss'.
623-
'kl_loss' is not implemented yet and will raise an error.
624633
epsilon : float
625634
Regularization term >0
626635
symmetric : bool, optional
@@ -655,8 +664,6 @@ def entropic_semirelaxed_gromov_wasserstein(
655664
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
656665
International Conference on Learning Representations (ICLR), 2022.
657666
"""
658-
if loss_fun == 'kl_loss':
659-
raise NotImplementedError()
660667
arr = [C1, C2]
661668
if p is not None:
662669
arr.append(list_to_array(p))
@@ -777,7 +784,6 @@ def entropic_semirelaxed_gromov_wasserstein2(
777784
If let to its default value None, uniform distribution is taken.
778785
loss_fun : str
779786
loss function used for the solver either 'square_loss' or 'kl_loss'.
780-
'kl_loss' is not implemented yet and will raise an error.
781787
epsilon : float
782788
Regularization term >0
783789
symmetric : bool, optional
@@ -869,7 +875,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
869875
If let to its default value None, uniform distribution is taken.
870876
loss_fun : str
871877
loss function used for the solver either 'square_loss' or 'kl_loss'.
872-
'kl_loss' is not implemented yet and will raise an error.
873878
epsilon : float
874879
Regularization term >0
875880
symmetric : bool, optional
@@ -907,8 +912,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
907912
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
908913
International Conference on Learning Representations (ICLR), 2022.
909914
"""
910-
if loss_fun == 'kl_loss':
911-
raise NotImplementedError()
912915
arr = [M, C1, C2]
913916
if p is not None:
914917
arr.append(list_to_array(p))
@@ -1032,7 +1035,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein2(
10321035
If let to its default value None, uniform distribution is taken.
10331036
loss_fun : str, optional
10341037
loss function used for the solver either 'square_loss' or 'kl_loss'.
1035-
'kl_loss' is not implemented yet and will raise an error.
10361038
epsilon : float
10371039
Regularization term >0
10381040
symmetric : bool, optional

ot/gromov/_utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,19 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
399399
400400
h_2(b) &= 2b
401401
402+
The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
403+
404+
.. math::
405+
406+
L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
407+
408+
\mathrm{with} \ f_1(a) &= a \log(a) - a
409+
410+
f_2(b) &= b
411+
412+
h_1(a) &= a
413+
414+
h_2(b) &= \log(b)
402415
Parameters
403416
----------
404417
C1 : array-like, shape (ns, ns)
@@ -451,9 +464,19 @@ def h1(a):
451464
def h2(b):
452465
return 2 * b
453466
elif loss_fun == 'kl_loss':
454-
raise NotImplementedError()
467+
def f1(a):
468+
return a * nx.log(a + 1e-15) - a
469+
470+
def f2(b):
471+
return b
472+
473+
def h1(a):
474+
return a
475+
476+
def h2(b):
477+
return nx.log(b + 1e-15)
455478
else:
456-
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.")
479+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
457480

458481
constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
459482
nx.ones((1, C2.shape[0]), type_as=p))

0 commit comments

Comments
 (0)