Skip to content

Commit a56e1b2

Browse files
[MRG] correct independence of fgw barycenters to init (#566)
* correct independence of fgw barycenters to init * fix pep8 and tests * correct PR id * take into account comments
1 parent 1682b60 commit a56e1b2

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
2222
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
2323
- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)
24+
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
2425

2526
## 0.9.1
2627
*August 2023*

ot/gromov/_gw.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,6 @@ def df(G):
166166
def df(G):
167167
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
168168

169-
# removed since 0.9.2
170-
#if loss_fun == 'kl_loss':
171-
# armijo = True # there is no closed form line-search with KL
172-
173169
if armijo:
174170
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
175171
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
@@ -478,10 +474,6 @@ def df(G):
478474
def df(G):
479475
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
480476

481-
# removed since 0.9.2
482-
#if loss_fun == 'kl_loss':
483-
# armijo = True # there is no closed form line-search with KL
484-
485477
if armijo:
486478
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
487479
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
@@ -827,10 +819,6 @@ def gromov_barycenters(
827819
else:
828820
C = init_C
829821

830-
# removed since 0.9.2
831-
#if loss_fun == 'kl_loss':
832-
# armijo = True
833-
834822
cpt = 0
835823
err = 1
836824

@@ -1005,16 +993,14 @@ def fgw_barycenters(
1005993
else:
1006994
if init_X is None:
1007995
X = nx.zeros((N, d), type_as=ps[0])
996+
1008997
else:
1009998
X = init_X
1010999

1011-
T = [nx.outer(p, q) for q in ps]
1012-
10131000
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
10141001

1015-
# removed since 0.9.2
1016-
#if loss_fun == 'kl_loss':
1017-
# armijo = True
1002+
if warmstartT:
1003+
T = [nx.outer(p, q) for q in ps]
10181004

10191005
cpt = 0
10201006
err_feature = 1
@@ -1030,11 +1016,19 @@ def fgw_barycenters(
10301016
Cprev = C
10311017
Xprev = X
10321018

1019+
if warmstartT:
1020+
T = [fused_gromov_wasserstein(
1021+
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1022+
G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1023+
else:
1024+
T = [fused_gromov_wasserstein(
1025+
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1026+
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1027+
# T is N,ns
10331028
if not fixed_features:
10341029
Ys_temp = [y.T for y in Ys]
10351030
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
1036-
1037-
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
1031+
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
10381032

10391033
if not fixed_structure:
10401034
T_temp = [t.T for t in T]
@@ -1044,15 +1038,6 @@ def fgw_barycenters(
10441038
elif loss_fun == 'kl_loss':
10451039
C = update_kl_loss(p, lambdas, T_temp, Cs)
10461040

1047-
if warmstartT:
1048-
T = [fused_gromov_wasserstein(
1049-
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1050-
G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1051-
else:
1052-
T = [fused_gromov_wasserstein(
1053-
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1054-
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1055-
# T is N,ns
10561041
err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
10571042
err_structure = nx.norm(C - Cprev)
10581043
if log:

test/test_gromov.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,6 +1425,13 @@ def test_fgw_barycenter(nx):
14251425
init_C /= init_C.max()
14261426
init_Cb = nx.from_numpy(init_C)
14271427

1428+
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None`
1429+
Xb, Cb = ot.gromov.fgw_barycenters(
1430+
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1431+
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
1432+
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
1433+
)
1434+
14281435
Xb, Cb = ot.gromov.fgw_barycenters(
14291436
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
14301437
alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
@@ -1437,12 +1444,20 @@ def test_fgw_barycenter(nx):
14371444
init_X = rng.randn(n_samples, ys.shape[1])
14381445
init_Xb = nx.from_numpy(init_X)
14391446

1447+
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None`
1448+
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1449+
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
1450+
fixed_structure=False, fixed_features=True, init_X=None,
1451+
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1452+
warmstartT=True, log=True, random_state=98765, verbose=True
1453+
)
14401454
Xb, Cb, logb = ot.gromov.fgw_barycenters(
14411455
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
14421456
fixed_structure=False, fixed_features=True, init_X=init_Xb,
14431457
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
14441458
warmstartT=True, log=True, random_state=98765, verbose=True
14451459
)
1460+
14461461
X, C = nx.to_numpy(Xb), nx.to_numpy(Cb)
14471462
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
14481463
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))

0 commit comments

Comments
 (0)