Skip to content

Commit 44be909

Browse files
authored
Disallow 0 on Triplets predictions (#331)
* Remove 3.9 from compatibility * Fix Triplets predict function. Made a test to show the point. * Fix identation * Simplified prediction as suggested * Resolved code review comments * Fix weird commit * Simplified assertion
1 parent 7a2a49d commit 44be909

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

metric_learn/base_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def predict(self, triplets):
602602
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
603603
Predictions of the ordering of pairs, for each triplet.
604604
"""
605-
return np.sign(self.decision_function(triplets))
605+
return 2 * (self.decision_function(triplets) > 0) - 1
606606

607607
def decision_function(self, triplets):
608608
"""Predicts differences between sample distances in input triplets.

test/test_triplets_classifiers.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from metric_learn.sklearn_shims import set_random_state
77
from sklearn import clone
88
import numpy as np
9+
from numpy.testing import assert_array_equal
910

1011

1112
@pytest.mark.parametrize('with_preprocessor', [True, False])
@@ -26,6 +27,49 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset,
2627
assert len(not_valid) == 0
2728

2829

30+
@pytest.mark.parametrize('estimator, build_dataset', triplets_learners,
31+
ids=ids_triplets_learners)
32+
def test_no_zero_prediction(estimator, build_dataset):
33+
"""
34+
Test that all predicted values are not zero, even when the
35+
distance d(x,y) and d(x,z) is the same for a triplet of the
36+
form (x, y, z). i.e border cases.
37+
"""
38+
triplets, _, _, X = build_dataset(with_preprocessor=False)
39+
# Force 3 dimentions only, to use cross product and get easy orthogonal vec.
40+
triplets = np.array([[t[0][:3], t[1][:3], t[2][:3]] for t in triplets])
41+
X = X[:, :3]
42+
# Dummy fit
43+
estimator = clone(estimator)
44+
set_random_state(estimator)
45+
estimator.fit(triplets)
46+
# We force the transformation to be identity, to force euclidean distance
47+
estimator.components_ = np.eye(X.shape[1])
48+
49+
# Get two orthogonal vectors in respect to X[1]
50+
k = X[1] / np.linalg.norm(X[1]) # Normalize first vector
51+
x = X[2] - X[2].dot(k) * k # Get random orthogonal vector
52+
x /= np.linalg.norm(x) # Normalize
53+
y = np.cross(k, x) # Get orthogonal vector to x
54+
# Assert these orthogonal vectors are different
55+
with pytest.raises(AssertionError):
56+
assert_array_equal(X[1], x)
57+
with pytest.raises(AssertionError):
58+
assert_array_equal(X[1], y)
59+
# Assert the distance is the same for both
60+
assert estimator.get_metric()(X[1], x) == estimator.get_metric()(X[1], y)
61+
62+
# Form the three scenarios where predict() gives 0 with numpy.sign
63+
triplets_test = np.array( # Critical examples
64+
[[X[0], X[2], X[2]],
65+
[X[1], X[1], X[1]],
66+
[X[1], x, y]])
67+
# Predict
68+
predictions = estimator.predict(triplets_test)
69+
# Check there are no zero values
70+
assert np.sum(predictions == 0) == 0
71+
72+
2973
@pytest.mark.parametrize('with_preprocessor', [True, False])
3074
@pytest.mark.parametrize('estimator, build_dataset', triplets_learners,
3175
ids=ids_triplets_learners)

0 commit comments

Comments
 (0)