From 2878b6e7ce7b3bd6f7de433a1b5947ead91de6e6 Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:45:44 +0100 Subject: [PATCH 1/4] improved gw linsearch for symmetric case --- ot/gromov/_gw.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 69dd3df0c..3d7a47480 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -171,7 +171,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, symmetric=symmetric, **kwargs) if not nx.is_floating_point(C10): warnings.warn( @@ -479,7 +479,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, symmetric=symmetric, **kwargs) if not nx.is_floating_point(M0): warnings.warn( "Input feature matrix consists of integer. The transport plan will be " @@ -647,7 +647,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, - alpha_min=None, alpha_max=None, nx=None, **kwargs): + alpha_min=None, alpha_max=None, nx=None, symmetric=False, **kwargs): """ Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] `. @@ -676,6 +676,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, Maximum value for alpha nx : backend, optional If let to its default value None, a backend test will be conducted. + symmetric : bool, optional + Either structures are to be assumed symmetric or not. Default value is False. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + Returns ------- alpha : float @@ -708,7 +712,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, dot = nx.dot(nx.dot(C1, deltaG), C2.T) a = - reg * nx.sum(dot * deltaG) - b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + if symmetric: + b = nx.sum(M * deltaG) - 2 * reg * nx.sum(dot * G) + else: + b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: From 354f8ff3e56a28e133665bcda1a317c7b4fbb3f7 Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Thu, 29 Feb 2024 12:08:33 +0100 Subject: [PATCH 2/4] add a demo in examples --- examples/gromov/symmetric_linesearch.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 examples/gromov/symmetric_linesearch.py diff --git a/examples/gromov/symmetric_linesearch.py b/examples/gromov/symmetric_linesearch.py new file mode 100644 index 000000000..116ad8bca --- /dev/null +++ b/examples/gromov/symmetric_linesearch.py @@ -0,0 +1,27 @@ +import numpy as np +from ot.gromov._gw import solve_gromov_linesearch +from time import time + +n = 1000 + +C1 = np.random.random((n, n)) +C1 = C1 + C1.T +C2 = np.random.random((n, n)) +C2 = C2 + C2.T + +G1 = np.ones((n, n)) / (n**2) +G2 = np.eye(n) / n + +tic = time() +alpha2, _, _ = solve_gromov_linesearch(G1, G2 - G1, cost_G=0, C1=C1, C2=C2, M=0, reg=1, symmetric=False) +tac = time() + +print(f'Linesearch time without symmetric assumption {tac-tic}') + +tic = time() +alpha1, _, _ = solve_gromov_linesearch(G1, G2 - G1, cost_G=0, C1=C1, C2=C2, M=0, reg=1, symmetric=True) +tac = time() + +print(f'Linesearch time with symmetric assumption {tac-tic}') + +assert alpha1 == alpha2 \ No newline at end of file From 1f341f3d0add5578672d9a6bc7509e8a4d72b29d Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:18:25 +0100 Subject: [PATCH 3/4] remove example --- examples/gromov/symmetric_linesearch.py | 27 ------------------------- 1 file changed, 27 deletions(-) delete mode 100644 examples/gromov/symmetric_linesearch.py diff --git a/examples/gromov/symmetric_linesearch.py b/examples/gromov/symmetric_linesearch.py deleted file mode 100644 index 116ad8bca..000000000 --- a/examples/gromov/symmetric_linesearch.py +++ /dev/null @@ -1,27 +0,0 @@ -import numpy as np -from ot.gromov._gw import solve_gromov_linesearch -from time import time - -n = 1000 - -C1 = np.random.random((n, n)) -C1 = C1 + C1.T -C2 = np.random.random((n, n)) -C2 = C2 + C2.T - -G1 = np.ones((n, n)) / (n**2) -G2 = np.eye(n) / n - -tic = time() -alpha2, _, _ = solve_gromov_linesearch(G1, G2 - G1, cost_G=0, C1=C1, C2=C2, M=0, reg=1, symmetric=False) -tac = time() - -print(f'Linesearch time without symmetric assumption {tac-tic}') - -tic = time() -alpha1, _, _ = solve_gromov_linesearch(G1, G2 - G1, cost_G=0, C1=C1, C2=C2, M=0, reg=1, symmetric=True) -tac = time() - -print(f'Linesearch time with symmetric assumption {tac-tic}') - -assert alpha1 == alpha2 \ No newline at end of file From a26827a6800ea31ddb0c6e6228f54182bc36ea02 Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Thu, 29 Feb 2024 18:48:28 +0100 Subject: [PATCH 4/4] releases.md updated --- RELEASES.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index 998d56836..a695d6a70 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,6 +2,9 @@ ## 0.9.3 +#### New features ++ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster. + #### Closed issues - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)