Skip to content

Commit b1edc46

Browse files
author
mvargas33
committed
Update method's descriptions
1 parent 9a10e06 commit b1edc46

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

metric_learn/base_metric.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,20 +162,50 @@ def transform(self, X):
162162

163163

164164
class BilinearMixin(BaseMetricLearner, metaclass=ABCMeta):
165+
r"""Bilinear similarity learning algorithms.
166+
167+
Algorithm that learns a Bilinear (pseudo) similarity :math:`s_M(x, x')`,
168+
defined between two column vectors :math:`x` and :math:`x'` by: :math:
169+
`s_M(x, x') = x M x'`, where :math:`M` is a learned matrix. This matrix
170+
is not guaranteed to be symmetric nor positive semi-definite (PSD). Thus
171+
it cannot be seen as learning a linear transformation of the original
172+
space like Mahalanobis learning algorithms.
173+
174+
Attributes
175+
----------
176+
components_ : `numpy.ndarray`, shape=(n_components, n_features)
177+
The learned bilinear matrix ``M``.
178+
"""
165179

166180
def score_pairs(self, pairs):
167-
r"""
181+
r"""Returns the learned Bilinear similarity between pairs.
182+
183+
This similarity is defined as: :math:`s_M(x, x') = x M x'`
184+
where ``M`` is the learned Bilinear matrix, for every pair of points
185+
``x`` and ``x'``.
186+
168187
Parameters
169188
----------
170189
pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
171190
3D Array of pairs to score, with each row corresponding to two points,
172-
for 2D array of indices of pairs if the metric learner uses a
191+
for 2D array of indices of pairs if the similarity learner uses a
173192
preprocessor.
174193
175194
Returns
176195
-------
177196
scores : `numpy.ndarray` of shape=(n_pairs,)
178-
The learned Mahalanobis distance for every pair.
197+
The learned Bilinear similarity for every pair.
198+
199+
See Also
200+
--------
201+
get_metric : a method that returns a function to compute the similarity
202+
between two points. The difference with `score_pairs` is that it works
203+
on two 1D arrays and cannot use a preprocessor. Besides, the returned
204+
function is independent of the similarity learner and hence is not
205+
modified if the similarity learner is.
206+
207+
:ref:`Bilinear_similarity` : The section of the project documentation
208+
that describes Bilinear similarity.
179209
"""
180210
check_is_fitted(self, ['preprocessor_', 'components_'])
181211
pairs = check_input(pairs, type_of_inputs='tuples',
@@ -184,36 +214,44 @@ def score_pairs(self, pairs):
184214
# Note: For bilinear order matters, dist(a,b) != dist(b,a)
185215
# We always choose first pair first, then second pair
186216
# (In contrast with Mahalanobis implementation)
187-
return (np.dot(pairs[:, 0, :], self.components_) * pairs[:, 1, :]).sum(-1)
217+
return np.sum(np.dot(pairs[:, 0, :], self.components_) * pairs[:, 1, :],
218+
axis=-1)
188219

189220
def get_metric(self):
190221
check_is_fitted(self, 'components_')
191222
components = self.components_.copy()
192223

193-
def metric_fun(u, v):
194-
"""This function computes the metric between u and v, according to the
195-
previously learned metric.
224+
def similarity_fun(u, v):
225+
"""This function computes the similarity between u and v, according to the
226+
previously learned similarity.
196227
197228
Parameters
198229
----------
199230
u : array-like, shape=(n_features,)
200-
The first point involved in the distance computation.
231+
The first point involved in the similarity computation.
201232
202233
v : array-like, shape=(n_features,)
203-
The second point involved in the distance computation.
234+
The second point involved in the similarity computation.
204235
205236
Returns
206237
-------
207-
distance : float
208-
The distance between u and v according to the new metric.
238+
similarity : float
239+
The similarity between u and v according to the new similarity.
209240
"""
210241
u = validate_vector(u)
211242
v = validate_vector(v)
212243
return np.dot(np.dot(u.T, components), v)
213244

214-
return metric_fun
245+
return similarity_fun
215246

216247
def get_bilinear_matrix(self):
248+
"""Returns a copy of the Bilinear matrix learned by the similarity learner.
249+
250+
Returns
251+
-------
252+
M : `numpy.ndarray`, shape=(n_features, n_features)
253+
The copy of the learned Bilinear matrix.
254+
"""
217255
check_is_fitted(self, 'components_')
218256
return self.components_
219257

test_bilinear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def op_2(pairs, components):
7676
for u, v in zip(pairs[:, 0, :], pairs[:, 1, :])])
7777

7878
def op_3(pairs, components):
79-
return (np.dot(pairs[:, 0, :], components) * pairs[:, 1, :]).sum(-1)
79+
return np.sum(np.dot(pairs[:, 0, :], components) * pairs[:, 1, :],
80+
axis=-1)
8081

8182
# Test first method
8283
start = timer()
@@ -89,9 +90,8 @@ def op_3(pairs, components):
8990
op_2(pairs, components)
9091
end = timer()
9192
print(f'Second method took {end - start}')
92-
9393
# Test second method
9494
start = timer()
9595
op_3(pairs, components)
9696
end = timer()
97-
print(f'Third method took {end - start}')
97+
print(f'Third method took {end - start}')

0 commit comments

Comments
 (0)