diff --git a/RELEASES.md b/RELEASES.md index e6e6ff4d4..77863b640 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -19,6 +19,7 @@ - Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) - Split `test/test_gromov.py` into `test/gromov/` (PR #619) - Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) +- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) ## 0.9.3 *January 2024* diff --git a/ot/partial.py b/ot/partial.py index 85635c9ba..a3b25a856 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -4,12 +4,15 @@ """ # Author: Laetitia Chapel -# License: MIT License +# Yikun Bai < yikun.bai@vanderbilt.edu > +# Cédric Vincent-Cuaz -import numpy as np -from .lp import emd -from .backend import get_backend from .utils import list_to_array +from .backend import get_backend +from .lp import emd +import numpy as np + +# License: MIT License def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, @@ -581,7 +584,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, " equal than min(|a|_1, |b|_1).") if G0 is None: - G0 = np.outer(p, q) + G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies) q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies) @@ -597,7 +600,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, Gprev = np.copy(G0) - M = gwgrad_partial(C1, C2, G0) + M = 0.5 * gwgrad_partial(C1, C2, G0) # rescaling the gradient with 0.5 for line-search while not changing Gc M_emd = np.zeros(dim_G_extended) M_emd[:len(p), :len(q)] = M M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2