Skip to content

Commit 9dd4d8d

Browse files
authored
[MRG] Issue warning when numerical error happens in empirical_bures_wasserstein_mapping (#503)
* Initial implementation * nx.all is not defined, switching to nx.any * Rename warnings to be consistent with numpy * Finally found example where empirical cov is ill-formed
1 parent 579c7b6 commit 9dd4d8d

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

ot/gaussian.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
#
99
# License: MIT License
1010

11+
import warnings
12+
1113
from .backend import get_backend
12-
from .utils import dots
13-
from .utils import list_to_array
14+
from .utils import dots, is_all_finite, list_to_array
1415

1516

1617
def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
@@ -155,6 +156,7 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
155156
"""
156157
xs, xt = list_to_array(xs, xt)
157158
nx = get_backend(xs, xt)
159+
is_input_finite = is_all_finite(xs, xt)
158160

159161
d = xs.shape[1]
160162

@@ -179,11 +181,19 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
179181

180182
if log:
181183
A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log)
184+
else:
185+
A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct)
186+
187+
if is_input_finite and not is_all_finite(A, b):
188+
warnings.warn(
189+
"Numerical errors were encountered in ot.gaussian.empirical_bures_wasserstein_mapping. "
190+
"Consider increasing the regularization parameter `reg`.")
191+
192+
if log:
182193
log['Cs'] = Cs
183194
log['Ct'] = Ct
184195
return A, b, log
185196
else:
186-
A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct)
187197
return A, b
188198

189199

@@ -240,6 +250,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
240250

241251
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
242252
W = nx.sqrt(nx.norm(ms - mt)**2 + B)
253+
243254
if log:
244255
log = {}
245256
log['Cs12'] = Cs12

ot/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,12 @@ def dots(*args):
384384
return reduce(nx.dot, args)
385385

386386

387+
def is_all_finite(*args):
388+
r"""Tests element-wise for finiteness in all arguments."""
389+
nx = get_backend(*args)
390+
return all(not nx.any(~nx.isfinite(arg)) for arg in args)
391+
392+
387393
def label_normalization(y, start=0):
388394
r""" Transform labels to start at a given value
389395

test/test_gaussian.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import ot
1313
from ot.datasets import make_data_classif
14+
from ot.utils import is_all_finite
1415

1516

1617
def test_bures_wasserstein_mapping(nx):
@@ -70,6 +71,15 @@ def test_empirical_bures_wasserstein_mapping(nx, bias):
7071
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
7172

7273

74+
def test_empirical_bures_wasserstein_mapping_numerical_error_warning():
75+
rng = np.random.RandomState(42)
76+
Xs = rng.rand(766, 800) * 5
77+
Xt = rng.rand(295, 800) * 2
78+
with pytest.warns():
79+
A, b = ot.gaussian.empirical_bures_wasserstein_mapping(Xs, Xt, reg=1e-8)
80+
assert not is_all_finite(A, b)
81+
82+
7383
def test_bures_wasserstein_distance(nx):
7484
ms, mt = np.array([0]), np.array([10])
7585
Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)

0 commit comments

Comments
 (0)