Skip to content

Commit f1fe593

Browse files
[MRG] Faster gromov-wasserstein linesearch for symmetric matrices (#607)
* improved gw linsearch for symmetric case * add a demo in examples * remove example * releases.md updated --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 737b20d commit f1fe593

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

RELEASES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## 0.9.3
44

5+
#### New features
6+
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster.
7+
58
#### Closed issues
69
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
710
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)

ot/gromov/_gw.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
171171
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
172172
else:
173173
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
174-
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)
174+
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, symmetric=symmetric, **kwargs)
175175

176176
if not nx.is_floating_point(C10):
177177
warnings.warn(
@@ -479,7 +479,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
479479
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
480480
else:
481481
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
482-
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
482+
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, symmetric=symmetric, **kwargs)
483483
if not nx.is_floating_point(M0):
484484
warnings.warn(
485485
"Input feature matrix consists of integer. The transport plan will be "
@@ -647,7 +647,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
647647

648648

649649
def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
650-
alpha_min=None, alpha_max=None, nx=None, **kwargs):
650+
alpha_min=None, alpha_max=None, nx=None, symmetric=False, **kwargs):
651651
"""
652652
Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] <references-solve-linesearch>`.
653653
@@ -676,6 +676,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
676676
Maximum value for alpha
677677
nx : backend, optional
678678
If let to its default value None, a backend test will be conducted.
679+
symmetric : bool, optional
680+
Either structures are to be assumed symmetric or not. Default value is False.
681+
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
682+
679683
Returns
680684
-------
681685
alpha : float
@@ -708,7 +712,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
708712

709713
dot = nx.dot(nx.dot(C1, deltaG), C2.T)
710714
a = - reg * nx.sum(dot * deltaG)
711-
b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
715+
if symmetric:
716+
b = nx.sum(M * deltaG) - 2 * reg * nx.sum(dot * G)
717+
else:
718+
b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
712719

713720
alpha = solve_1d_linesearch_quad(a, b)
714721
if alpha_min is not None or alpha_max is not None:

0 commit comments

Comments
 (0)