diff --git a/RELEASES.md b/RELEASES.md index 998d56836..a695d6a70 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,6 +2,9 @@ ## 0.9.3 +#### New features ++ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster. + #### Closed issues - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 69dd3df0c..3d7a47480 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -171,7 +171,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, symmetric=symmetric, **kwargs) if not nx.is_floating_point(C10): warnings.warn( @@ -479,7 +479,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, symmetric=symmetric, **kwargs) if not nx.is_floating_point(M0): warnings.warn( "Input feature matrix consists of integer. The transport plan will be " @@ -647,7 +647,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, - alpha_min=None, alpha_max=None, nx=None, **kwargs): + alpha_min=None, alpha_max=None, nx=None, symmetric=False, **kwargs): """ Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] `. @@ -676,6 +676,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, Maximum value for alpha nx : backend, optional If let to its default value None, a backend test will be conducted. + symmetric : bool, optional + Either structures are to be assumed symmetric or not. Default value is False. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + Returns ------- alpha : float @@ -708,7 +712,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, dot = nx.dot(nx.dot(C1, deltaG), C2.T) a = - reg * nx.sum(dot * deltaG) - b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + if symmetric: + b = nx.sum(M * deltaG) - 2 * reg * nx.sum(dot * G) + else: + b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: