Skip to content

Commit 49f3b9e

Browse files
author
William de Vazelhes
committed
Put back right indent
1 parent b3bf6a8 commit 49f3b9e

File tree

1 file changed

+104
-104
lines changed

1 file changed

+104
-104
lines changed

test/metric_learn_test.py

Lines changed: 104 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -150,112 +150,112 @@ def test_no_twice_same_objective(capsys):
150150

151151
class TestSDML(MetricTestCase):
152152

153-
@pytest.mark.skipif(has_installed_skggm(),
154-
reason="The warning will be thrown only if skggm is "
155-
"not installed.")
156-
def test_raises_warning_msg_not_installed_skggm(self):
157-
"""Tests that the right warning message is raised if someone tries to
158-
use SDML but has not installed skggm"""
159-
# TODO: remove if we don't need skggm anymore
160-
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
161-
y_pairs = [1, -1]
162-
X, y = make_classification(random_state=42)
153+
@pytest.mark.skipif(has_installed_skggm(),
154+
reason="The warning will be thrown only if skggm is "
155+
"not installed.")
156+
def test_raises_warning_msg_not_installed_skggm(self):
157+
"""Tests that the right warning message is raised if someone tries to
158+
use SDML but has not installed skggm"""
159+
# TODO: remove if we don't need skggm anymore
160+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
161+
y_pairs = [1, -1]
162+
X, y = make_classification(random_state=42)
163+
sdml = SDML()
164+
sdml_supervised = SDML_Supervised(use_cov=False, balance_param=1e-5)
165+
msg = ("Warning, skggm is not installed, so SDML will use "
166+
"scikit-learn's graphical_lasso method. It can fail to converge"
167+
"on some non SPD matrices where skggm would converge. If so, "
168+
"try to install skggm. (see the README.md for the right "
169+
"version.)")
170+
with pytest.warns(None) as record:
171+
sdml.fit(pairs, y_pairs)
172+
assert str(record[0].message) == msg
173+
with pytest.warns(None) as record:
174+
sdml_supervised.fit(X, y)
175+
assert str(record[0].message) == msg
176+
177+
@pytest.mark.skipif(not has_installed_skggm(),
178+
reason="It's only in the case where skggm is installed"
179+
"that no warning should be thrown.")
180+
def test_raises_no_warning_installed_skggm(self):
181+
# otherwise we should be able to instantiate and fit SDML and it
182+
# should raise no warning
183+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
184+
y_pairs = [1, -1]
185+
X, y = make_classification(random_state=42)
186+
with pytest.warns(None) as record:
163187
sdml = SDML()
164-
sdml_supervised = SDML_Supervised(use_cov=False, balance_param=1e-5)
165-
msg = ("Warning, skggm is not installed, so SDML will use "
166-
"scikit-learn's graphical_lasso method. It can fail to converge"
167-
"on some non SPD matrices where skggm would converge. If so, "
168-
"try to install skggm. (see the README.md for the right "
169-
"version.)")
170-
with pytest.warns(None) as record:
171-
sdml.fit(pairs, y_pairs)
172-
assert str(record[0].message) == msg
173-
with pytest.warns(None) as record:
174-
sdml_supervised.fit(X, y)
175-
assert str(record[0].message) == msg
176-
177-
@pytest.mark.skipif(not has_installed_skggm(),
178-
reason="It's only in the case where skggm is installed"
179-
"that no warning should be thrown.")
180-
def test_raises_no_warning_installed_skggm(self):
181-
# otherwise we should be able to instantiate and fit SDML and it
182-
# should raise no warning
183-
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
184-
y_pairs = [1, -1]
185-
X, y = make_classification(random_state=42)
186-
with pytest.warns(None) as record:
187-
sdml = SDML()
188-
sdml.fit(pairs, y_pairs)
189-
assert len(record) == 0
190-
with pytest.warns(None) as record:
191-
sdml = SDML_Supervised(use_cov=False, balance_param=1e-5)
192-
sdml.fit(X, y)
193-
assert len(record) == 0
194-
195-
def test_iris(self):
196-
# Note: this is a flaky test, which fails for certain seeds.
197-
# TODO: un-flake it!
198-
rs = np.random.RandomState(5555)
199-
200-
sdml = SDML_Supervised(num_constraints=1500, use_cov=False,
201-
balance_param=5e-5)
202-
sdml.fit(self.iris_points, self.iris_labels, random_state=rs)
203-
csep = class_separation(sdml.transform(self.iris_points),
204-
self.iris_labels)
205-
self.assertLess(csep, 0.22)
206-
207-
def test_deprecation_num_labeled(self):
208-
# test that a deprecation message is thrown if num_labeled is set at
209-
# initialization
210-
# TODO: remove in v.0.6
211-
X, y = make_classification(random_state=42)
212-
sdml_supervised = SDML_Supervised(num_labeled=np.inf, use_cov=False,
213-
balance_param=5e-5)
214-
msg = ('"num_labeled" parameter is not used.'
215-
' It has been deprecated in version 0.5.0 and will be'
216-
'removed in 0.6.0')
217-
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)
218-
219-
def test_sdml_raises_warning_non_psd(self):
220-
"""Tests that SDML raises a warning on a toy example where we know the
221-
pseudo-covariance matrix is not PSD"""
222-
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
223-
y = [1, -1]
224-
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
225-
msg = ("Warning, the input matrix of graphical lasso is not "
226-
"positive semi-definite (PSD). The algorithm may diverge, "
227-
"and lead to degenerate solutions. "
228-
"To prevent that, try to decrease the balance parameter "
229-
"`balance_param` and/or to set use_covariance=False.")
230-
with pytest.warns(ConvergenceWarning) as raised_warning:
231-
try:
232-
sdml.fit(pairs, y)
233-
except Exception:
234-
pass
235-
# we assert that this warning is in one of the warning raised by the
236-
# estimator
237-
assert msg in list(map(lambda w: str(w.message), raised_warning))
238-
239-
def test_sdml_converges_if_psd(self):
240-
"""Tests that sdml converges on a simple problem where we know the
241-
pseudo-covariance matrix is PSD"""
242-
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
243-
y = [1, -1]
244-
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
245-
sdml.fit(pairs, y)
246-
assert np.isfinite(sdml.get_mahalanobis_matrix()).all()
247-
248-
@pytest.mark.skipif(not has_installed_skggm(),
249-
reason="sklearn's graphical_lasso can sometimes not "
250-
"work on some non SPD problems. We test that "
251-
"is works only if skggm is installed.")
252-
def test_sdml_works_on_non_spd_pb_with_skggm(self):
253-
"""Test that SDML works on a certain non SPD problem on which we know
254-
it should work, but scikit-learn's graphical_lasso does not work"""
255-
X, y = load_iris(return_X_y=True)
256-
sdml = SDML_Supervised(balance_param=0.5, sparsity_param=0.01,
257-
use_cov=True)
188+
sdml.fit(pairs, y_pairs)
189+
assert len(record) == 0
190+
with pytest.warns(None) as record:
191+
sdml = SDML_Supervised(use_cov=False, balance_param=1e-5)
258192
sdml.fit(X, y)
193+
assert len(record) == 0
194+
195+
def test_iris(self):
196+
# Note: this is a flaky test, which fails for certain seeds.
197+
# TODO: un-flake it!
198+
rs = np.random.RandomState(5555)
199+
200+
sdml = SDML_Supervised(num_constraints=1500, use_cov=False,
201+
balance_param=5e-5)
202+
sdml.fit(self.iris_points, self.iris_labels, random_state=rs)
203+
csep = class_separation(sdml.transform(self.iris_points),
204+
self.iris_labels)
205+
self.assertLess(csep, 0.22)
206+
207+
def test_deprecation_num_labeled(self):
208+
# test that a deprecation message is thrown if num_labeled is set at
209+
# initialization
210+
# TODO: remove in v.0.6
211+
X, y = make_classification(random_state=42)
212+
sdml_supervised = SDML_Supervised(num_labeled=np.inf, use_cov=False,
213+
balance_param=5e-5)
214+
msg = ('"num_labeled" parameter is not used.'
215+
' It has been deprecated in version 0.5.0 and will be'
216+
'removed in 0.6.0')
217+
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)
218+
219+
def test_sdml_raises_warning_non_psd(self):
220+
"""Tests that SDML raises a warning on a toy example where we know the
221+
pseudo-covariance matrix is not PSD"""
222+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
223+
y = [1, -1]
224+
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
225+
msg = ("Warning, the input matrix of graphical lasso is not "
226+
"positive semi-definite (PSD). The algorithm may diverge, "
227+
"and lead to degenerate solutions. "
228+
"To prevent that, try to decrease the balance parameter "
229+
"`balance_param` and/or to set use_covariance=False.")
230+
with pytest.warns(ConvergenceWarning) as raised_warning:
231+
try:
232+
sdml.fit(pairs, y)
233+
except Exception:
234+
pass
235+
# we assert that this warning is in one of the warning raised by the
236+
# estimator
237+
assert msg in list(map(lambda w: str(w.message), raised_warning))
238+
239+
def test_sdml_converges_if_psd(self):
240+
"""Tests that sdml converges on a simple problem where we know the
241+
pseudo-covariance matrix is PSD"""
242+
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
243+
y = [1, -1]
244+
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
245+
sdml.fit(pairs, y)
246+
assert np.isfinite(sdml.get_mahalanobis_matrix()).all()
247+
248+
@pytest.mark.skipif(not has_installed_skggm(),
249+
reason="sklearn's graphical_lasso can sometimes not "
250+
"work on some non SPD problems. We test that "
251+
"is works only if skggm is installed.")
252+
def test_sdml_works_on_non_spd_pb_with_skggm(self):
253+
"""Test that SDML works on a certain non SPD problem on which we know
254+
it should work, but scikit-learn's graphical_lasso does not work"""
255+
X, y = load_iris(return_X_y=True)
256+
sdml = SDML_Supervised(balance_param=0.5, sparsity_param=0.01,
257+
use_cov=True)
258+
sdml.fit(X, y)
259259

260260

261261
@pytest.mark.skipif(not has_installed_skggm(),

0 commit comments

Comments
 (0)