Skip to content

Commit 9a10e06

Browse files
author
mvargas33
committed
Found an efficient way to compute Bilinear Sim for n pairs
1 parent ee5c5ee commit 9a10e06

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

metric_learn/base_metric.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,7 @@ def score_pairs(self, pairs):
184184
# Note: For bilinear order matters, dist(a,b) != dist(b,a)
185185
# We always choose first pair first, then second pair
186186
# (In contrast with Mahalanobis implementation)
187-
# I dont know wich implementation performs better
188-
return np.diagonal(np.dot(
189-
np.dot(pairs[:, 0, :], self.components_),
190-
pairs[:, 1, :].T))
191-
return np.array([np.dot(np.dot(u.T, self.components_), v)
192-
for u, v in zip(pairs[:, 0, :], pairs[:, 1, :])])
187+
return (np.dot(pairs[:, 0, :], self.components_) * pairs[:, 1, :]).sum(-1)
193188

194189
def get_metric(self):
195190
check_is_fitted(self, 'components_')

test_bilinear.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_bilinar_properties():
5959
def test_performace():
6060

6161
features = int(1e4)
62-
samples = int(1e3)
62+
samples = int(1e4)
6363

6464
a = [np.random.rand(features) for i in range(samples)]
6565
b = [np.random.rand(features) for i in range(samples)]
@@ -75,6 +75,9 @@ def op_2(pairs, components):
7575
return np.array([np.dot(np.dot(u.T, components), v)
7676
for u, v in zip(pairs[:, 0, :], pairs[:, 1, :])])
7777

78+
def op_3(pairs, components):
79+
return (np.dot(pairs[:, 0, :], components) * pairs[:, 1, :]).sum(-1)
80+
7881
# Test first method
7982
start = timer()
8083
op_1(pairs, components)
@@ -86,8 +89,9 @@ def op_2(pairs, components):
8689
op_2(pairs, components)
8790
end = timer()
8891
print(f'Second method took {end - start}')
89-
90-
91-
# test_toy_distance()
92-
# test_bilinar_properties()
93-
test_performace()
92+
93+
# Test second method
94+
start = timer()
95+
op_3(pairs, components)
96+
end = timer()
97+
print(f'Third method took {end - start}')

0 commit comments

Comments
 (0)