Skip to content

Commit dbe2a7a

Browse files
author
mvargas33
committed
Fix identation for bilinear
1 parent c21d283 commit dbe2a7a

File tree

3 files changed

+22
-15
lines changed

3 files changed

+22
-15
lines changed

metric_learn/base_metric.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def transform(self, X):
160160
Input data transformed to the metric space by :math:`XL^{\\top}`
161161
"""
162162

163+
163164
class BilinearMixin(BaseMetricLearner, metaclass=ABCMeta):
164165

165166
def score_pairs(self, pairs):
@@ -180,14 +181,15 @@ def score_pairs(self, pairs):
180181
pairs = check_input(pairs, type_of_inputs='tuples',
181182
preprocessor=self.preprocessor_,
182183
estimator=self, tuple_size=2)
183-
184184
# Note: For bilinear order matters, dist(a,b) != dist(b,a)
185185
# We always choose first pair first, then second pair
186186
# (In contrast with Mahalanobis implementation)
187-
188187
# I dont know wich implementation performs better
189-
return np.diagonal(np.dot(np.dot(pairs[:, 0, :], self.components_), pairs[:, 1, :].T))
190-
return [np.dot(np.dot(u.T, self.components_), v) for u,v in zip(pairs[:, 0, :], pairs[:, 1, :])]
188+
return np.diagonal(np.dot(
189+
np.dot(pairs[:, 0, :], self.components_),
190+
pairs[:, 1, :].T))
191+
return np.array([np.dot(np.dot(u.T, self.components_), v)
192+
for u, v in zip(pairs[:, 0, :], pairs[:, 1, :])])
191193

192194
def get_metric(self):
193195
check_is_fitted(self, 'components_')
@@ -220,6 +222,7 @@ def get_bilinear_matrix(self):
220222
check_is_fitted(self, 'components_')
221223
return self.components_
222224

225+
223226
class MahalanobisMixin(BaseMetricLearner, MetricTransformer,
224227
metaclass=ABCMeta):
225228
r"""Mahalanobis metric learning algorithms.

metric_learn/oasis.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .base_metric import BilinearMixin
22
import numpy as np
33

4+
45
class OASIS(BilinearMixin):
56

67
def __init__(self, preprocessor=None):
@@ -16,7 +17,8 @@ def fit(self, X, y):
1617
y : (n) data labels
1718
"""
1819
X = self._prepare_inputs(X, y, ensure_min_samples=2)
19-
20+
2021
# Dummy fit
21-
self.components_ = np.random.rand(np.shape(X[0])[-1], np.shape(X[0])[-1])
22-
return self
22+
self.components_ = np.random.rand(
23+
np.shape(X[0])[-1], np.shape(X[0])[-1])
24+
return self

test_bilinear.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from numpy.testing import assert_array_almost_equal
44

5+
56
def test_toy_distance():
67
# Random generalized test for 2 points
78
d = 100
@@ -10,33 +11,34 @@ def test_toy_distance():
1011
v = np.random.rand(d)
1112

1213
mixin = OASIS()
13-
mixin.fit([u, v], [0, 0]) # Dummy fit
14+
mixin.fit([u, v], [0, 0]) # Dummy fit
1415

1516
# The distances must match, whether calc with get_metric() or score_pairs()
1617
dist1 = mixin.score_pairs([[u, v], [v, u]])
1718
dist2 = [mixin.get_metric()(u, v), mixin.get_metric()(v, u)]
18-
19+
1920
u_v = (np.dot(np.dot(u.T, mixin.get_bilinear_matrix()), v))
2021
v_u = (np.dot(np.dot(v.T, mixin.get_bilinear_matrix()), u))
2122
desired = [u_v, v_u]
22-
23+
2324
assert_array_almost_equal(dist1, desired)
2425
assert_array_almost_equal(dist2, desired)
2526

2627
# Handmade example
27-
u = np.array([0, 1 ,2])
28+
u = np.array([0, 1, 2])
2829
v = np.array([3, 4, 5])
2930

30-
mixin.components_= np.array([[2,4,6], [6,4,2], [1, 2, 3]])
31+
mixin.components_ = np.array([[2, 4, 6], [6, 4, 2], [1, 2, 3]])
3132
dists = mixin.score_pairs([[u, v], [v, u]])
3233
assert_array_almost_equal(dists, [96, 120])
3334

3435
# Symetric example
35-
u = np.array([0, 1 ,2])
36+
u = np.array([0, 1, 2])
3637
v = np.array([3, 4, 5])
3738

38-
mixin.components_= np.identity(3) # Identity matrix
39+
mixin.components_ = np.identity(3) # Identity matrix
3940
dists = mixin.score_pairs([[u, v], [v, u]])
4041
assert_array_almost_equal(dists, [14, 14])
4142

42-
test_toy_distance()
43+
44+
test_toy_distance()

0 commit comments

Comments
 (0)