@@ -285,26 +285,6 @@ def test_deprecation_num_labeled(self):
285
285
'removed in 0.6.0' )
286
286
assert_warns_message (DeprecationWarning , msg , sdml_supervised .fit , X , y )
287
287
288
- def test_sdml_raises_warning_non_psd (self ):
289
- """Tests that SDML raises a warning on a toy example where we know the
290
- pseudo-covariance matrix is not PSD"""
291
- pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , 50. ], [0. , - 60 ]]])
292
- y = [1 , - 1 ]
293
- sdml = SDML (use_cov = True , sparsity_param = 0.01 , balance_param = 0.5 )
294
- msg = ("Warning, the input matrix of graphical lasso is not "
295
- "positive semi-definite (PSD). The algorithm may diverge, "
296
- "and lead to degenerate solutions. "
297
- "To prevent that, try to decrease the balance parameter "
298
- "`balance_param` and/or to set use_covariance=False." )
299
- with pytest .warns (ConvergenceWarning ) as raised_warning :
300
- try :
301
- sdml .fit (pairs , y )
302
- except Exception :
303
- pass
304
- # we assert that this warning is in one of the warning raised by the
305
- # estimator
306
- assert msg in list (map (lambda w : str (w .message ), raised_warning ))
307
-
308
288
def test_sdml_converges_if_psd (self ):
309
289
"""Tests that sdml converges on a simple problem where we know the
310
290
pseudo-covariance matrix is PSD"""
@@ -385,6 +365,27 @@ def test_verbose_has_not_installed_skggm_sdml_supervised(capsys):
385
365
assert "SDML will use scikit-learn's graphical lasso solver." in out
386
366
387
367
368
+ def test_sdml_raises_warning_non_psd ():
369
+ """Tests that SDML raises a warning on a toy example where we know the
370
+ pseudo-covariance matrix is not PSD"""
371
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , 50. ], [0. , - 60 ]]])
372
+ y = [1 , - 1 ]
373
+ sdml = SDML (use_cov = True , sparsity_param = 0.01 , balance_param = 0.5 )
374
+ msg = ("Warning, the input matrix of graphical lasso is not "
375
+ "positive semi-definite (PSD). The algorithm may diverge, "
376
+ "and lead to degenerate solutions. "
377
+ "To prevent that, try to decrease the balance parameter "
378
+ "`balance_param` and/or to set use_covariance=False." )
379
+ with pytest .warns (ConvergenceWarning ) as raised_warning :
380
+ try :
381
+ sdml .fit (pairs , y )
382
+ except Exception :
383
+ pass
384
+ # we assert that this warning is in one of the warning raised by the
385
+ # estimator
386
+ assert msg in list (map (lambda w : str (w .message ), raised_warning ))
387
+
388
+
388
389
class TestNCA (MetricTestCase ):
389
390
def test_iris (self ):
390
391
n = self .iris_points .shape [0 ]
0 commit comments