@@ -150,112 +150,112 @@ def test_no_twice_same_objective(capsys):
150
150
151
151
class TestSDML (MetricTestCase ):
152
152
153
- @pytest .mark .skipif (has_installed_skggm (),
154
- reason = "The warning will be thrown only if skggm is "
155
- "not installed." )
156
- def test_raises_warning_msg_not_installed_skggm (self ):
157
- """Tests that the right warning message is raised if someone tries to
158
- use SDML but has not installed skggm"""
159
- # TODO: remove if we don't need skggm anymore
160
- pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
161
- y_pairs = [1 , - 1 ]
162
- X , y = make_classification (random_state = 42 )
153
+ @pytest .mark .skipif (has_installed_skggm (),
154
+ reason = "The warning will be thrown only if skggm is "
155
+ "not installed." )
156
+ def test_raises_warning_msg_not_installed_skggm (self ):
157
+ """Tests that the right warning message is raised if someone tries to
158
+ use SDML but has not installed skggm"""
159
+ # TODO: remove if we don't need skggm anymore
160
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
161
+ y_pairs = [1 , - 1 ]
162
+ X , y = make_classification (random_state = 42 )
163
+ sdml = SDML ()
164
+ sdml_supervised = SDML_Supervised (use_cov = False , balance_param = 1e-5 )
165
+ msg = ("Warning, skggm is not installed, so SDML will use "
166
+ "scikit-learn's graphical_lasso method. It can fail to converge"
167
+ "on some non SPD matrices where skggm would converge. If so, "
168
+ "try to install skggm. (see the README.md for the right "
169
+ "version.)" )
170
+ with pytest .warns (None ) as record :
171
+ sdml .fit (pairs , y_pairs )
172
+ assert str (record [0 ].message ) == msg
173
+ with pytest .warns (None ) as record :
174
+ sdml_supervised .fit (X , y )
175
+ assert str (record [0 ].message ) == msg
176
+
177
+ @pytest .mark .skipif (not has_installed_skggm (),
178
+ reason = "It's only in the case where skggm is installed"
179
+ "that no warning should be thrown." )
180
+ def test_raises_no_warning_installed_skggm (self ):
181
+ # otherwise we should be able to instantiate and fit SDML and it
182
+ # should raise no warning
183
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
184
+ y_pairs = [1 , - 1 ]
185
+ X , y = make_classification (random_state = 42 )
186
+ with pytest .warns (None ) as record :
163
187
sdml = SDML ()
164
- sdml_supervised = SDML_Supervised (use_cov = False , balance_param = 1e-5 )
165
- msg = ("Warning, skggm is not installed, so SDML will use "
166
- "scikit-learn's graphical_lasso method. It can fail to converge"
167
- "on some non SPD matrices where skggm would converge. If so, "
168
- "try to install skggm. (see the README.md for the right "
169
- "version.)" )
170
- with pytest .warns (None ) as record :
171
- sdml .fit (pairs , y_pairs )
172
- assert str (record [0 ].message ) == msg
173
- with pytest .warns (None ) as record :
174
- sdml_supervised .fit (X , y )
175
- assert str (record [0 ].message ) == msg
176
-
177
- @pytest .mark .skipif (not has_installed_skggm (),
178
- reason = "It's only in the case where skggm is installed"
179
- "that no warning should be thrown." )
180
- def test_raises_no_warning_installed_skggm (self ):
181
- # otherwise we should be able to instantiate and fit SDML and it
182
- # should raise no warning
183
- pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
184
- y_pairs = [1 , - 1 ]
185
- X , y = make_classification (random_state = 42 )
186
- with pytest .warns (None ) as record :
187
- sdml = SDML ()
188
- sdml .fit (pairs , y_pairs )
189
- assert len (record ) == 0
190
- with pytest .warns (None ) as record :
191
- sdml = SDML_Supervised (use_cov = False , balance_param = 1e-5 )
192
- sdml .fit (X , y )
193
- assert len (record ) == 0
194
-
195
- def test_iris (self ):
196
- # Note: this is a flaky test, which fails for certain seeds.
197
- # TODO: un-flake it!
198
- rs = np .random .RandomState (5555 )
199
-
200
- sdml = SDML_Supervised (num_constraints = 1500 , use_cov = False ,
201
- balance_param = 5e-5 )
202
- sdml .fit (self .iris_points , self .iris_labels , random_state = rs )
203
- csep = class_separation (sdml .transform (self .iris_points ),
204
- self .iris_labels )
205
- self .assertLess (csep , 0.22 )
206
-
207
- def test_deprecation_num_labeled (self ):
208
- # test that a deprecation message is thrown if num_labeled is set at
209
- # initialization
210
- # TODO: remove in v.0.6
211
- X , y = make_classification (random_state = 42 )
212
- sdml_supervised = SDML_Supervised (num_labeled = np .inf , use_cov = False ,
213
- balance_param = 5e-5 )
214
- msg = ('"num_labeled" parameter is not used.'
215
- ' It has been deprecated in version 0.5.0 and will be'
216
- 'removed in 0.6.0' )
217
- assert_warns_message (DeprecationWarning , msg , sdml_supervised .fit , X , y )
218
-
219
- def test_sdml_raises_warning_non_psd (self ):
220
- """Tests that SDML raises a warning on a toy example where we know the
221
- pseudo-covariance matrix is not PSD"""
222
- pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , 50. ], [0. , - 60 ]]])
223
- y = [1 , - 1 ]
224
- sdml = SDML (use_cov = True , sparsity_param = 0.01 , balance_param = 0.5 )
225
- msg = ("Warning, the input matrix of graphical lasso is not "
226
- "positive semi-definite (PSD). The algorithm may diverge, "
227
- "and lead to degenerate solutions. "
228
- "To prevent that, try to decrease the balance parameter "
229
- "`balance_param` and/or to set use_covariance=False." )
230
- with pytest .warns (ConvergenceWarning ) as raised_warning :
231
- try :
232
- sdml .fit (pairs , y )
233
- except Exception :
234
- pass
235
- # we assert that this warning is in one of the warning raised by the
236
- # estimator
237
- assert msg in list (map (lambda w : str (w .message ), raised_warning ))
238
-
239
- def test_sdml_converges_if_psd (self ):
240
- """Tests that sdml converges on a simple problem where we know the
241
- pseudo-covariance matrix is PSD"""
242
- pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
243
- y = [1 , - 1 ]
244
- sdml = SDML (use_cov = True , sparsity_param = 0.01 , balance_param = 0.5 )
245
- sdml .fit (pairs , y )
246
- assert np .isfinite (sdml .get_mahalanobis_matrix ()).all ()
247
-
248
- @pytest .mark .skipif (not has_installed_skggm (),
249
- reason = "sklearn's graphical_lasso can sometimes not "
250
- "work on some non SPD problems. We test that "
251
- "is works only if skggm is installed." )
252
- def test_sdml_works_on_non_spd_pb_with_skggm (self ):
253
- """Test that SDML works on a certain non SPD problem on which we know
254
- it should work, but scikit-learn's graphical_lasso does not work"""
255
- X , y = load_iris (return_X_y = True )
256
- sdml = SDML_Supervised (balance_param = 0.5 , sparsity_param = 0.01 ,
257
- use_cov = True )
188
+ sdml .fit (pairs , y_pairs )
189
+ assert len (record ) == 0
190
+ with pytest .warns (None ) as record :
191
+ sdml = SDML_Supervised (use_cov = False , balance_param = 1e-5 )
258
192
sdml .fit (X , y )
193
+ assert len (record ) == 0
194
+
195
+ def test_iris (self ):
196
+ # Note: this is a flaky test, which fails for certain seeds.
197
+ # TODO: un-flake it!
198
+ rs = np .random .RandomState (5555 )
199
+
200
+ sdml = SDML_Supervised (num_constraints = 1500 , use_cov = False ,
201
+ balance_param = 5e-5 )
202
+ sdml .fit (self .iris_points , self .iris_labels , random_state = rs )
203
+ csep = class_separation (sdml .transform (self .iris_points ),
204
+ self .iris_labels )
205
+ self .assertLess (csep , 0.22 )
206
+
207
+ def test_deprecation_num_labeled (self ):
208
+ # test that a deprecation message is thrown if num_labeled is set at
209
+ # initialization
210
+ # TODO: remove in v.0.6
211
+ X , y = make_classification (random_state = 42 )
212
+ sdml_supervised = SDML_Supervised (num_labeled = np .inf , use_cov = False ,
213
+ balance_param = 5e-5 )
214
+ msg = ('"num_labeled" parameter is not used.'
215
+ ' It has been deprecated in version 0.5.0 and will be'
216
+ 'removed in 0.6.0' )
217
+ assert_warns_message (DeprecationWarning , msg , sdml_supervised .fit , X , y )
218
+
219
+ def test_sdml_raises_warning_non_psd (self ):
220
+ """Tests that SDML raises a warning on a toy example where we know the
221
+ pseudo-covariance matrix is not PSD"""
222
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , 50. ], [0. , - 60 ]]])
223
+ y = [1 , - 1 ]
224
+ sdml = SDML (use_cov = True , sparsity_param = 0.01 , balance_param = 0.5 )
225
+ msg = ("Warning, the input matrix of graphical lasso is not "
226
+ "positive semi-definite (PSD). The algorithm may diverge, "
227
+ "and lead to degenerate solutions. "
228
+ "To prevent that, try to decrease the balance parameter "
229
+ "`balance_param` and/or to set use_covariance=False." )
230
+ with pytest .warns (ConvergenceWarning ) as raised_warning :
231
+ try :
232
+ sdml .fit (pairs , y )
233
+ except Exception :
234
+ pass
235
+ # we assert that this warning is in one of the warning raised by the
236
+ # estimator
237
+ assert msg in list (map (lambda w : str (w .message ), raised_warning ))
238
+
239
+ def test_sdml_converges_if_psd (self ):
240
+ """Tests that sdml converges on a simple problem where we know the
241
+ pseudo-covariance matrix is PSD"""
242
+ pairs = np .array ([[[- 10. , 0. ], [10. , 0. ]], [[0. , - 55. ], [0. , - 60 ]]])
243
+ y = [1 , - 1 ]
244
+ sdml = SDML (use_cov = True , sparsity_param = 0.01 , balance_param = 0.5 )
245
+ sdml .fit (pairs , y )
246
+ assert np .isfinite (sdml .get_mahalanobis_matrix ()).all ()
247
+
248
+ @pytest .mark .skipif (not has_installed_skggm (),
249
+ reason = "sklearn's graphical_lasso can sometimes not "
250
+ "work on some non SPD problems. We test that "
251
+ "is works only if skggm is installed." )
252
+ def test_sdml_works_on_non_spd_pb_with_skggm (self ):
253
+ """Test that SDML works on a certain non SPD problem on which we know
254
+ it should work, but scikit-learn's graphical_lasso does not work"""
255
+ X , y = load_iris (return_X_y = True )
256
+ sdml = SDML_Supervised (balance_param = 0.5 , sparsity_param = 0.01 ,
257
+ use_cov = True )
258
+ sdml .fit (X , y )
259
259
260
260
261
261
@pytest .mark .skipif (not has_installed_skggm (),
0 commit comments