Skip to content

Commit d3620bb

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] Refactor the metric() method (#152)
* MAINT Rename metric() into get_mahalanobis_matrix() * ENH: refactor methods to get the metric * DOC: change description of distance into pseudo-metric * MAINT: make description clearer * ENH: enhance description * MAINT: remove the 1D part in case we allow 2D * FIX: fix expression for mahalanobis distance * TST: Add tests * ENH: deal with the 1D case * Rename forgotten point 1 and point 2 to u and v * STY: Fix PEP8 errors * Address all comments * Revert changes in metric_plotting included by mistake * FIX: use custom validate_vector * TST: fix syntax error for assert in test * Add tolerance for triangular inequality because MMC probably projected onto a line * MAINT: address comments from review #152 (review) * ENH: add squared option * FIX fix test that was failing du to a non 2D transformer: - ensure that the transformer_ fitted is always 2D: - in the result returned from transformer_from_metric - in the code of metric learners, for metric learners that don't call transformer_from_metric - for metric learners that cannot work on 1 feature, ensure it when checking the input - add a test to check this behaviour * FIX: remove message that is not supported anymore by python newer versions and replace it by str * TST: make shape testing more precise * TST: enforce the 2d transformer test for everyone, and make it pass for RCA and SDML * TST: fix typo in removing * Remove unnecessary calls of np.atleast2d * Add functions to commented doc
1 parent 8ffd998 commit d3620bb

11 files changed

+385
-23
lines changed

doc/conf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
'sphinx.ext.viewcode',
88
'sphinx.ext.mathjax',
99
'numpydoc',
10-
'sphinx_gallery.gen_gallery'
10+
'sphinx_gallery.gen_gallery',
11+
'sphinx.ext.doctest'
1112
]
1213

1314
templates_path = ['_templates']
@@ -35,3 +36,6 @@
3536
# Option to only need single backticks to refer to symbols
3637
default_role = 'any'
3738

39+
# Option to hide doctests comments in the documentation (like # doctest:
40+
# +NORMALIZE_WHITESPACE for instance)
41+
trim_doctest_flags = True

doc/introduction.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ generally formulated as an optimization problem where one seeks to find the
3838
parameters of a distance function that optimize some objective function
3939
measuring the agreement with the training data.
4040

41+
.. _mahalanobis_distances:
42+
4143
Mahalanobis Distances
4244
=====================
4345

@@ -124,7 +126,9 @@ to the following resources:
124126
.. Currently, each metric learning algorithm supports the following methods:
125127
126128
.. - ``fit(...)``, which learns the model.
127-
.. - ``metric()``, which returns a Mahalanobis matrix
129+
.. - ``get_mahalanobis_matrix()``, which returns a Mahalanobis matrix
130+
.. - ``get_metric()``, which returns a function that takes as input two 1D
131+
arrays and outputs the learned metric score on these two points
128132
.. :math:`M = L^{\top}L` such that distance between vectors ``x`` and
129133
.. ``y`` can be computed as :math:`\sqrt{\left(x-y\right)M\left(x-y\right)}`.
130134
.. - ``transformer_from_metric(metric)``, which returns a transformation matrix

metric_learn/_util.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,14 @@ def transformer_from_metric(metric):
349349
else:
350350
w, V = np.linalg.eigh(metric)
351351
return V.T * np.sqrt(np.maximum(0, w[:, None]))
352+
353+
354+
def validate_vector(u, dtype=None):
355+
# replica of scipy.spatial.distance._validate_vector, for making scipy
356+
# compatible functions on vectors (such as distances computations)
357+
u = np.asarray(u, dtype=dtype, order='c').squeeze()
358+
# Ensure values such as u=1 and u=[1] still return 1-D arrays.
359+
u = np.atleast_1d(u)
360+
if u.ndim > 1:
361+
raise ValueError("Input vector should be 1-D.")
362+
return u

metric_learn/base_metric.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from numpy.linalg import cholesky
2+
from scipy.spatial.distance import euclidean
13
from sklearn.base import BaseEstimator
24
from sklearn.utils.validation import _is_arraylike
35
from sklearn.metrics import roc_auc_score
46
import numpy as np
57
from abc import ABCMeta, abstractmethod
68
import six
7-
from ._util import ArrayIndexer, check_input
9+
from ._util import ArrayIndexer, check_input, validate_vector
10+
import warnings
811

912

1013
class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)):
@@ -34,6 +37,14 @@ def score_pairs(self, pairs):
3437
-------
3538
scores: `numpy.ndarray` of shape=(n_pairs,)
3639
The score of every pair.
40+
41+
See Also
42+
--------
43+
get_metric : a method that returns a function to compute the metric between
44+
two points. The difference with `score_pairs` is that it works on two 1D
45+
arrays and cannot use a preprocessor. Besides, the returned function is
46+
independent of the metric learner and hence is not modified if the metric
47+
learner is.
3748
"""
3849

3950
def check_preprocessor(self):
@@ -85,6 +96,47 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic',
8596
tuple_size=getattr(self, '_tuple_size', None),
8697
**kwargs)
8798

99+
@abstractmethod
100+
def get_metric(self):
101+
"""Returns a function that takes as input two 1D arrays and outputs the
102+
learned metric score on these two points.
103+
104+
This function will be independent from the metric learner that learned it
105+
(it will not be modified if the initial metric learner is modified),
106+
and it can be directly plugged into the `metric` argument of
107+
scikit-learn's estimators.
108+
109+
Returns
110+
-------
111+
metric_fun : function
112+
The function described above.
113+
114+
115+
Examples
116+
--------
117+
.. doctest::
118+
119+
>>> from metric_learn import NCA
120+
>>> from sklearn.datasets import make_classification
121+
>>> from sklearn.neighbors import KNeighborsClassifier
122+
>>> nca = NCA()
123+
>>> X, y = make_classification()
124+
>>> nca.fit(X, y)
125+
>>> knn = KNeighborsClassifier(metric=nca.get_metric())
126+
>>> knn.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
127+
KNeighborsClassifier(algorithm='auto', leaf_size=30,
128+
metric=<function MahalanobisMixin.get_metric.<locals>.metric_fun
129+
at 0x...>,
130+
metric_params=None, n_jobs=None, n_neighbors=5, p=2,
131+
weights='uniform')
132+
133+
See Also
134+
--------
135+
score_pairs : a method that returns the metric score between several pairs
136+
of points. Unlike `get_metric`, this is a method of the metric learner
137+
and therefore can change if the metric learner changes. Besides, it can
138+
use the metric learner's preprocessor, and works on concatenated arrays.
139+
"""
88140

89141
class MetricTransformer(six.with_metaclass(ABCMeta)):
90142

@@ -146,6 +198,17 @@ def score_pairs(self, pairs):
146198
-------
147199
scores: `numpy.ndarray` of shape=(n_pairs,)
148200
The learned Mahalanobis distance for every pair.
201+
202+
See Also
203+
--------
204+
get_metric : a method that returns a function to compute the metric between
205+
two points. The difference with `score_pairs` is that it works on two 1D
206+
arrays and cannot use a preprocessor. Besides, the returned function is
207+
independent of the metric learner and hence is not modified if the metric
208+
learner is.
209+
210+
:ref:`mahalanobis_distances` : The section of the project documentation
211+
that describes Mahalanobis Distances.
149212
"""
150213
pairs = check_input(pairs, type_of_inputs='tuples',
151214
preprocessor=self.preprocessor_,
@@ -177,7 +240,57 @@ def transform(self, X):
177240
accept_sparse=True)
178241
return X_checked.dot(self.transformer_.T)
179242

243+
def get_metric(self):
244+
transformer_T = self.transformer_.T.copy()
245+
246+
def metric_fun(u, v, squared=False):
247+
"""This function computes the metric between u and v, according to the
248+
previously learned metric.
249+
250+
Parameters
251+
----------
252+
u : array-like, shape=(n_features,)
253+
The first point involved in the distance computation.
254+
255+
v : array-like, shape=(n_features,)
256+
The second point involved in the distance computation.
257+
258+
squared : `bool`
259+
If True, the function will return the squared metric between u and
260+
v, which is faster to compute.
261+
262+
Returns
263+
-------
264+
distance: float
265+
The distance between u and v according to the new metric.
266+
"""
267+
u = validate_vector(u)
268+
v = validate_vector(v)
269+
transformed_diff = (u - v).dot(transformer_T)
270+
dist = np.dot(transformed_diff, transformed_diff.T)
271+
if not squared:
272+
dist = np.sqrt(dist)
273+
return dist
274+
275+
return metric_fun
276+
277+
get_metric.__doc__ = BaseMetricLearner.get_metric.__doc__
278+
180279
def metric(self):
280+
# TODO: remove this method in version 0.6.0
281+
warnings.warn(("`metric` is deprecated since version 0.5.0 and will be "
282+
"removed in 0.6.0. Use `get_mahalanobis_matrix` instead."),
283+
DeprecationWarning)
284+
return self.get_mahalanobis_matrix()
285+
286+
def get_mahalanobis_matrix(self):
287+
"""Returns a copy of the Mahalanobis matrix learned by the metric learner.
288+
289+
Returns
290+
-------
291+
M : `numpy.ndarray`, shape=(n_components, n_features)
292+
The copy of the learned Mahalanobis matrix.
293+
"""
181294
return self.transformer_.T.dot(self.transformer_)
182295

183296

metric_learn/rca.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def fit(self, X, chunks):
112112
chunks = np.asanyarray(chunks, dtype=int)
113113
chunk_mask, chunked_data = _chunk_mean_centering(X_t, chunks)
114114

115-
inner_cov = np.cov(chunked_data, rowvar=0, bias=1)
115+
inner_cov = np.atleast_2d(np.cov(chunked_data, rowvar=0, bias=1))
116116
dim = self._check_dimension(np.linalg.matrix_rank(inner_cov), X_t)
117117

118118
# Fisher Linear Discriminant projection
@@ -122,13 +122,13 @@ def fit(self, X, chunks):
122122
vals, vecs = np.linalg.eig(tmp)
123123
inds = np.argsort(vals)[:dim]
124124
A = vecs[:, inds]
125-
inner_cov = A.T.dot(inner_cov).dot(A)
125+
inner_cov = np.atleast_2d(A.T.dot(inner_cov).dot(A))
126126
self.transformer_ = _inv_sqrtm(inner_cov).dot(A.T)
127127
else:
128128
self.transformer_ = _inv_sqrtm(inner_cov).T
129129

130130
if M_pca is not None:
131-
self.transformer_ = self.transformer_.dot(M_pca)
131+
self.transformer_ = np.atleast_2d(self.transformer_.dot(M_pca))
132132

133133
return self
134134

metric_learn/sdml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _fit(self, pairs, y):
5858
# set up prior M
5959
if self.use_cov:
6060
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
61-
self.M_ = pinvh(np.cov(X, rowvar = False))
61+
self.M_ = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
6262
else:
6363
self.M_ = np.identity(pairs.shape[2])
6464
diff = pairs[:, 0] - pairs[:, 1]

test/metric_learn_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_iris(self):
273273
self.assertLess(csep, 0.15)
274274

275275
# Sanity checks for learned matrices.
276-
self.assertEqual(lfda.metric().shape, (4, 4))
276+
self.assertEqual(lfda.get_mahalanobis_matrix().shape, (4, 4))
277277
self.assertEqual(lfda.transformer_.shape, (2, 4))
278278

279279

@@ -348,14 +348,16 @@ def test_iris(self):
348348
[+0.000868, +0.001468, -0.002021, -0.002879],
349349
[-0.001195, -0.002021, +0.002782, +0.003964],
350350
[-0.001703, -0.002879, +0.003964, +0.005648]]
351-
assert_array_almost_equal(expected, mmc.metric(), decimal=6)
351+
assert_array_almost_equal(expected, mmc.get_mahalanobis_matrix(),
352+
decimal=6)
352353

353354
# Diagonal metric
354355
mmc = MMC(diagonal=True)
355356
mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d]))
356357
expected = [0, 0, 1.210220, 1.228596]
357-
assert_array_almost_equal(np.diag(expected), mmc.metric(), decimal=6)
358-
358+
assert_array_almost_equal(np.diag(expected), mmc.get_mahalanobis_matrix(),
359+
decimal=6)
360+
359361
# Supervised Full
360362
mmc = MMC_Supervised()
361363
mmc.fit(self.iris_points, self.iris_labels)

test/test_base_metric.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import pytest
12
import unittest
23
import metric_learn
4+
import numpy as np
5+
from sklearn import clone
6+
from sklearn.utils.testing import set_random_state
7+
from test.test_utils import ids_metric_learners, metric_learners
38

49

510
class TestStringRepr(unittest.TestCase):
@@ -81,5 +86,82 @@ def test_mmc(self):
8186
num_labeled='deprecated', preprocessor=None, verbose=False)
8287
""".strip('\n'))
8388

89+
90+
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
91+
ids=ids_metric_learners)
92+
def test_get_metric_is_independent_from_metric_learner(estimator,
93+
build_dataset):
94+
"""Tests that the get_metric method returns a function that is independent
95+
from the original metric learner"""
96+
input_data, labels, _, X = build_dataset()
97+
model = clone(estimator)
98+
set_random_state(model)
99+
100+
# we fit the metric learner on it and then we compute the metric on some
101+
# points
102+
model.fit(input_data, labels)
103+
metric = model.get_metric()
104+
score = metric(X[0], X[1])
105+
106+
# then we refit the estimator on another dataset
107+
model.fit(np.sin(input_data), labels)
108+
109+
# we recompute the distance between the two points: it should be the same
110+
score_bis = metric(X[0], X[1])
111+
assert score_bis == score
112+
113+
114+
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
115+
ids=ids_metric_learners)
116+
def test_get_metric_raises_error(estimator, build_dataset):
117+
"""Tests that the metric returned by get_metric raises errors similar to
118+
the distance functions in scipy.spatial.distance"""
119+
input_data, labels, _, X = build_dataset()
120+
model = clone(estimator)
121+
set_random_state(model)
122+
model.fit(input_data, labels)
123+
metric = model.get_metric()
124+
125+
list_test_get_metric_raises = [(X[0].tolist() + [5.2], X[1]), # vectors with
126+
# different dimensions
127+
(X[0:4], X[1:5]), # 2D vectors
128+
(X[0].tolist() + [5.2], X[1] + [7.2])]
129+
# vectors of same dimension but incompatible with what the metric learner
130+
# was trained on
131+
132+
for u, v in list_test_get_metric_raises:
133+
with pytest.raises(ValueError):
134+
metric(u, v)
135+
136+
137+
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
138+
ids=ids_metric_learners)
139+
def test_get_metric_works_does_not_raise(estimator, build_dataset):
140+
"""Tests that the metric returned by get_metric does not raise errors (or
141+
warnings) similarly to the distance functions in scipy.spatial.distance"""
142+
input_data, labels, _, X = build_dataset()
143+
model = clone(estimator)
144+
set_random_state(model)
145+
model.fit(input_data, labels)
146+
metric = model.get_metric()
147+
148+
list_test_get_metric_doesnt_raise = [(X[0], X[1]),
149+
(X[0].tolist(), X[1].tolist()),
150+
(X[0][None], X[1][None])]
151+
152+
for u, v in list_test_get_metric_doesnt_raise:
153+
with pytest.warns(None) as record:
154+
metric(u, v)
155+
assert len(record) == 0
156+
157+
# Test that the scalar case works
158+
model.transformer_ = np.array([3.1])
159+
metric = model.get_metric()
160+
for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]:
161+
with pytest.warns(None) as record:
162+
metric(u, v)
163+
assert len(record) == 0
164+
165+
84166
if __name__ == '__main__':
85167
unittest.main()

0 commit comments

Comments
 (0)