Skip to content

Commit 31072d3

Browse files
author
William de Vazelhes
committed
TST: make test about 1 feature arrays more readable
1 parent a7ed1bb commit 31072d3

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

test/test_mahalanobis_mixin.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from sklearn.utils.testing import set_random_state
1111

1212
from metric_learn._util import make_context
13+
from metric_learn.base_metric import (_QuadrupletsClassifierMixin,
14+
_PairsClassifierMixin)
1315

1416
from test.test_utils import ids_metric_learners, metric_learners
1517

@@ -283,13 +285,17 @@ def test_transformer_is_2D(estimator, build_dataset):
283285
trunc_data = input_data[..., :1]
284286
# we drop duplicates that might have been formed, i.e. of the form
285287
# aabc or abcc or aabb for quadruplets, and aa for pairs.
286-
slices = {4: [slice(0, 2), slice(2, 4)], 2: [slice(0, 2)]}
287-
if trunc_data.ndim == 3:
288-
for slice_idx in slices[trunc_data.shape[1]]:
288+
if isinstance(estimator, _QuadrupletsClassifierMixin):
289+
for slice_idx in [slice(0, 2), slice(2, 4)]:
289290
pairs = trunc_data[:, slice_idx, :]
290291
diffs = pairs[:, 1, :] - pairs[:, 0, :]
291-
to_keep = np.nonzero(diffs.ravel())
292+
to_keep = np.where(np.abs(diffs.ravel()) > 1e-9)
292293
trunc_data = trunc_data[to_keep]
293294
labels = labels[to_keep]
295+
elif isinstance(estimator, _PairsClassifierMixin):
296+
diffs = trunc_data[:, 1, :] - trunc_data[:, 0, :]
297+
to_keep = np.where(np.abs(diffs.ravel()) > 1e-9)
298+
trunc_data = trunc_data[to_keep]
299+
labels = labels[to_keep]
294300
model.fit(trunc_data, labels)
295301
assert model.transformer_.shape == (1, 1) # the transformer must be 2D

0 commit comments

Comments
 (0)