Skip to content

Commit ec09f59

Browse files
mvargas33mvargas33
authored and
mvargas33
committed
Fix score_pairs
1 parent 0147c0c commit ec09f59

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

metric_learn/base_metric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def score_pairs(self, pairs):
180180
pairs = check_input(pairs, type_of_inputs='tuples',
181181
preprocessor=self.preprocessor_,
182182
estimator=self, tuple_size=2)
183-
return np.dot(np.dot(pairs[:, 1, :], self.components_), pairs[:, 0, :].T)
183+
184+
return [np.dot(np.dot(u, self.components_), v.T) for u,v in zip(pairs[:, 1, :], pairs[:, 0, :])]
184185

185186
def get_metric(self):
186187
check_is_fitted(self, 'components_')

test_bilinear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test_toy_distance():
99
mixin.fit([u, v], [0, 0])
1010
#mixin.components_ = np.array([[1, 0, 0],[0, 1, 0],[0, 0, 1]])
1111

12-
dist = mixin.score_pairs([[u, v]])
12+
dist = mixin.score_pairs([[u, v],[v, u]])
1313
print(dist)
1414

1515
test_toy_distance()

0 commit comments

Comments
 (0)