Skip to content

Commit e4685b1

Browse files
authored
[MRG] Create new Mahalanobis mixin (#96)
* WIP create MahalanobisMixin * ENH Update algorithms with Mahalanobis Mixin: - Make them inherit from Mahalanobis Mixin, and implement the metric_ property - Improve metric_ property by checking if it exists and raising the appropriate warning if not - Make tests work, by replacing metric() with metric_ * FIX: add missing import * FIX: update sklearn's function check_no_fit_attributes_set_in_init to new check_no_attributes_set_in_init" This new function was introduced through PR scikit-learn/scikit-learn#9450 in scikit-learn. It also allows to pass tests that would otherwise not pass: indeed having abstract attributes as properties threw an error. But the new test functions handles well this property inheritance. * FIX: take function ``_get_args`` from scikit-learn's PR scikit-learn/scikit-learn#9450 Indeed, in the PR this function is modified to support python 2. This should solve the CI error. * ENH: add transformer_ attribute and improve docstring * WIP: move transform() in BaseMetricLearner to transformer_from_metric() in MahalanobisMixin * WIP: refactor metric to original formulation: a function, with result computed from the transformer * WIP: make all Mahalanobis Metric Learner algorithms have transformer_ and metric() * ENH Add score_pairs function - Make MahalanobisMixin inherit from BaseMetricLearner to give a concrete implementation of score_pairs - Use score_pairs to compute more easily predict - Add docstring - TST: for every algorithm: - test that using score_pairs pairwise returns an euclidean distance matrix - test that score_pairs works for 3D arrays of several pairs as well as 2D arrays of one pair (and there returns only a scalar) - test that score_pairs always returns a finite output * TST add test on toy example for score_pairs * ENH Add embed function - add the function and docstring - use it for score_pairs - TST : - should be finite - have right output dimension - embedding should be linear - should work on a toy example * FIX fix error in slicing of quadruplets * FIX minor corrections * FIX minor corrections - remove unusual s to test functions - remove redundant parenthesis * FIX fix PEP8 errors * FIX remove possible one-sample scoring from docstring for now * REF rename n_features_out to num_dims to be more coherent with current algorithms * MAINT: Adress #96 (review) - replace embed by transform and add always the input X in calling the function - mutualize _transformer_from_metric not to be overwritten in MMC - improve test_mahalanobis_mixin.test_score_pairs_pairwise according to #96 (comment) - improve test_mahalanobis_mixin.check_is_distance_matrix - correct typos and nitpicks * ENH: Add check_tuples * FIX: fix parenthesis * ENH: put docstring of transformer_ in each metric learner * FIX: style knitpicks to uniformize transformer_ docstring with childs * FIX: make transformer_from_metric public * Address #96 (review) * FIX: fix pairwise distances check * FIX: ensure random state is set in all tests * FIX: fix test with real value to test in check_tuples * FIX: update MetricTransformer to be abstract method and have the appropriate doc * MAINT: make BaseMetricLearner and MetricTransformer abstract * MAINT: remove __init__ method from BaseMetricLearner Since it is an abstract class it already returns an error at instanciation: TypeError: Can't instantiate abstract class BaseMetricLearner with abstract methods score_pairs
1 parent 24b0def commit e4685b1

19 files changed

+737
-196
lines changed

examples/sandwich.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def sandwich_demo():
3030

3131
for ax_num, ml in enumerate(mls, start=3):
3232
ml.fit(x, y)
33-
tx = ml.transform()
33+
tx = ml.transform(x)
3434
ml_knn = nearest_neighbors(tx, k=2)
3535
ax = plt.subplot(3, 2, ax_num)
3636
plot_sandwich_data(tx, y, axis=ax)

metric_learn/_util.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,42 @@ def vector_norm(X):
99
return np.apply_along_axis(np.linalg.norm, 1, X)
1010
else:
1111
def vector_norm(X):
12-
return np.linalg.norm(X, axis=1)
12+
return np.linalg.norm(X, axis=1)
13+
14+
15+
def check_tuples(tuples):
16+
"""Check that the input is a valid 3D array representing a dataset of tuples.
17+
18+
Equivalent of `check_array` in scikit-learn.
19+
20+
Parameters
21+
----------
22+
tuples : object
23+
The tuples to check.
24+
25+
Returns
26+
-------
27+
tuples_valid : object
28+
The validated input.
29+
"""
30+
# If input is scalar raise error
31+
if np.isscalar(tuples):
32+
raise ValueError(
33+
"Expected 3D array, got scalar instead. Cannot apply this function on "
34+
"scalars.")
35+
# If input is 1D raise error
36+
if len(tuples.shape) == 1:
37+
raise ValueError(
38+
"Expected 3D array, got 1D array instead:\ntuples={}.\n"
39+
"Reshape your data using tuples.reshape(1, -1, 1) if it contains a "
40+
"single tuple and the points in the tuple have a single "
41+
"feature.".format(tuples))
42+
# If input is 2D raise error
43+
if len(tuples.shape) == 2:
44+
raise ValueError(
45+
"Expected 3D array, got 2D array instead:\ntuples={}.\n"
46+
"Reshape your data either using tuples.reshape(-1, {}, 1) if "
47+
"your data has a single feature or tuples.reshape(1, {}, -1) "
48+
"if it contains a single tuple.".format(tuples, tuples.shape[1],
49+
tuples.shape[0]))
50+
return tuples

metric_learn/base_metric.py

Lines changed: 144 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,148 @@
11
from numpy.linalg import cholesky
2-
from sklearn.base import BaseEstimator, TransformerMixin
2+
from sklearn.base import BaseEstimator
33
from sklearn.utils.validation import check_array
44
from sklearn.metrics import roc_auc_score
55
import numpy as np
6+
from abc import ABCMeta, abstractmethod
7+
import six
8+
from ._util import check_tuples
69

710

8-
class BaseMetricLearner(BaseEstimator):
9-
def __init__(self):
10-
raise NotImplementedError('BaseMetricLearner should not be instantiated')
11+
class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)):
1112

12-
def metric(self):
13-
"""Computes the Mahalanobis matrix from the transformation matrix.
13+
@abstractmethod
14+
def score_pairs(self, pairs):
15+
"""Returns the score between pairs
16+
(can be a similarity, or a distance/metric depending on the algorithm)
1417
15-
.. math:: M = L^{\\top} L
18+
Parameters
19+
----------
20+
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
21+
3D array of pairs.
1622
1723
Returns
1824
-------
19-
M : (d x d) matrix
25+
scores: `numpy.ndarray` of shape=(n_pairs,)
26+
The score of every pair.
2027
"""
21-
L = self.transformer()
22-
return L.T.dot(L)
2328

24-
def transformer(self):
25-
"""Computes the transformation matrix from the Mahalanobis matrix.
2629

27-
L = cholesky(M).T
30+
class MetricTransformer(six.with_metaclass(ABCMeta)):
31+
32+
@abstractmethod
33+
def transform(self, X):
34+
"""Applies the metric transformation.
35+
36+
Parameters
37+
----------
38+
X : (n x d) matrix
39+
Data to transform.
2840
2941
Returns
3042
-------
31-
L : upper triangular (d x d) matrix
43+
transformed : (n x d) matrix
44+
Input data transformed to the metric space by :math:`XL^{\\top}`
3245
"""
33-
return cholesky(self.metric()).T
3446

3547

36-
class MetricTransformer(TransformerMixin):
48+
class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner,
49+
MetricTransformer)):
50+
"""Mahalanobis metric learning algorithms.
51+
52+
Algorithm that learns a Mahalanobis (pseudo) distance :math:`d_M(x, x')`,
53+
defined between two column vectors :math:`x` and :math:`x'` by: :math:`d_M(x,
54+
x') = \sqrt{(x-x')^T M (x-x')}`, where :math:`M` is a learned symmetric
55+
positive semi-definite (PSD) matrix. The metric between points can then be
56+
expressed as the euclidean distance between points embedded in a new space
57+
through a linear transformation. Indeed, the above matrix can be decomposed
58+
into the product of two transpose matrices (through SVD or Cholesky
59+
decomposition): :math:`d_M(x, x')^2 = (x-x')^T M (x-x') = (x-x')^T L^T L
60+
(x-x') = (L x - L x')^T (L x- L x')`
61+
62+
Attributes
63+
----------
64+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
65+
The learned linear transformation ``L``.
66+
"""
67+
68+
def score_pairs(self, pairs):
69+
"""Returns the learned Mahalanobis distance between pairs.
70+
71+
This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}`
72+
where ``M`` is the learned Mahalanobis matrix, for every pair of points
73+
``x`` and ``x'``. This corresponds to the euclidean distance between
74+
embeddings of the points in a new space, obtained through a linear
75+
transformation. Indeed, we have also: :math:`d_M(x, x') = \sqrt{(x_e -
76+
x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See
77+
:class:`MahalanobisMixin`).
3778
38-
def transform(self, X=None):
39-
"""Applies the metric transformation.
79+
Parameters
80+
----------
81+
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
82+
3D array of pairs, or 2D array of one pair.
83+
84+
Returns
85+
-------
86+
scores: `numpy.ndarray` of shape=(n_pairs,)
87+
The learned Mahalanobis distance for every pair.
88+
"""
89+
pairs = check_tuples(pairs)
90+
pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :])
91+
# (for MahalanobisMixin, the embedding is linear so we can just embed the
92+
# difference)
93+
return np.sqrt(np.sum(pairwise_diffs**2, axis=-1))
94+
95+
def transform(self, X):
96+
"""Embeds data points in the learned linear embedding space.
97+
98+
Transforms samples in ``X`` into ``X_embedded``, samples inside a new
99+
embedding space such that: ``X_embedded = X.dot(L.T)``, where ``L`` is
100+
the learned linear transformation (See :class:`MahalanobisMixin`).
40101
41102
Parameters
42103
----------
43-
X : (n x d) matrix, optional
44-
Data to transform. If not supplied, the training data will be used.
104+
X : `numpy.ndarray`, shape=(n_samples, n_features)
105+
The data points to embed.
45106
46107
Returns
47108
-------
48-
transformed : (n x d) matrix
49-
Input data transformed to the metric space by :math:`XL^{\\top}`
109+
X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims)
110+
The embedded data points.
111+
"""
112+
X_checked = check_array(X, accept_sparse=True)
113+
return X_checked.dot(self.transformer_.T)
114+
115+
def metric(self):
116+
return self.transformer_.T.dot(self.transformer_)
117+
118+
def transformer_from_metric(self, metric):
119+
"""Computes the transformation matrix from the Mahalanobis matrix.
120+
121+
Since by definition the metric `M` is positive semi-definite (PSD), it
122+
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
123+
computation of the Cholesky decomposition used does not support
124+
non-definite matrices. If the metric is not definite, this method will
125+
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
126+
decomposition of M with the eigenvalues in the diagonal matrix w and the
127+
columns of V being the eigenvectors. If M is diagonal, this method will
128+
just return its elementwise square root (since the diagonalization of
129+
the matrix is itself).
130+
131+
Returns
132+
-------
133+
L : (d x d) matrix
50134
"""
51-
if X is None:
52-
X = self.X_
135+
136+
if np.allclose(metric, np.diag(np.diag(metric))):
137+
return np.sqrt(metric)
138+
elif not np.isclose(np.linalg.det(metric), 0):
139+
return cholesky(metric).T
53140
else:
54-
X = check_array(X, accept_sparse=True)
55-
L = self.transformer()
56-
return X.dot(L.T)
141+
w, V = np.linalg.eigh(metric)
142+
return V.T * np.sqrt(np.maximum(0, w[:, None]))
57143

58144

59-
class _PairsClassifierMixin:
145+
class _PairsClassifierMixin(BaseMetricLearner):
60146

61147
def predict(self, pairs):
62148
"""Predicts the learned metric between input pairs.
@@ -74,11 +160,11 @@ def predict(self, pairs):
74160
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
75161
The predicted learned metric value between samples in every pair.
76162
"""
77-
pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :]
78-
return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs,
79-
axis=1))
163+
pairs = check_tuples(pairs)
164+
return self.score_pairs(pairs)
80165

81166
def decision_function(self, pairs):
167+
pairs = check_tuples(pairs)
82168
return self.predict(pairs)
83169

84170
def score(self, pairs, y):
@@ -104,12 +190,32 @@ def score(self, pairs, y):
104190
score : float
105191
The ``roc_auc`` score.
106192
"""
193+
pairs = check_tuples(pairs)
107194
return roc_auc_score(y, self.decision_function(pairs))
108195

109196

110-
class _QuadrupletsClassifierMixin:
197+
class _QuadrupletsClassifierMixin(BaseMetricLearner):
111198

112199
def predict(self, quadruplets):
200+
"""Predicts the ordering between sample distances in input quadruplets.
201+
202+
For each quadruplet, returns 1 if the quadruplet is in the right order (
203+
first pair is more similar than second pair), and -1 if not.
204+
205+
Parameters
206+
----------
207+
quadruplets : array-like, shape=(n_constraints, 4, n_features)
208+
Input quadruplets.
209+
210+
Returns
211+
-------
212+
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
213+
Predictions of the ordering of pairs, for each quadruplet.
214+
"""
215+
quadruplets = check_tuples(quadruplets)
216+
return np.sign(self.decision_function(quadruplets))
217+
218+
def decision_function(self, quadruplets):
113219
"""Predicts differences between sample distances in input quadruplets.
114220
115221
For each quadruplet of samples, computes the difference between the learned
@@ -122,18 +228,12 @@ def predict(self, quadruplets):
122228
123229
Returns
124230
-------
125-
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
231+
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
126232
Metric differences.
127233
"""
128-
similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :]
129-
dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :]
130-
return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) *
131-
similar_diffs, axis=1)) -
132-
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) *
133-
dissimilar_diffs, axis=1)))
134-
135-
def decision_function(self, quadruplets):
136-
return self.predict(quadruplets)
234+
quadruplets = check_tuples(quadruplets)
235+
return (self.score_pairs(quadruplets[:, :2, :]) -
236+
self.score_pairs(quadruplets[:, 2:, :]))
137237

138238
def score(self, quadruplets, y=None):
139239
"""Computes score on input quadruplets
@@ -154,4 +254,5 @@ def score(self, quadruplets, y=None):
154254
score : float
155255
The quadruplets score.
156256
"""
157-
return - np.mean(np.sign(self.decision_function(quadruplets)))
257+
quadruplets = check_tuples(quadruplets)
258+
return -np.mean(self.predict(quadruplets))

metric_learn/covariance.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,24 @@
1111
from __future__ import absolute_import
1212
import numpy as np
1313
from sklearn.utils.validation import check_array
14+
from sklearn.base import TransformerMixin
1415

15-
from .base_metric import BaseMetricLearner, MetricTransformer
16+
from .base_metric import MahalanobisMixin
1617

1718

18-
class Covariance(BaseMetricLearner, MetricTransformer):
19+
class Covariance(MahalanobisMixin, TransformerMixin):
20+
"""Covariance metric (baseline method)
21+
22+
Attributes
23+
----------
24+
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
25+
The linear transformation ``L`` deduced from the learned Mahalanobis
26+
metric (See :meth:`transformer_from_metric`.)
27+
"""
28+
1929
def __init__(self):
2030
pass
2131

22-
def metric(self):
23-
return self.M_
24-
2532
def fit(self, X, y=None):
2633
"""
2734
X : data matrix, (n x d)
@@ -33,4 +40,6 @@ def fit(self, X, y=None):
3340
self.M_ = 1./self.M_
3441
else:
3542
self.M_ = np.linalg.inv(self.M_)
43+
44+
self.transformer_ = self.transformer_from_metric(check_array(self.M_))
3645
return self

0 commit comments

Comments
 (0)