diff --git a/ot/gaussian.py b/ot/gaussian.py index e83d5eee8..708f9eb16 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -8,9 +8,10 @@ # # License: MIT License +import warnings + from .backend import get_backend -from .utils import dots -from .utils import list_to_array +from .utils import dots, is_all_finite, list_to_array def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): @@ -155,6 +156,7 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) + is_input_finite = is_all_finite(xs, xt) d = xs.shape[1] @@ -179,11 +181,19 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, if log: A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log) + else: + A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct) + + if is_input_finite and not is_all_finite(A, b): + warnings.warn( + "Numerical errors were encountered in ot.gaussian.empirical_bures_wasserstein_mapping. " + "Consider increasing the regularization parameter `reg`.") + + if log: log['Cs'] = Cs log['Ct'] = Ct return A, b, log else: - A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct) return A, b @@ -240,6 +250,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) W = nx.sqrt(nx.norm(ms - mt)**2 + B) + if log: log = {} log['Cs12'] = Cs12 diff --git a/ot/utils.py b/ot/utils.py index 1559eaaa6..72df4294f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -384,6 +384,12 @@ def dots(*args): return reduce(nx.dot, args) +def is_all_finite(*args): + r"""Tests element-wise for finiteness in all arguments.""" + nx = get_backend(*args) + return all(not nx.any(~nx.isfinite(arg)) for arg in args) + + def label_normalization(y, start=0): r""" Transform labels to start at a given value diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 5a021d004..4e3c2df7b 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -11,6 +11,7 @@ import ot from ot.datasets import make_data_classif +from ot.utils import is_all_finite def test_bures_wasserstein_mapping(nx): @@ -70,6 +71,15 @@ def test_empirical_bures_wasserstein_mapping(nx, bias): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) +def test_empirical_bures_wasserstein_mapping_numerical_error_warning(): + rng = np.random.RandomState(42) + Xs = rng.rand(766, 800) * 5 + Xt = rng.rand(295, 800) * 2 + with pytest.warns(): + A, b = ot.gaussian.empirical_bures_wasserstein_mapping(Xs, Xt, reg=1e-8) + assert not is_all_finite(A, b) + + def test_bures_wasserstein_distance(nx): ms, mt = np.array([0]), np.array([10]) Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)