Skip to content

Commit ec49397

Browse files
mvargas33mvargas33
authored and
mvargas33
committed
Two implementations for score_pairs
1 parent ec09f59 commit ec49397

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

metric_learn/base_metric.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,14 @@ def score_pairs(self, pairs):
180180
pairs = check_input(pairs, type_of_inputs='tuples',
181181
preprocessor=self.preprocessor_,
182182
estimator=self, tuple_size=2)
183-
184-
return [np.dot(np.dot(u, self.components_), v.T) for u,v in zip(pairs[:, 1, :], pairs[:, 0, :])]
183+
184+
# Note: For bilinear order matters, dist(a,b) != dist(b,a)
185+
# We always choose first pair first, then second pair
186+
# (In contrast with Mahalanobis implementation)
187+
188+
# I dont know wich implementation performs better
189+
return np.diagonal(np.dot(np.dot(pairs[:, 0, :], self.components_), pairs[:, 1, :].T))
190+
return [np.dot(np.dot(u.T, self.components_), v) for u,v in zip(pairs[:, 0, :], pairs[:, 1, :])]
185191

186192
def get_metric(self):
187193
check_is_fitted(self, 'components_')
@@ -206,7 +212,7 @@ def metric_fun(u, v):
206212
"""
207213
u = validate_vector(u)
208214
v = validate_vector(v)
209-
return np.dot(np.dot(u, components), v.T)
215+
return np.dot(np.dot(u.T, components), v)
210216

211217
return metric_fun
212218

0 commit comments

Comments
 (0)