Skip to content

[MRG] add kl_loss to all semi-relaxed (f)gw solvers #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
80 changes: 41 additions & 39 deletions ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -92,8 +91,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
arr = [C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -139,7 +136,7 @@ def df(G):
return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))

def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs)
return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs)

if log:
res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
Expand Down Expand Up @@ -190,7 +187,6 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -243,7 +239,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))

elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))

if log:
return srgw, log_srgw
Expand Down Expand Up @@ -291,7 +292,6 @@ def semirelaxed_fused_gromov_wasserstein(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -332,9 +332,6 @@ def semirelaxed_fused_gromov_wasserstein(
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()

arr = [M, C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -382,7 +379,7 @@ def df(G):

def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_semirelaxed_gromov_linesearch(
G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs)
G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs)

if log:
res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
Expand Down Expand Up @@ -434,7 +431,6 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -494,15 +490,20 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
if isinstance(alpha, int) or isinstance(alpha, float):
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
(alpha * gC1, alpha * gC2, (1 - alpha) * T,
srgw_term - lin_term))

elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

if isinstance(alpha, int) or isinstance(alpha, float):
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
(alpha * gC1, alpha * gC2, (1 - alpha) * T,
srgw_term - lin_term))

if log:
return srfgw_dist, log_fgw
Expand All @@ -511,7 +512,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo


def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs):
M, reg, fC2t=None, alpha_min=None, alpha_max=None, nx=None, **kwargs):
"""
Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence.

Expand All @@ -524,16 +525,22 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
cost_G : float
Value of the cost at `G`
C1 : array-like (ns,ns)
Structure matrix in the source domain.
C2 : array-like (nt,nt)
Structure matrix in the target domain.
C1 : array-like (ns,ns), optional
Transformed Structure matrix in the source domain.
Note that for the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix_semirelaxed
C2 : array-like (nt,nt), optional
Transformed Structure matrix in the source domain.
Note that for the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix_semirelaxed
ones_p: array-like (ns,1)
Array of ones of size ns
M : array-like (ns,nt)
Cost matrix between the features.
reg : float
Regularization parameter.
fC2t: array-like (nt,nt), optional
Transformed Structure matrix in the source domain.
Note that for the 'square_loss' and 'kl_loss', we provide fC2t from ot.gromov.init_matrix_semirelaxed.
If fC2t is not provided, it is by default fC2t corresponding to the 'square_loss'.
alpha_min : float, optional
Minimum value for alpha
alpha_max : float, optional
Expand Down Expand Up @@ -565,11 +572,14 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,

qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0)
dot = nx.dot(nx.dot(C1, deltaG), C2.T)
C2t_square = C2.T ** 2
dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square)
dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square)
a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG)
b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG))
if fC2t is None:
fC2t = C2.T ** 2
dot_qG = nx.dot(nx.outer(ones_p, qG), fC2t)
dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), fC2t)

a = reg * nx.sum((dot_qdeltaG - dot) * deltaG)
b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - dot) * G) + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG))

alpha = solve_1d_linesearch_quad(a, b)
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
Expand Down Expand Up @@ -620,7 +630,6 @@ def entropic_semirelaxed_gromov_wasserstein(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down Expand Up @@ -655,8 +664,6 @@ def entropic_semirelaxed_gromov_wasserstein(
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
arr = [C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -777,7 +784,6 @@ def entropic_semirelaxed_gromov_wasserstein2(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down Expand Up @@ -869,7 +875,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down Expand Up @@ -907,8 +912,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
arr = [M, C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -1032,7 +1035,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein2(
If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down
27 changes: 25 additions & 2 deletions ot/gromov/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,19 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):

h_2(b) &= 2b

The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :

.. math::

L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)

\mathrm{with} \ f_1(a) &= a \log(a) - a

f_2(b) &= b

h_1(a) &= a

h_2(b) &= \log(b)
Parameters
----------
C1 : array-like, shape (ns, ns)
Expand Down Expand Up @@ -451,9 +464,19 @@ def h1(a):
def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
raise NotImplementedError()
def f1(a):
return a * nx.log(a + 1e-15) - a

def f2(b):
return b

def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.")
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
nx.ones((1, C2.shape[0]), type_as=p))
Expand Down
Loading