From 81d0b15b109643b959d50a7c4a22dffd6480b0c2 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Sat, 12 Aug 2023 22:43:41 +0200 Subject: [PATCH 1/4] Initial implementation --- ot/gaussian.py | 14 ++++++++++++-- ot/utils.py | 6 ++++++ test/test_gaussian.py | 9 +++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index e83d5eee8..6334ca699 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -8,9 +8,10 @@ # # License: MIT License +import warnings + from .backend import get_backend -from .utils import dots -from .utils import list_to_array +from .utils import dots, is_all_finite, list_to_array def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): @@ -72,6 +73,7 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) + is_input_finite = is_all_finite(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) Cs12inv = nx.inv(Cs12) @@ -82,6 +84,9 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): b = mt - nx.dot(ms, A) + if is_input_finite and not is_all_finite(A, b): + warnings.warn("Warning: 'bures_wasserstein_mapping' caused numerical errors.") + if log: log = {} log['Cs12'] = Cs12 @@ -235,11 +240,16 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) + is_input_finite = is_all_finite(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) W = nx.sqrt(nx.norm(ms - mt)**2 + B) + + if is_input_finite and not is_all_finite(W): + warnings.warn("Warning: 'bures_wasserstein_distance' caused numerical errors.") + if log: log = {} log['Cs12'] = Cs12 diff --git a/ot/utils.py b/ot/utils.py index 1559eaaa6..d40e48a59 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -384,6 +384,12 @@ def dots(*args): return reduce(nx.dot, args) +def is_all_finite(*args): + r"""Tests element-wise for finiteness in all arguments.""" + nx = get_backend(*args) + return reduce(lambda f: nx.all(nx.isfinite(f)), args) + + def label_normalization(y, start=0): r""" Transform labels to start at a given value diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 5a021d004..29cfc4084 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -11,6 +11,7 @@ import ot from ot.datasets import make_data_classif +from ot.utils import is_all_finite def test_bures_wasserstein_mapping(nx): @@ -70,6 +71,14 @@ def test_empirical_bures_wasserstein_mapping(nx, bias): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) +def test_empirical_bures_wasserstein_mapping_instabilities_warning(): + Xs = np.random.rand(10, 1000) + Xt = np.random.rand(10, 1000) + with pytest.warns(): + A, b = ot.gaussian.empirical_bures_wasserstein_mapping(Xs, Xt) + assert not is_all_finite(A, b) + + def test_bures_wasserstein_distance(nx): ms, mt = np.array([0]), np.array([10]) Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32) From 16b38de15827eaa33568232ed4b29d56e16b2bb7 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Tue, 15 Aug 2023 23:19:23 +0200 Subject: [PATCH 2/4] nx.all is not defined, switching to nx.any --- ot/gaussian.py | 2 +- ot/utils.py | 2 +- test/test_gaussian.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 6334ca699..17916981e 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -247,7 +247,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) W = nx.sqrt(nx.norm(ms - mt)**2 + B) - if is_input_finite and not is_all_finite(W): + if is_input_finite and not nx.isfinite(W): warnings.warn("Warning: 'bures_wasserstein_distance' caused numerical errors.") if log: diff --git a/ot/utils.py b/ot/utils.py index d40e48a59..72df4294f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -387,7 +387,7 @@ def dots(*args): def is_all_finite(*args): r"""Tests element-wise for finiteness in all arguments.""" nx = get_backend(*args) - return reduce(lambda f: nx.all(nx.isfinite(f)), args) + return all(not nx.any(~nx.isfinite(arg)) for arg in args) def label_normalization(y, start=0): diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 29cfc4084..09d534f8f 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -72,8 +72,8 @@ def test_empirical_bures_wasserstein_mapping(nx, bias): def test_empirical_bures_wasserstein_mapping_instabilities_warning(): - Xs = np.random.rand(10, 1000) - Xt = np.random.rand(10, 1000) + Xs = np.random.rand(2000, 1000) + Xt = np.random.rand(2000, 1000) with pytest.warns(): A, b = ot.gaussian.empirical_bures_wasserstein_mapping(Xs, Xt) assert not is_all_finite(A, b) From a912de7bc66103514fba72ef73f222ee7deb59d0 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Wed, 16 Aug 2023 11:37:32 +0200 Subject: [PATCH 3/4] Rename warnings to be consistent with numpy --- ot/gaussian.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 17916981e..eb3a25c61 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -85,7 +85,7 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): b = mt - nx.dot(ms, A) if is_input_finite and not is_all_finite(A, b): - warnings.warn("Warning: 'bures_wasserstein_mapping' caused numerical errors.") + warnings.warn("Numerical errors encountered in ot.gaussian.bures_wasserstein_mapping") if log: log = {} @@ -248,7 +248,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): W = nx.sqrt(nx.norm(ms - mt)**2 + B) if is_input_finite and not nx.isfinite(W): - warnings.warn("Warning: 'bures_wasserstein_distance' caused numerical errors.") + warnings.warn("Numerical errors encountered in ot.gaussian.bures_wasserstein_distance") if log: log = {} From 5853a55839d368764d6bd803c2df9d54a5d7f672 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Wed, 16 Aug 2023 22:23:52 +0200 Subject: [PATCH 4/4] Finally found example where empirical cov is ill-formed --- ot/gaussian.py | 19 ++++++++++--------- test/test_gaussian.py | 9 +++++---- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index eb3a25c61..708f9eb16 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -73,7 +73,6 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) - is_input_finite = is_all_finite(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) Cs12inv = nx.inv(Cs12) @@ -84,9 +83,6 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): b = mt - nx.dot(ms, A) - if is_input_finite and not is_all_finite(A, b): - warnings.warn("Numerical errors encountered in ot.gaussian.bures_wasserstein_mapping") - if log: log = {} log['Cs12'] = Cs12 @@ -160,6 +156,7 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) + is_input_finite = is_all_finite(xs, xt) d = xs.shape[1] @@ -184,11 +181,19 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, if log: A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log) + else: + A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct) + + if is_input_finite and not is_all_finite(A, b): + warnings.warn( + "Numerical errors were encountered in ot.gaussian.empirical_bures_wasserstein_mapping. " + "Consider increasing the regularization parameter `reg`.") + + if log: log['Cs'] = Cs log['Ct'] = Ct return A, b, log else: - A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct) return A, b @@ -240,16 +245,12 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) - is_input_finite = is_all_finite(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) W = nx.sqrt(nx.norm(ms - mt)**2 + B) - if is_input_finite and not nx.isfinite(W): - warnings.warn("Numerical errors encountered in ot.gaussian.bures_wasserstein_distance") - if log: log = {} log['Cs12'] = Cs12 diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 09d534f8f..4e3c2df7b 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -71,11 +71,12 @@ def test_empirical_bures_wasserstein_mapping(nx, bias): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) -def test_empirical_bures_wasserstein_mapping_instabilities_warning(): - Xs = np.random.rand(2000, 1000) - Xt = np.random.rand(2000, 1000) +def test_empirical_bures_wasserstein_mapping_numerical_error_warning(): + rng = np.random.RandomState(42) + Xs = rng.rand(766, 800) * 5 + Xt = rng.rand(295, 800) * 2 with pytest.warns(): - A, b = ot.gaussian.empirical_bures_wasserstein_mapping(Xs, Xt) + A, b = ot.gaussian.empirical_bures_wasserstein_mapping(Xs, Xt, reg=1e-8) assert not is_all_finite(A, b)