@@ -180,8 +180,14 @@ def score_pairs(self, pairs):
180
180
pairs = check_input (pairs , type_of_inputs = 'tuples' ,
181
181
preprocessor = self .preprocessor_ ,
182
182
estimator = self , tuple_size = 2 )
183
-
184
- return [np .dot (np .dot (u , self .components_ ), v .T ) for u ,v in zip (pairs [:, 1 , :], pairs [:, 0 , :])]
183
+
184
+ # Note: For bilinear order matters, dist(a,b) != dist(b,a)
185
+ # We always choose first pair first, then second pair
186
+ # (In contrast with Mahalanobis implementation)
187
+
188
+ # I dont know wich implementation performs better
189
+ return np .diagonal (np .dot (np .dot (pairs [:, 0 , :], self .components_ ), pairs [:, 1 , :].T ))
190
+ return [np .dot (np .dot (u .T , self .components_ ), v ) for u ,v in zip (pairs [:, 0 , :], pairs [:, 1 , :])]
185
191
186
192
def get_metric (self ):
187
193
check_is_fitted (self , 'components_' )
@@ -206,7 +212,7 @@ def metric_fun(u, v):
206
212
"""
207
213
u = validate_vector (u )
208
214
v = validate_vector (v )
209
- return np .dot (np .dot (u , components ), v . T )
215
+ return np .dot (np .dot (u . T , components ), v )
210
216
211
217
return metric_fun
212
218
0 commit comments