Skip to content

Commit 0147c0c

Browse files
mvargas33mvargas33
authored and
mvargas33
committed
FIrst draft of bilinear mixin
1 parent 4b7cdec commit 0147c0c

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

metric_learn/base_metric.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,58 @@ def transform(self, X):
160160
Input data transformed to the metric space by :math:`XL^{\\top}`
161161
"""
162162

163+
class BilinearMixin(BaseMetricLearner, metaclass=ABCMeta):
164+
165+
def score_pairs(self, pairs):
166+
r"""
167+
Parameters
168+
----------
169+
pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
170+
3D Array of pairs to score, with each row corresponding to two points,
171+
for 2D array of indices of pairs if the metric learner uses a
172+
preprocessor.
173+
174+
Returns
175+
-------
176+
scores : `numpy.ndarray` of shape=(n_pairs,)
177+
The learned Mahalanobis distance for every pair.
178+
"""
179+
check_is_fitted(self, ['preprocessor_', 'components_'])
180+
pairs = check_input(pairs, type_of_inputs='tuples',
181+
preprocessor=self.preprocessor_,
182+
estimator=self, tuple_size=2)
183+
return np.dot(np.dot(pairs[:, 1, :], self.components_), pairs[:, 0, :].T)
184+
185+
def get_metric(self):
186+
check_is_fitted(self, 'components_')
187+
components = self.components_.copy()
188+
189+
def metric_fun(u, v):
190+
"""This function computes the metric between u and v, according to the
191+
previously learned metric.
192+
193+
Parameters
194+
----------
195+
u : array-like, shape=(n_features,)
196+
The first point involved in the distance computation.
197+
198+
v : array-like, shape=(n_features,)
199+
The second point involved in the distance computation.
200+
201+
Returns
202+
-------
203+
distance : float
204+
The distance between u and v according to the new metric.
205+
"""
206+
u = validate_vector(u)
207+
v = validate_vector(v)
208+
return np.dot(np.dot(u, components), v.T)
209+
210+
return metric_fun
211+
212+
def get_bilinear_matrix(self):
213+
check_is_fitted(self, 'components_')
214+
return self.components_
163215

164216
class MahalanobisMixin(BaseMetricLearner, MetricTransformer,
165217
metaclass=ABCMeta):

metric_learn/oasis.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from .base_metric import BilinearMixin
2+
import numpy as np
3+
4+
class OASIS(BilinearMixin):
5+
6+
def __init__(self, preprocessor=None):
7+
super().__init__(preprocessor=preprocessor)
8+
9+
def fit(self, X, y):
10+
"""
11+
Fit OASIS model
12+
13+
Parameters
14+
----------
15+
X : (n x d) array of samples
16+
y : (n) data labels
17+
"""
18+
X = self._prepare_inputs(X, y, ensure_min_samples=2)
19+
self.components_ = np.identity(np.shape(X[0])[-1]) # Identity matrix
20+
return self

test_bilinear.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from metric_learn.oasis import OASIS
2+
import numpy as np
3+
4+
def test_toy_distance():
5+
u = np.array([0, 1, 2])
6+
v = np.array([3, 4, 5])
7+
8+
mixin = OASIS()
9+
mixin.fit([u, v], [0, 0])
10+
#mixin.components_ = np.array([[1, 0, 0],[0, 1, 0],[0, 0, 1]])
11+
12+
dist = mixin.score_pairs([[u, v]])
13+
print(dist)
14+
15+
test_toy_distance()

0 commit comments

Comments
 (0)