@@ -272,26 +272,25 @@ def test_get_squared_metric(estimator, build_dataset):
272
272
ids = ids_metric_learners )
273
273
def test_transformer_is_2D (estimator , build_dataset ):
274
274
"""Tests that the transformer of metric learners is 2D"""
275
- # TODO: remove this check when SDML has become robust to 1D elements,
276
- # or when the 1D case is dealt with separately
277
- if not str (estimator ).startswith ('SDML' ):
278
- input_data , labels , _ , X = build_dataset ()
279
- model = clone (estimator )
280
- set_random_state (model )
281
- # test that it works for X.shape[1] features
282
- model .fit (input_data , labels )
283
- assert model .transformer_ .shape == (X .shape [1 ], X .shape [1 ])
284
-
285
- # test that it works for 1 feature
286
- trunc_data = input_data [..., :1 ]
287
- # we drop duplicates that might have been formed, i.e. of the form
288
- # aabc or abcc or aabb for quadruplets, and aa for pairs.
289
- slices = {4 : [slice (0 , 2 ), slice (2 , 4 )], 2 : [slice (0 , 2 )]}
290
- if trunc_data .ndim == 3 :
291
- for slice_idx in slices [trunc_data .shape [1 ]]:
292
- _ , indices = np .unique (trunc_data [:, slice_idx , :], axis = 2 ,
293
- return_index = True )
294
- trunc_data = trunc_data [indices ]
295
- labels = labels [indices ]
296
- model .fit (trunc_data , labels )
297
- assert model .transformer_ .shape == (1 , 1 ) # the transformer must be 2D
275
+ input_data , labels , _ , X = build_dataset ()
276
+ model = clone (estimator )
277
+ if model .__class__ .__name__ .startswith ('SDML' ):
278
+ model .set_params (use_cov = False , balance_param = 1e-3 )
279
+ set_random_state (model )
280
+ # test that it works for X.shape[1] features
281
+ model .fit (input_data , labels )
282
+ assert model .transformer_ .shape == (X .shape [1 ], X .shape [1 ])
283
+
284
+ # test that it works for 1 feature
285
+ trunc_data = input_data [..., :1 ]
286
+ # we drop duplicates that might have been formed, i.e. of the form
287
+ # aabc or abcc or aabb for quadruplets, and aa for pairs.
288
+ slices = {4 : [slice (0 , 2 ), slice (2 , 4 )], 2 : [slice (0 , 2 )]}
289
+ if trunc_data .ndim == 3 :
290
+ for slice_idx in slices [trunc_data .shape [1 ]]:
291
+ _ , indices = np .unique (trunc_data [:, slice_idx , :], axis = 2 ,
292
+ return_index = True )
293
+ trunc_data = trunc_data [indices ]
294
+ labels = labels [indices ]
295
+ model .fit (trunc_data , labels )
296
+ assert model .transformer_ .shape == (1 , 1 ) # the transformer must be 2D
0 commit comments