|
10 | 10 | from sklearn.utils.testing import set_random_state
|
11 | 11 |
|
12 | 12 | from metric_learn._util import make_context
|
| 13 | +from metric_learn.base_metric import (_QuadrupletsClassifierMixin, |
| 14 | + _PairsClassifierMixin) |
13 | 15 |
|
14 | 16 | from test.test_utils import ids_metric_learners, metric_learners
|
15 | 17 |
|
@@ -283,13 +285,17 @@ def test_transformer_is_2D(estimator, build_dataset):
|
283 | 285 | trunc_data = input_data[..., :1]
|
284 | 286 | # we drop duplicates that might have been formed, i.e. of the form
|
285 | 287 | # aabc or abcc or aabb for quadruplets, and aa for pairs.
|
286 |
| - slices = {4: [slice(0, 2), slice(2, 4)], 2: [slice(0, 2)]} |
287 |
| - if trunc_data.ndim == 3: |
288 |
| - for slice_idx in slices[trunc_data.shape[1]]: |
| 288 | + if isinstance(estimator, _QuadrupletsClassifierMixin): |
| 289 | + for slice_idx in [slice(0, 2), slice(2, 4)]: |
289 | 290 | pairs = trunc_data[:, slice_idx, :]
|
290 | 291 | diffs = pairs[:, 1, :] - pairs[:, 0, :]
|
291 |
| - to_keep = np.nonzero(diffs.ravel()) |
| 292 | + to_keep = np.where(np.abs(diffs.ravel()) > 1e-9) |
292 | 293 | trunc_data = trunc_data[to_keep]
|
293 | 294 | labels = labels[to_keep]
|
| 295 | + elif isinstance(estimator, _PairsClassifierMixin): |
| 296 | + diffs = trunc_data[:, 1, :] - trunc_data[:, 0, :] |
| 297 | + to_keep = np.where(np.abs(diffs.ravel()) > 1e-9) |
| 298 | + trunc_data = trunc_data[to_keep] |
| 299 | + labels = labels[to_keep] |
294 | 300 | model.fit(trunc_data, labels)
|
295 | 301 | assert model.transformer_.shape == (1, 1) # the transformer must be 2D
|
0 commit comments