6
6
from metric_learn .sklearn_shims import set_random_state
7
7
from sklearn import clone
8
8
import numpy as np
9
+ from numpy .testing import assert_array_equal
9
10
10
11
11
12
@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
@@ -26,6 +27,49 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset,
26
27
assert len (not_valid ) == 0
27
28
28
29
30
+ @pytest .mark .parametrize ('estimator, build_dataset' , triplets_learners ,
31
+ ids = ids_triplets_learners )
32
+ def test_no_zero_prediction (estimator , build_dataset ):
33
+ """
34
+ Test that all predicted values are not zero, even when the
35
+ distance d(x,y) and d(x,z) is the same for a triplet of the
36
+ form (x, y, z). i.e border cases.
37
+ """
38
+ triplets , _ , _ , X = build_dataset (with_preprocessor = False )
39
+ # Force 3 dimentions only, to use cross product and get easy orthogonal vec.
40
+ triplets = np .array ([[t [0 ][:3 ], t [1 ][:3 ], t [2 ][:3 ]] for t in triplets ])
41
+ X = X [:, :3 ]
42
+ # Dummy fit
43
+ estimator = clone (estimator )
44
+ set_random_state (estimator )
45
+ estimator .fit (triplets )
46
+ # We force the transformation to be identity, to force euclidean distance
47
+ estimator .components_ = np .eye (X .shape [1 ])
48
+
49
+ # Get two orthogonal vectors in respect to X[1]
50
+ k = X [1 ] / np .linalg .norm (X [1 ]) # Normalize first vector
51
+ x = X [2 ] - X [2 ].dot (k ) * k # Get random orthogonal vector
52
+ x /= np .linalg .norm (x ) # Normalize
53
+ y = np .cross (k , x ) # Get orthogonal vector to x
54
+ # Assert these orthogonal vectors are different
55
+ with pytest .raises (AssertionError ):
56
+ assert_array_equal (X [1 ], x )
57
+ with pytest .raises (AssertionError ):
58
+ assert_array_equal (X [1 ], y )
59
+ # Assert the distance is the same for both
60
+ assert estimator .get_metric ()(X [1 ], x ) == estimator .get_metric ()(X [1 ], y )
61
+
62
+ # Form the three scenarios where predict() gives 0 with numpy.sign
63
+ triplets_test = np .array ( # Critical examples
64
+ [[X [0 ], X [2 ], X [2 ]],
65
+ [X [1 ], X [1 ], X [1 ]],
66
+ [X [1 ], x , y ]])
67
+ # Predict
68
+ predictions = estimator .predict (triplets_test )
69
+ # Check there are no zero values
70
+ assert np .sum (predictions == 0 ) == 0
71
+
72
+
29
73
@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
30
74
@pytest .mark .parametrize ('estimator, build_dataset' , triplets_learners ,
31
75
ids = ids_triplets_learners )
0 commit comments