@@ -155,28 +155,89 @@ def test_no_twice_same_objective(capsys):
155
155
class TestSDML (MetricTestCase ):
156
156
157
157
@pytest .mark .skipif (HAS_SKGGM ,
158
- reason = "The warning will be thrown only if skggm is "
158
+ reason = "The warning can be thrown only if skggm is "
159
159
"not installed." )
160
- def test_raises_warning_msg_not_installed_skggm (self ):
160
+ def test_sdml_supervised_raises_warning_msg_not_installed_skggm (self ):
161
161
"""Tests that the right warning message is raised if someone tries to
162
- use SDML but has not installed skggm"""
162
+ use SDML_Supervised but has not installed skggm, and that the algorithm
163
+ fails to converge"""
163
164
# TODO: remove if we don't need skggm anymore
164
- pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
165
+ # load_iris: dataset where we know scikit-learn's graphical lasso fails
166
+ # with a Floating Point error
167
+ X , y = load_iris (return_X_y = True )
168
+ sdml_supervised = SDML_Supervised (balance_param = 0.5 , use_cov = True ,
169
+ sparsity_param = 0.01 )
170
+ msg = ("There was a problem in SDML when using scikit-learn's graphical "
171
+ "lasso solver. skggm's graphical lasso can sometimes converge on "
172
+ "non SPD cases where scikit-learn's graphical lasso fails to "
173
+ "converge. Try to install skggm and rerun the algorithm (see "
174
+ "the README.md for the right version of skggm). The following "
175
+ "error message was thrown:" )
176
+ with pytest .raises (RuntimeError ) as raised_error :
177
+ sdml_supervised .fit (X , y )
178
+ assert str (raised_error .value ).startswith (msg )
179
+
180
+ @pytest .mark .skipif (HAS_SKGGM ,
181
+ reason = "The warning can be thrown only if skggm is "
182
+ "not installed." )
183
+ def test_sdml_raises_warning_msg_not_installed_skggm (self ):
184
+ """Tests that the right warning message is raised if someone tries to
185
+ use SDML but has not installed skggm, and that the algorithm fails to
186
+ converge"""
187
+ # TODO: remove if we don't need skggm anymore
188
+ # case on which we know that scikit-learn's graphical lasso fails
189
+ # because it will return a non SPD matrix
190
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , 50. ], [0. , - 60 ]]])
165
191
y_pairs = [1 , - 1 ]
166
- X , y = make_classification (random_state = 42 )
167
- sdml = SDML ()
168
- sdml_supervised = SDML_Supervised (use_cov = False , balance_param = 1e-5 )
169
- msg = ("Warning, skggm is not installed, so SDML will use "
170
- "scikit-learn's graphical_lasso method. It can fail to converge"
171
- "on some non SPD matrices where skggm would converge. If so, "
172
- "try to install skggm. (see the README.md for the right "
173
- "version.)" )
174
- with pytest .warns (None ) as record :
192
+ sdml = SDML (use_cov = False , balance_param = 100 , verbose = True )
193
+
194
+ msg = ("There was a problem in SDML when using scikit-learn's graphical "
195
+ "lasso solver. skggm's graphical lasso can sometimes converge on "
196
+ "non SPD cases where scikit-learn's graphical lasso fails to "
197
+ "converge. Try to install skggm and rerun the algorithm (see "
198
+ "the README.md for the right version of skggm)." )
199
+ with pytest .raises (RuntimeError ) as raised_error :
175
200
sdml .fit (pairs , y_pairs )
176
- assert str (record [0 ].message ) == msg
177
- with pytest .warns (None ) as record :
201
+ assert msg == str (raised_error .value )
202
+
203
+ @pytest .mark .skipif (not HAS_SKGGM ,
204
+ reason = "The warning can be thrown only if skggm is "
205
+ "installed." )
206
+ def test_sdml_raises_warning_msg_installed_skggm (self ):
207
+ """Tests that the right warning message is raised if someone tries to
208
+ use SDML but has not installed skggm, and that the algorithm fails to
209
+ converge"""
210
+ # TODO: remove if we don't need skggm anymore
211
+ # case on which we know that skggm's graphical lasso fails
212
+ # because it will return non finite values
213
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , 50. ], [0. , - 60 ]]])
214
+ y_pairs = [1 , - 1 ]
215
+ sdml = SDML (use_cov = False , balance_param = 100 , verbose = True )
216
+
217
+ msg = ("There was a problem in SDML when using skggm's graphical "
218
+ "lasso solver." )
219
+ with pytest .raises (RuntimeError ) as raised_error :
220
+ sdml .fit (pairs , y_pairs )
221
+ assert msg == str (raised_error .value )
222
+
223
+ @pytest .mark .skipif (not HAS_SKGGM ,
224
+ reason = "The warning can be thrown only if skggm is "
225
+ "installed." )
226
+ def test_sdml_supervised_raises_warning_msg_installed_skggm (self ):
227
+ """Tests that the right warning message is raised if someone tries to
228
+ use SDML_Supervised but has not installed skggm, and that the algorithm
229
+ fails to converge"""
230
+ # TODO: remove if we don't need skggm anymore
231
+ # case on which we know that skggm's graphical lasso fails
232
+ # because it will return non finite values
233
+ X , y = load_iris (return_X_y = True )
234
+ sdml_supervised = SDML_Supervised (balance_param = 0.5 , use_cov = True ,
235
+ sparsity_param = 0.01 )
236
+ msg = ("There was a problem in SDML when using skggm's graphical "
237
+ "lasso solver." )
238
+ with pytest .raises (RuntimeError ) as raised_error :
178
239
sdml_supervised .fit (X , y )
179
- assert str ( record [ 0 ]. message ) == msg
240
+ assert msg == str ( raised_error . value )
180
241
181
242
@pytest .mark .skipif (not HAS_SKGGM ,
182
243
reason = "It's only in the case where skggm is installed"
@@ -271,10 +332,10 @@ def test_verbose_has_installed_skggm_sdml(capsys):
271
332
# TODO: remove if we don't need skggm anymore
272
333
pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
273
334
y_pairs = [1 , - 1 ]
274
- sdml = SDML ()
335
+ sdml = SDML (verbose = True )
275
336
sdml .fit (pairs , y_pairs )
276
337
out , _ = capsys .readouterr ()
277
- assert "SDML will use skggm's solver." in out
338
+ assert "SDML will use skggm's graphical lasso solver." in out
278
339
279
340
280
341
@pytest .mark .skipif (not HAS_SKGGM ,
@@ -285,10 +346,39 @@ def test_verbose_has_installed_skggm_sdml_supervised(capsys):
285
346
# skggm's solver is used (when they use SDML_Supervised)
286
347
# TODO: remove if we don't need skggm anymore
287
348
X , y = make_classification (random_state = 42 )
288
- sdml = SDML_Supervised ()
349
+ sdml = SDML_Supervised (verbose = True )
350
+ sdml .fit (X , y )
351
+ out , _ = capsys .readouterr ()
352
+ assert "SDML will use skggm's graphical lasso solver." in out
353
+
354
+
355
+ @pytest .mark .skipif (HAS_SKGGM ,
356
+ reason = 'The message should be printed only if skggm is '
357
+ 'not installed.' )
358
+ def test_verbose_has_not_installed_skggm_sdml (capsys ):
359
+ # Test that if users have installed skggm, a message is printed telling them
360
+ # skggm's solver is used (when they use SDML)
361
+ # TODO: remove if we don't need skggm anymore
362
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
363
+ y_pairs = [1 , - 1 ]
364
+ sdml = SDML (verbose = True )
365
+ sdml .fit (pairs , y_pairs )
366
+ out , _ = capsys .readouterr ()
367
+ assert "SDML will use scikit-learn's graphical lasso solver." in out
368
+
369
+
370
+ @pytest .mark .skipif (HAS_SKGGM ,
371
+ reason = 'The message should be printed only if skggm is '
372
+ 'not installed.' )
373
+ def test_verbose_has_not_installed_skggm_sdml_supervised (capsys ):
374
+ # Test that if users have installed skggm, a message is printed telling them
375
+ # skggm's solver is used (when they use SDML_Supervised)
376
+ # TODO: remove if we don't need skggm anymore
377
+ X , y = make_classification (random_state = 42 )
378
+ sdml = SDML_Supervised (verbose = True , balance_param = 1e-5 , use_cov = False )
289
379
sdml .fit (X , y )
290
380
out , _ = capsys .readouterr ()
291
- assert "SDML will use skggm's solver." in out
381
+ assert "SDML will use scikit-learn's graphical lasso solver." in out
292
382
293
383
294
384
class TestNCA (MetricTestCase ):
0 commit comments