Skip to content

Commit dbf5257

Browse files
author
William de Vazelhes
committed
FIX: fix for sdml by reducing balance parameter
1 parent 45d3b7b commit dbf5257

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

test/test_mahalanobis_mixin.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -272,26 +272,25 @@ def test_get_squared_metric(estimator, build_dataset):
272272
ids=ids_metric_learners)
273273
def test_transformer_is_2D(estimator, build_dataset):
274274
"""Tests that the transformer of metric learners is 2D"""
275-
# TODO: remove this check when SDML has become robust to 1D elements,
276-
# or when the 1D case is dealt with separately
277-
if not str(estimator).startswith('SDML'):
278-
input_data, labels, _, X = build_dataset()
279-
model = clone(estimator)
280-
set_random_state(model)
281-
# test that it works for X.shape[1] features
282-
model.fit(input_data, labels)
283-
assert model.transformer_.shape == (X.shape[1], X.shape[1])
284-
285-
# test that it works for 1 feature
286-
trunc_data = input_data[..., :1]
287-
# we drop duplicates that might have been formed, i.e. of the form
288-
# aabc or abcc or aabb for quadruplets, and aa for pairs.
289-
slices = {4: [slice(0, 2), slice(2, 4)], 2: [slice(0, 2)]}
290-
if trunc_data.ndim == 3:
291-
for slice_idx in slices[trunc_data.shape[1]]:
292-
_, indices = np.unique(trunc_data[:, slice_idx, :], axis=2,
293-
return_index=True)
294-
trunc_data = trunc_data[indices]
295-
labels = labels[indices]
296-
model.fit(trunc_data, labels)
297-
assert model.transformer_.shape == (1, 1) # the transformer must be 2D
275+
input_data, labels, _, X = build_dataset()
276+
model = clone(estimator)
277+
if model.__class__.__name__.startswith('SDML'):
278+
model.set_params(use_cov=False, balance_param=1e-3)
279+
set_random_state(model)
280+
# test that it works for X.shape[1] features
281+
model.fit(input_data, labels)
282+
assert model.transformer_.shape == (X.shape[1], X.shape[1])
283+
284+
# test that it works for 1 feature
285+
trunc_data = input_data[..., :1]
286+
# we drop duplicates that might have been formed, i.e. of the form
287+
# aabc or abcc or aabb for quadruplets, and aa for pairs.
288+
slices = {4: [slice(0, 2), slice(2, 4)], 2: [slice(0, 2)]}
289+
if trunc_data.ndim == 3:
290+
for slice_idx in slices[trunc_data.shape[1]]:
291+
_, indices = np.unique(trunc_data[:, slice_idx, :], axis=2,
292+
return_index=True)
293+
trunc_data = trunc_data[indices]
294+
labels = labels[indices]
295+
model.fit(trunc_data, labels)
296+
assert model.transformer_.shape == (1, 1) # the transformer must be 2D

0 commit comments

Comments
 (0)