1
+ from metric_learn .base_metric import BilinearMixin
2
+ import numpy as np
3
+ from numpy .testing import assert_array_almost_equal
4
+
5
+ class IdentityBilinearMixin (BilinearMixin ):
6
+ """A simple Identity bilinear mixin that returns an identity matrix
7
+ M as learned. Can change M for a random matrix calling random_M.
8
+ Class for testing purposes.
9
+ """
10
+ def __init__ (self , preprocessor = None ):
11
+ super ().__init__ (preprocessor = preprocessor )
12
+
13
+ def fit (self , X , y ):
14
+ X , y = self ._prepare_inputs (X , y , ensure_min_samples = 2 )
15
+ self .d = np .shape (X [0 ])[- 1 ]
16
+ self .components_ = np .identity (self .d )
17
+ return self
18
+
19
+ def random_M (self ):
20
+ self .components_ = np .random .rand (self .d , self .d )
21
+
22
+ def test_same_similarity_with_two_methods ():
23
+ d = 100
24
+ u = np .random .rand (d )
25
+ v = np .random .rand (d )
26
+ mixin = IdentityBilinearMixin ()
27
+ mixin .fit ([u , v ], [0 , 0 ]) # Dummy fit
28
+ mixin .random_M ()
29
+
30
+ # The distances must match, whether calc with get_metric() or score_pairs()
31
+ dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
32
+ dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
33
+
34
+ assert_array_almost_equal (dist1 , dist2 )
35
+
36
+ def test_check_correctness_similarity ():
37
+ d = 100
38
+ u = np .random .rand (d )
39
+ v = np .random .rand (d )
40
+ mixin = IdentityBilinearMixin ()
41
+ mixin .fit ([u , v ], [0 , 0 ]) # Dummy fit
42
+ dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
43
+ u_v = np .dot (np .dot (u .T , np .identity (d )), v )
44
+ v_u = np .dot (np .dot (v .T , np .identity (d )), u )
45
+ desired = [u_v , v_u ]
46
+ assert_array_almost_equal (dist1 , desired )
47
+
48
+ def test_check_handmade_example ():
49
+ u = np .array ([0 , 1 , 2 ])
50
+ v = np .array ([3 , 4 , 5 ])
51
+ mixin = IdentityBilinearMixin ()
52
+ mixin .fit ([u , v ], [0 , 0 ])
53
+ c = np .array ([[2 , 4 , 6 ], [6 , 4 , 2 ], [1 , 2 , 3 ]])
54
+ mixin .components_ = c # Force a components_
55
+ dists = mixin .score_pairs ([[u , v ], [v , u ]])
56
+ assert_array_almost_equal (dists , [96 , 120 ])
57
+
58
+ def test_check_handmade_symmetric_example ():
59
+ u = np .array ([0 , 1 , 2 ])
60
+ v = np .array ([3 , 4 , 5 ])
61
+ mixin = IdentityBilinearMixin ()
62
+ mixin .fit ([u , v ], [0 , 0 ])
63
+ dists = mixin .score_pairs ([[u , v ], [v , u ]])
64
+ assert_array_almost_equal (dists , [14 , 14 ])
0 commit comments