Skip to content

Commit d2c0614

Browse files
author
William de Vazelhes
committed
MAINT: address comments from review #152 (review)
1 parent d943406 commit d2c0614

File tree

2 files changed

+12
-50
lines changed

2 files changed

+12
-50
lines changed

metric_learn/base_metric.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -241,47 +241,6 @@ def transform(self, X):
241241
return X_checked.dot(self.transformer_.T)
242242

243243
def get_metric(self):
244-
"""Returns a function that takes as input two 1D arrays and outputs the
245-
learned metric score on these two points.
246-
247-
This function will be independent from the metric learner that learned it
248-
(it will not be modified if the initial metric learner is modified),
249-
and it can be directly plugged into the `metric` argument of
250-
scikit-learn's estimators.
251-
252-
Returns
253-
-------
254-
metric_fun : function
255-
The function described above.
256-
257-
Examples
258-
--------
259-
.. doctest::
260-
261-
>>> from metric_learn import NCA
262-
>>> from sklearn.datasets import make_classification
263-
>>> from sklearn.neighbors import KNeighborsClassifier
264-
>>> nca = NCA()
265-
>>> X, y = make_classification()
266-
>>> nca.fit(X, y)
267-
>>> knn = KNeighborsClassifier(metric=nca.get_metric())
268-
>>> knn.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
269-
KNeighborsClassifier(algorithm='auto', leaf_size=30,
270-
metric=<function MahalanobisMixin.get_metric.<locals>.metric_fun
271-
at 0x...>,
272-
metric_params=None, n_jobs=None, n_neighbors=5, p=2,
273-
weights='uniform')
274-
275-
See Also
276-
--------
277-
score_pairs : a method that returns the metric score between several pairs
278-
of points. Unlike `get_metric`, this is a method of the metric learner
279-
and therefore can change if the metric learner changes. Besides, it can
280-
use the metric learner's preprocessor, and works on concatenated arrays.
281-
282-
:ref:`mahalanobis_distances` : The section of the project documentation
283-
that describes Mahalanobis Distances.
284-
"""
285244
transformer_T = self.transformer_.T.copy()
286245

287246
def metric_fun(u, v):
@@ -304,12 +263,14 @@ def metric_fun(u, v):
304263
return euclidean(u.dot(transformer_T), v.dot(transformer_T))
305264
return metric_fun
306265

266+
get_metric.__doc__ = BaseMetricLearner.get_metric.__doc__
267+
307268
def metric(self):
308269
# TODO: remove this method in version 0.6.0
309270
warnings.warn(("`metric` is deprecated since version 0.5.0 and will be "
310271
"removed in 0.6.0. Use `get_mahalanobis_matrix` instead."),
311272
DeprecationWarning)
312-
return self.transformer_.T.dot(self.transformer_)
273+
return self.get_mahalanobis_matrix()
313274

314275
def get_mahalanobis_matrix(self):
315276
"""Returns a copy of the Mahalanobis matrix learned by the metric learner.
@@ -319,7 +280,7 @@ def get_mahalanobis_matrix(self):
319280
M : `numpy.ndarray`, shape=(n_components, n_features)
320281
The copy of the learned Mahalanobis matrix.
321282
"""
322-
return self.transformer_.T.dot(self.transformer_).copy()
283+
return self.transformer_.T.dot(self.transformer_)
323284

324285

325286
class _PairsClassifierMixin(BaseMetricLearner):

test/test_mahalanobis_mixin.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import numpy as np
55
from numpy.testing import assert_array_almost_equal, assert_allclose
6-
from scipy.spatial.distance import pdist, squareform, euclidean
6+
from scipy.spatial.distance import pdist, squareform, mahalanobis
77
from sklearn import clone
88
from sklearn.cluster import DBSCAN
99
from sklearn.utils import check_random_state
@@ -172,10 +172,10 @@ def test_embed_is_linear(estimator, build_dataset):
172172

173173
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
174174
ids=ids_metric_learners)
175-
def test_get_metric_equivalent_to_transform_and_euclidean(estimator,
176-
build_dataset):
177-
"""Tests that the get_metric method of mahalanobis metric learners is the
178-
euclidean distance in the transformed space
175+
def test_get_metric_equivalent_to_explicit_mahalanobis(estimator,
176+
build_dataset):
177+
"""Tests that using the get_metric method of mahalanobis metric learners is
178+
equivalent to explicitely calling scipy's mahalanobis metric
179179
"""
180180
rng = np.random.RandomState(42)
181181
input_data, labels, _, X = build_dataset()
@@ -185,8 +185,9 @@ def test_get_metric_equivalent_to_transform_and_euclidean(estimator,
185185
metric = model.get_metric()
186186
n_features = X.shape[1]
187187
a, b = (rng.randn(n_features), rng.randn(n_features))
188-
euc_dist = euclidean(model.transform(a[None]), model.transform(b[None]))
189-
assert_allclose(metric(a, b), euc_dist, rtol=1e-15)
188+
expected_dist = mahalanobis(a[None], b[None],
189+
VI=model.get_mahalanobis_matrix())
190+
assert_allclose(metric(a, b), expected_dist, rtol=1e-15)
190191

191192

192193
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,

0 commit comments

Comments
 (0)