From 7a4dda0ae998c0a5f3d5c4a5deb5f7e58830fc69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 17 Jun 2024 18:36:37 +0200 Subject: [PATCH 1/2] correct bugs with gw barycenter on 1 input --- ot/gromov/_bregman.py | 19 ++++++--- ot/gromov/_gw.py | 20 +++++++--- test/gromov/test_bregman.py | 75 ++++++++++++++++++++++++++++++++++++ test/gromov/test_gw.py | 77 ++++++++++++++++++++++++++++++++++++- 4 files changed, 179 insertions(+), 12 deletions(-) diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index df4ba0ae3..6bb7a675a 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -735,10 +735,15 @@ def entropic_gromov_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) + if isinstance(Cs[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: @@ -1620,11 +1625,15 @@ def entropic_fused_gromov_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) - Ys = list_to_array(*Ys) + if isinstance(Cs[0], list) or isinstance(Ys[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs, *Ys] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 46e1ddfe8..9dbc6b19e 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -808,13 +808,19 @@ def gromov_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) + if isinstance(Cs[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: + arr.append(list_to_array(p)) else: p = unif(N, type_as=Cs[0]) @@ -1014,11 +1020,15 @@ def fgw_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) - Ys = list_to_array(*Ys) + if isinstance(Cs[0], list) or isinstance(Ys[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs, *Ys] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: diff --git a/test/gromov/test_bregman.py b/test/gromov/test_bregman.py index 4baf3ce10..71e55b1ce 100644 --- a/test/gromov/test_bregman.py +++ b/test/gromov/test_bregman.py @@ -792,6 +792,46 @@ def test_entropic_fgw_barycenter(nx): np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb, init_Yb) + # test edge cases for fgw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1_list = [list(c) for c in C1b] + _, _, _ = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb], [C1_list], [p1b], lambdas=None, + fixed_structure=False, fixed_features=False, + init_Y=None, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1b) + p2_list = list(p2b) + _, _, _ = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1_list, p2_list], lambdas=[0.5, 0.5], + fixed_structure=False, fixed_features=False, + init_Y=None, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + + # unique input structure + X, C = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys], [C1], [p1], lambdas=None, + fixed_structure=False, fixed_features=False, + init_Y=init_Y, p=p, max_iter=10, tol=1e-3, + warmstartT=True, log=False, random_state=98765, verbose=True + ) + + Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb], [C1b], [p1b], lambdas=None, + fixed_structure=False, fixed_features=False, + init_Y=init_Yb, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=False, random_state=98765, verbose=True + ) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(X, Xb, atol=1e-06) + @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(nx): @@ -886,6 +926,41 @@ def test_gromov_entropic_barycenter(nx): np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + # test edge cases for gw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1_list = [list(c) for c in C1b] + _, _ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1_list], [p1b], pb, None, 'square_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, + random_state=42, init_C=None, log=True + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1b) + p2_list = list(p2b) + _, _ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1_list, p2_list], pb, None, + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, + verbose=True, random_state=42, init_Cb=None, log=True + ) + + # unique input structure + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1], [p1], p, None, 'square_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, + init_C=None, log=False) + + Cbb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b], [p1b], pb, [1.], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, + random_state=42, init_Cb=None, log=False + ) + + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + def test_not_implemented_solver(): # test sinkhorn diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 0008cebce..e76a33dcf 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -429,6 +429,40 @@ def test_gromov_barycenter(nx): np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + # test edge cases for gw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1_list = [list(c) for c in C1] + _ = ot.gromov.gromov_barycenters( + n_samples, [C1_list], None, p, None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1) + p2_list = list(p2) + _ = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1_list, p2_list], p, None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + + # unique input structure + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1], None, p, None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b], None, None, [1.], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + def test_fgw(nx): n_samples = 20 # nb samples @@ -815,7 +849,7 @@ def test_fgw_barycenter(nx): X, C, log = ot.gromov.fgw_barycenters( n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, init_X=X, warmstartT=True, random_state=12345, log=True ) @@ -823,7 +857,7 @@ def test_fgw_barycenter(nx): X, C, log = ot.gromov.fgw_barycenters( n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, init_X=X, warmstartT=True, random_state=12345, log=True, verbose=True ) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) @@ -832,3 +866,42 @@ def test_fgw_barycenter(nx): # test correspondance with utils function recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['T'], [C1, C2]) np.testing.assert_allclose(C, recovered_C) + + # test edge cases for fgw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1b_list = [list(c) for c in C1b] + _, _, _ = ot.gromov.fgw_barycenters( + n_samples, [ysb], [C1b_list], [p1b], None, 0.5, + fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=Cb, + init_X=Xb, warmstartT=True, random_state=12345, log=True, verbose=True + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1) + p2_list = list(p2) + _, _, _ = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1_list, p2_list], None, 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=Cb, + init_X=Xb, warmstartT=True, random_state=12345, log=True, verbose=True + ) + + # unique input structure + X, C = ot.gromov.fgw_barycenters( + n_samples, [ys], [C1], [p1], None, 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + warmstartT=True, random_state=12345, log=False, verbose=False + ) + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb], [C1b], [p1b], [1.], 0.5, + fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + warmstartT=True, random_state=12345, log=False, verbose=False + ) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(X, Xb, atol=1e-06) From 2f41b5d9241055a181100832e5607c9abd49a5d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 19 Jun 2024 16:49:10 +0200 Subject: [PATCH 2/2] merge --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 3908d079c..e6e6ff4d4 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,7 @@ - Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) - Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) - Split `test/test_gromov.py` into `test/gromov/` (PR #619) +- Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) ## 0.9.3 *January 2024*