Skip to content

Commit 917b445

Browse files
correct independence of fgw barycenters to init
1 parent 1682b60 commit 917b445

File tree

3 files changed

+42
-24
lines changed

3 files changed

+42
-24
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
2222
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
2323
- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)
24+
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #564)
2425

2526
## 0.9.1
2627
*August 2023*

ot/gromov/_gw.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,13 +1005,15 @@ def fgw_barycenters(
10051005
else:
10061006
if init_X is None:
10071007
X = nx.zeros((N, d), type_as=ps[0])
1008+
10081009
else:
10091010
X = init_X
10101011

1011-
T = [nx.outer(p, q) for q in ps]
1012-
10131012
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
10141013

1014+
if warmstartT:
1015+
T = [nx.outer(p, q) for q in ps]
1016+
10151017
# removed since 0.9.2
10161018
#if loss_fun == 'kl_loss':
10171019
# armijo = True
@@ -1030,11 +1032,19 @@ def fgw_barycenters(
10301032
Cprev = C
10311033
Xprev = X
10321034

1035+
if warmstartT:
1036+
T = [fused_gromov_wasserstein(
1037+
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1038+
G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1039+
else:
1040+
T = [fused_gromov_wasserstein(
1041+
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1042+
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1043+
# T is N,ns
10331044
if not fixed_features:
10341045
Ys_temp = [y.T for y in Ys]
10351046
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
1036-
1037-
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
1047+
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
10381048

10391049
if not fixed_structure:
10401050
T_temp = [t.T for t in T]
@@ -1044,15 +1054,6 @@ def fgw_barycenters(
10441054
elif loss_fun == 'kl_loss':
10451055
C = update_kl_loss(p, lambdas, T_temp, Cs)
10461056

1047-
if warmstartT:
1048-
T = [fused_gromov_wasserstein(
1049-
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1050-
G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1051-
else:
1052-
T = [fused_gromov_wasserstein(
1053-
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
1054-
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1055-
# T is N,ns
10561057
err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
10571058
err_structure = nx.norm(C - Cprev)
10581059
if log:

test/test_gromov.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,24 +1425,40 @@ def test_fgw_barycenter(nx):
14251425
init_C /= init_C.max()
14261426
init_Cb = nx.from_numpy(init_C)
14271427

1428-
Xb, Cb = ot.gromov.fgw_barycenters(
1429-
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1430-
alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
1431-
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
1432-
)
1428+
try: # to raise warning when `fixed_structure=True`and `init_C=None`
1429+
Xb, Cb = ot.gromov.fgw_barycenters(
1430+
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1431+
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
1432+
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
1433+
)
1434+
except:
1435+
Xb, Cb = ot.gromov.fgw_barycenters(
1436+
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1437+
alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
1438+
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
1439+
)
14331440
Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
14341441
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
14351442
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
14361443

14371444
init_X = rng.randn(n_samples, ys.shape[1])
14381445
init_Xb = nx.from_numpy(init_X)
14391446

1440-
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1441-
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
1442-
fixed_structure=False, fixed_features=True, init_X=init_Xb,
1443-
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1444-
warmstartT=True, log=True, random_state=98765, verbose=True
1445-
)
1447+
try: # to raise warning when `fixed_features=True`and `init_X=None`
1448+
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1449+
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
1450+
fixed_structure=False, fixed_features=True, init_X=None,
1451+
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1452+
warmstartT=True, log=True, random_state=98765, verbose=True
1453+
)
1454+
except:
1455+
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1456+
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
1457+
fixed_structure=False, fixed_features=True, init_X=init_Xb,
1458+
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1459+
warmstartT=True, log=True, random_state=98765, verbose=True
1460+
)
1461+
14461462
X, C = nx.to_numpy(Xb), nx.to_numpy(Cb)
14471463
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
14481464
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))

0 commit comments

Comments
 (0)