Skip to content

Commit 882202c

Browse files
take into account comments
1 parent b2760c9 commit 882202c

File tree

2 files changed

+14
-31
lines changed

2 files changed

+14
-31
lines changed

ot/gromov/_gw.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,6 @@ def df(G):
166166
def df(G):
167167
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
168168

169-
# removed since 0.9.2
170-
# if loss_fun == 'kl_loss':
171-
# armijo = True # there is no closed form line-search with KL
172-
173169
if armijo:
174170
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
175171
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
@@ -478,10 +474,6 @@ def df(G):
478474
def df(G):
479475
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
480476

481-
# removed since 0.9.2
482-
# if loss_fun == 'kl_loss':
483-
# armijo = True # there is no closed form line-search with KL
484-
485477
if armijo:
486478
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
487479
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
@@ -827,10 +819,6 @@ def gromov_barycenters(
827819
else:
828820
C = init_C
829821

830-
# removed since 0.9.2
831-
# if loss_fun == 'kl_loss':
832-
# armijo = True
833-
834822
cpt = 0
835823
err = 1
836824

@@ -1014,10 +1002,6 @@ def fgw_barycenters(
10141002
if warmstartT:
10151003
T = [nx.outer(p, q) for q in ps]
10161004

1017-
# removed since 0.9.2
1018-
# if loss_fun == 'kl_loss':
1019-
# armijo = True
1020-
10211005
cpt = 0
10221006
err_feature = 1
10231007
err_structure = 1

test/test_gromov.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,39 +1425,38 @@ def test_fgw_barycenter(nx):
14251425
init_C /= init_C.max()
14261426
init_Cb = nx.from_numpy(init_C)
14271427

1428-
try: # to raise warning when `fixed_structure=True`and `init_C=None`
1428+
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None`
14291429
Xb, Cb = ot.gromov.fgw_barycenters(
14301430
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
14311431
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
14321432
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
14331433
)
1434-
except ot.utils.UndefinedParameter:
1435-
Xb, Cb = ot.gromov.fgw_barycenters(
1436-
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1437-
alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
1438-
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
1439-
)
1434+
1435+
Xb, Cb = ot.gromov.fgw_barycenters(
1436+
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1437+
alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
1438+
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
1439+
)
14401440
Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
14411441
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
14421442
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
14431443

14441444
init_X = rng.randn(n_samples, ys.shape[1])
14451445
init_Xb = nx.from_numpy(init_X)
14461446

1447-
try: # to raise warning when `fixed_features=True`and `init_X=None`
1447+
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None`
14481448
Xb, Cb, logb = ot.gromov.fgw_barycenters(
14491449
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
14501450
fixed_structure=False, fixed_features=True, init_X=None,
14511451
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
14521452
warmstartT=True, log=True, random_state=98765, verbose=True
14531453
)
1454-
except ot.utils.UndefinedParameter:
1455-
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1456-
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
1457-
fixed_structure=False, fixed_features=True, init_X=init_Xb,
1458-
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1459-
warmstartT=True, log=True, random_state=98765, verbose=True
1460-
)
1454+
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1455+
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
1456+
fixed_structure=False, fixed_features=True, init_X=init_Xb,
1457+
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1458+
warmstartT=True, log=True, random_state=98765, verbose=True
1459+
)
14611460

14621461
X, C = nx.to_numpy(Xb), nx.to_numpy(Cb)
14631462
np.testing.assert_allclose(C.shape, (n_samples, n_samples))

0 commit comments

Comments
 (0)