Skip to content

Commit 6f4a40d

Browse files
[WIP] Fix matrix feature shape in entropic FGW barycenters (#575)
* fix matrix feature shape in entropic FGW barycenter * fix matrix feature shape in entropic FGW barycenter * complete tests for gromov.bregman
1 parent fcd8f05 commit 6f4a40d

File tree

3 files changed

+35
-9
lines changed

3 files changed

+35
-9
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
2525
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
2626
- Create `ot/bregman/`repository (Issue #567, PR #569)
27+
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
2728

2829

2930
## 0.9.1

ot/gromov/_bregman.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,8 @@ def entropic_fused_gromov_barycenters(
961961
else:
962962
Y = init_Y
963963

964-
T = [nx.outer(p_, p) for p_ in ps]
964+
if warmstartT:
965+
T = [nx.outer(p_, p) for p_ in ps]
965966

966967
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
967968

@@ -971,9 +972,6 @@ def entropic_fused_gromov_barycenters(
971972
err_feature = 1
972973
err_structure = 1
973974

974-
if warmstartT:
975-
T = [None] * S
976-
977975
if log:
978976
log_ = {}
979977
log_['err_feature'] = []
@@ -987,7 +985,7 @@ def entropic_fused_gromov_barycenters(
987985
if warmstartT:
988986
T = [entropic_fused_gromov_wasserstein(
989987
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
990-
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
988+
T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
991989

992990
else:
993991
T = [entropic_fused_gromov_wasserstein(
@@ -1001,7 +999,7 @@ def entropic_fused_gromov_barycenters(
1001999

10021000
Ys_temp = [y.T for y in Ys]
10031001
T_temp = [Ts.T for Ts in T]
1004-
Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p)
1002+
Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p).T
10051003
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
10061004

10071005
if cpt % 10 == 0:

test/test_gromov.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,12 @@ def test_entropic_proximal_gromov(nx):
459459

460460
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
461461

462+
with pytest.raises(ValueError):
463+
loss_fun = 'weird_loss_fun'
464+
G, log = ot.gromov.entropic_gromov_wasserstein(
465+
C1, C2, None, q, loss_fun, symmetric=None, G0=G0,
466+
epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=True, numItermax=1)
467+
462468
G, log = ot.gromov.entropic_gromov_wasserstein(
463469
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
464470
epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=True, numItermax=1)
@@ -606,6 +612,12 @@ def test_entropic_fgw(nx):
606612

607613
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
608614

615+
with pytest.raises(ValueError):
616+
loss_fun = 'weird_loss_fun'
617+
G, log = ot.gromov.entropic_fused_gromov_wasserstein(
618+
M, C1, C2, None, None, loss_fun, symmetric=None, G0=G0,
619+
epsilon=1e-1, max_iter=10, verbose=True, log=True)
620+
609621
G, log = ot.gromov.entropic_fused_gromov_wasserstein(
610622
M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0,
611623
epsilon=1e-1, max_iter=10, verbose=True, log=True)
@@ -812,20 +824,28 @@ def test_entropic_fgw_barycenter(nx):
812824
C2 = ot.dist(Xt)
813825
p1 = ot.unif(ns)
814826
p2 = ot.unif(nt)
815-
n_samples = 2
827+
n_samples = 3
816828
p = ot.unif(n_samples)
817829

818830
ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
819831

832+
with pytest.raises(ValueError):
833+
loss_fun = 'weird_loss_fun'
834+
X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
835+
n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], loss_fun, 0.1,
836+
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42,
837+
solver='PPA', numItermax=10, log=True
838+
)
839+
820840
X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
821841
n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', 0.1,
822842
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42,
823-
solver='PPA', numItermax=1, log=True
843+
solver='PPA', numItermax=10, log=True
824844
)
825845
Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters(
826846
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 0.1,
827847
max_iter=10, tol=1e-3, verbose=False, warmstartT=True, random_state=42,
828-
solver='PPA', numItermax=1, log=False)
848+
solver='PPA', numItermax=10, log=False)
829849
Xb, Cb = nx.to_numpy(Xb, Cb)
830850

831851
np.testing.assert_allclose(C, Cb, atol=1e-06)
@@ -1052,6 +1072,13 @@ def test_gromov_entropic_barycenter(nx):
10521072

10531073
C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
10541074

1075+
with pytest.raises(ValueError):
1076+
loss_fun = 'weird_loss_fun'
1077+
Cb = ot.gromov.entropic_gromov_barycenters(
1078+
n_samples, [C1, C2], None, p, [.5, .5], loss_fun, 1e-3,
1079+
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42
1080+
)
1081+
10551082
Cb = ot.gromov.entropic_gromov_barycenters(
10561083
n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3,
10571084
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42

0 commit comments

Comments
 (0)