Skip to content

Commit eb95719

Browse files
author
William de Vazelhes
committed
ENH: Deal better with errors and skggm/scikit-learn
1 parent 60866cb commit eb95719

File tree

2 files changed

+147
-36
lines changed

2 files changed

+147
-36
lines changed

metric_learn/sdml.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,11 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
6060

6161
def _fit(self, pairs, y):
6262
if not HAS_SKGGM:
63-
msg = ("Warning, skggm is not installed, so SDML will use "
64-
"scikit-learn's graphical_lasso method. It can fail to converge"
65-
"on some non SPD matrices where skggm would converge. If so, "
66-
"try to install skggm. (see the README.md for the right "
67-
"version.)")
68-
warnings.warn(msg)
63+
if self.verbose:
64+
print("SDML will use scikit-learn's graphical lasso solver.")
6965
else:
70-
print("SDML will use skggm's solver.")
66+
if self.verbose:
67+
print("SDML will use skggm's graphical lasso solver.")
7168
pairs, y = self._prepare_inputs(pairs, y,
7269
type_of_inputs='tuples')
7370

@@ -93,15 +90,39 @@ def _fit(self, pairs, y):
9390
"`balance_param` and/or to set use_covariance=False.",
9491
ConvergenceWarning)
9592
sigma0 = (V * (w - min(0, np.min(w)) + 1e-10)).dot(V.T)
96-
if HAS_SKGGM:
97-
theta0 = pinvh(sigma0)
98-
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
99-
msg=self.verbose,
100-
Theta0=theta0, Sigma0=sigma0)
101-
else:
102-
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
103-
verbose=self.verbose,
104-
cov_init=sigma0)
93+
try:
94+
if HAS_SKGGM:
95+
theta0 = pinvh(sigma0)
96+
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
97+
msg=self.verbose,
98+
Theta0=theta0, Sigma0=sigma0)
99+
else:
100+
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
101+
verbose=self.verbose,
102+
cov_init=sigma0)
103+
raised_error = None
104+
w_mahalanobis, _ = np.linalg.eigh(M)
105+
not_spd = any(w_mahalanobis < 0.)
106+
not_finite = not np.isfinite(M).all()
107+
except Exception as e:
108+
raised_error = e
109+
not_spd = False # not_spd not applicable here so we set to False
110+
not_finite = False # not_finite not applicable here so we set to False
111+
if raised_error is not None or not_spd or not_finite:
112+
msg = ("There was a problem in SDML when using {}'s graphical "
113+
"lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn")
114+
if not HAS_SKGGM:
115+
skggm_advice = (" skggm's graphical lasso can sometimes converge "
116+
"on non SPD cases where scikit-learn's graphical "
117+
"lasso fails to converge. Try to install skggm and "
118+
"rerun the algorithm (see the README.md for the "
119+
"right version of skggm).")
120+
msg += skggm_advice
121+
if raised_error is not None:
122+
msg += " The following error message was thrown: {}.".format(
123+
raised_error)
124+
raise RuntimeError(msg)
125+
105126
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
106127
return self
107128

test/metric_learn_test.py

Lines changed: 110 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -155,28 +155,89 @@ def test_no_twice_same_objective(capsys):
155155
class TestSDML(MetricTestCase):
156156

157157
@pytest.mark.skipif(HAS_SKGGM,
158-
reason="The warning will be thrown only if skggm is "
158+
reason="The warning can be thrown only if skggm is "
159159
"not installed.")
160-
def test_raises_warning_msg_not_installed_skggm(self):
160+
def test_sdml_supervised_raises_warning_msg_not_installed_skggm(self):
161161
"""Tests that the right warning message is raised if someone tries to
162-
use SDML but has not installed skggm"""
162+
use SDML_Supervised but has not installed skggm, and that the algorithm
163+
fails to converge"""
163164
# TODO: remove if we don't need skggm anymore
164-
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
165+
# load_iris: dataset where we know scikit-learn's graphical lasso fails
166+
# with a Floating Point error
167+
X, y = load_iris(return_X_y=True)
168+
sdml_supervised = SDML_Supervised(balance_param=0.5, use_cov=True,
169+
sparsity_param=0.01)
170+
msg = ("There was a problem in SDML when using scikit-learn's graphical "
171+
"lasso solver. skggm's graphical lasso can sometimes converge on "
172+
"non SPD cases where scikit-learn's graphical lasso fails to "
173+
"converge. Try to install skggm and rerun the algorithm (see "
174+
"the README.md for the right version of skggm). The following "
175+
"error message was thrown:")
176+
with pytest.raises(RuntimeError) as raised_error:
177+
sdml_supervised.fit(X, y)
178+
assert str(raised_error.value).startswith(msg)
179+
180+
@pytest.mark.skipif(HAS_SKGGM,
181+
reason="The warning can be thrown only if skggm is "
182+
"not installed.")
183+
def test_sdml_raises_warning_msg_not_installed_skggm(self):
184+
"""Tests that the right warning message is raised if someone tries to
185+
use SDML but has not installed skggm, and that the algorithm fails to
186+
converge"""
187+
# TODO: remove if we don't need skggm anymore
188+
# case on which we know that scikit-learn's graphical lasso fails
189+
# because it will return a non SPD matrix
190+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
165191
y_pairs = [1, -1]
166-
X, y = make_classification(random_state=42)
167-
sdml = SDML()
168-
sdml_supervised = SDML_Supervised(use_cov=False, balance_param=1e-5)
169-
msg = ("Warning, skggm is not installed, so SDML will use "
170-
"scikit-learn's graphical_lasso method. It can fail to converge"
171-
"on some non SPD matrices where skggm would converge. If so, "
172-
"try to install skggm. (see the README.md for the right "
173-
"version.)")
174-
with pytest.warns(None) as record:
192+
sdml = SDML(use_cov=False, balance_param=100, verbose=True)
193+
194+
msg = ("There was a problem in SDML when using scikit-learn's graphical "
195+
"lasso solver. skggm's graphical lasso can sometimes converge on "
196+
"non SPD cases where scikit-learn's graphical lasso fails to "
197+
"converge. Try to install skggm and rerun the algorithm (see "
198+
"the README.md for the right version of skggm).")
199+
with pytest.raises(RuntimeError) as raised_error:
175200
sdml.fit(pairs, y_pairs)
176-
assert str(record[0].message) == msg
177-
with pytest.warns(None) as record:
201+
assert msg == str(raised_error.value)
202+
203+
@pytest.mark.skipif(not HAS_SKGGM,
204+
reason="The warning can be thrown only if skggm is "
205+
"installed.")
206+
def test_sdml_raises_warning_msg_installed_skggm(self):
207+
"""Tests that the right warning message is raised if someone tries to
208+
use SDML but has not installed skggm, and that the algorithm fails to
209+
converge"""
210+
# TODO: remove if we don't need skggm anymore
211+
# case on which we know that skggm's graphical lasso fails
212+
# because it will return non finite values
213+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
214+
y_pairs = [1, -1]
215+
sdml = SDML(use_cov=False, balance_param=100, verbose=True)
216+
217+
msg = ("There was a problem in SDML when using skggm's graphical "
218+
"lasso solver.")
219+
with pytest.raises(RuntimeError) as raised_error:
220+
sdml.fit(pairs, y_pairs)
221+
assert msg == str(raised_error.value)
222+
223+
@pytest.mark.skipif(not HAS_SKGGM,
224+
reason="The warning can be thrown only if skggm is "
225+
"installed.")
226+
def test_sdml_supervised_raises_warning_msg_installed_skggm(self):
227+
"""Tests that the right warning message is raised if someone tries to
228+
use SDML_Supervised but has not installed skggm, and that the algorithm
229+
fails to converge"""
230+
# TODO: remove if we don't need skggm anymore
231+
# case on which we know that skggm's graphical lasso fails
232+
# because it will return non finite values
233+
X, y = load_iris(return_X_y=True)
234+
sdml_supervised = SDML_Supervised(balance_param=0.5, use_cov=True,
235+
sparsity_param=0.01)
236+
msg = ("There was a problem in SDML when using skggm's graphical "
237+
"lasso solver.")
238+
with pytest.raises(RuntimeError) as raised_error:
178239
sdml_supervised.fit(X, y)
179-
assert str(record[0].message) == msg
240+
assert msg == str(raised_error.value)
180241

181242
@pytest.mark.skipif(not HAS_SKGGM,
182243
reason="It's only in the case where skggm is installed"
@@ -271,10 +332,10 @@ def test_verbose_has_installed_skggm_sdml(capsys):
271332
# TODO: remove if we don't need skggm anymore
272333
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
273334
y_pairs = [1, -1]
274-
sdml = SDML()
335+
sdml = SDML(verbose=True)
275336
sdml.fit(pairs, y_pairs)
276337
out, _ = capsys.readouterr()
277-
assert "SDML will use skggm's solver." in out
338+
assert "SDML will use skggm's graphical lasso solver." in out
278339

279340

280341
@pytest.mark.skipif(not HAS_SKGGM,
@@ -285,10 +346,39 @@ def test_verbose_has_installed_skggm_sdml_supervised(capsys):
285346
# skggm's solver is used (when they use SDML_Supervised)
286347
# TODO: remove if we don't need skggm anymore
287348
X, y = make_classification(random_state=42)
288-
sdml = SDML_Supervised()
349+
sdml = SDML_Supervised(verbose=True)
350+
sdml.fit(X, y)
351+
out, _ = capsys.readouterr()
352+
assert "SDML will use skggm's graphical lasso solver." in out
353+
354+
355+
@pytest.mark.skipif(HAS_SKGGM,
356+
reason='The message should be printed only if skggm is '
357+
'not installed.')
358+
def test_verbose_has_not_installed_skggm_sdml(capsys):
359+
# Test that if users have installed skggm, a message is printed telling them
360+
# skggm's solver is used (when they use SDML)
361+
# TODO: remove if we don't need skggm anymore
362+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
363+
y_pairs = [1, -1]
364+
sdml = SDML(verbose=True)
365+
sdml.fit(pairs, y_pairs)
366+
out, _ = capsys.readouterr()
367+
assert "SDML will use scikit-learn's graphical lasso solver." in out
368+
369+
370+
@pytest.mark.skipif(HAS_SKGGM,
371+
reason='The message should be printed only if skggm is '
372+
'not installed.')
373+
def test_verbose_has_not_installed_skggm_sdml_supervised(capsys):
374+
# Test that if users have installed skggm, a message is printed telling them
375+
# skggm's solver is used (when they use SDML_Supervised)
376+
# TODO: remove if we don't need skggm anymore
377+
X, y = make_classification(random_state=42)
378+
sdml = SDML_Supervised(verbose=True, balance_param=1e-5, use_cov=False)
289379
sdml.fit(X, y)
290380
out, _ = capsys.readouterr()
291-
assert "SDML will use skggm's solver." in out
381+
assert "SDML will use scikit-learn's graphical lasso solver." in out
292382

293383

294384
class TestNCA(MetricTestCase):

0 commit comments

Comments
 (0)