Skip to content

[MRG] Refactor the metric() method #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a3384b1
MAINT Rename metric() into get_mahalanobis_matrix()
Jan 9, 2019
8e0d197
ENH: refactor methods to get the metric
Jan 9, 2019
6dd118e
DOC: change description of distance into pseudo-metric
Jan 9, 2019
c7e40f6
MAINT: make description clearer
Jan 9, 2019
1947ea5
ENH: enhance description
Jan 9, 2019
bee6902
MAINT: remove the 1D part in case we allow 2D
Jan 9, 2019
646cf97
FIX: fix expression for mahalanobis distance
Jan 10, 2019
00d37c9
TST: Add tests
Jan 10, 2019
c9eefb4
ENH: deal with the 1D case
Jan 10, 2019
bd6aac0
Rename forgotten point 1 and point 2 to u and v
Jan 10, 2019
22141f5
Merge branch 'master' into feat/metric_func
Jan 10, 2019
9e447f6
STY: Fix PEP8 errors
Jan 10, 2019
201320b
Address all comments
Jan 15, 2019
4b660fa
Revert changes in metric_plotting included by mistake
Jan 15, 2019
61a33cc
FIX: use custom validate_vector
Jan 15, 2019
72153ed
TST: fix syntax error for assert in test
Jan 15, 2019
d943406
Add tolerance for triangular inequality because MMC probably projecte…
Jan 16, 2019
d2c0614
MAINT: address comments from review https://github.com/metric-learn/m…
Jan 22, 2019
5e29295
ENH: add squared option
Jan 22, 2019
92669ae
FIX fix test that was failing du to a non 2D transformer:
Jan 23, 2019
a2955e0
FIX: remove message that is not supported anymore by python newer ver…
Jan 24, 2019
c8708b2
TST: make shape testing more precise
Jan 28, 2019
7d4efd9
TST: enforce the 2d transformer test for everyone, and make it pass f…
Jan 28, 2019
0c7c5dc
TST: fix typo in removing
Jan 28, 2019
7dfd874
Remove unnecessary calls of np.atleast2d
Jan 29, 2019
80c2943
Add functions to commented doc
Jan 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
'sphinx.ext.viewcode',
'sphinx.ext.mathjax',
'numpydoc',
'sphinx_gallery.gen_gallery'
'sphinx_gallery.gen_gallery',
'sphinx.ext.doctest'
]

templates_path = ['_templates']
Expand Down Expand Up @@ -35,3 +36,6 @@
# Option to only need single backticks to refer to symbols
default_role = 'any'

# Option to hide doctests comments in the documentation (like # doctest:
# +NORMALIZE_WHITESPACE for instance)
trim_doctest_flags = True
6 changes: 5 additions & 1 deletion doc/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ generally formulated as an optimization problem where one seeks to find the
parameters of a distance function that optimize some objective function
measuring the agreement with the training data.

.. _mahalanobis_distances:

Mahalanobis Distances
=====================

Expand Down Expand Up @@ -124,7 +126,9 @@ to the following resources:
.. Currently, each metric learning algorithm supports the following methods:

.. - ``fit(...)``, which learns the model.
.. - ``metric()``, which returns a Mahalanobis matrix
.. - ``get_mahalanobis_matrix()``, which returns a Mahalanobis matrix
.. - ``get_metric()``, which returns a function that takes as input two 1D
arrays and outputs the learned metric score on these two points
.. :math:`M = L^{\top}L` such that distance between vectors ``x`` and
.. ``y`` can be computed as :math:`\sqrt{\left(x-y\right)M\left(x-y\right)}`.
.. - ``transformer_from_metric(metric)``, which returns a transformation matrix
Expand Down
11 changes: 11 additions & 0 deletions metric_learn/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,14 @@ def transformer_from_metric(metric):
else:
w, V = np.linalg.eigh(metric)
return V.T * np.sqrt(np.maximum(0, w[:, None]))


def validate_vector(u, dtype=None):
# replica of scipy.spatial.distance._validate_vector, for making scipy
# compatible functions on vectors (such as distances computations)
u = np.asarray(u, dtype=dtype, order='c').squeeze()
# Ensure values such as u=1 and u=[1] still return 1-D arrays.
u = np.atleast_1d(u)
if u.ndim > 1:
raise ValueError("Input vector should be 1-D.")
return u
115 changes: 114 additions & 1 deletion metric_learn/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from numpy.linalg import cholesky
from scipy.spatial.distance import euclidean
from sklearn.base import BaseEstimator
from sklearn.utils.validation import _is_arraylike
from sklearn.metrics import roc_auc_score
import numpy as np
from abc import ABCMeta, abstractmethod
import six
from ._util import ArrayIndexer, check_input
from ._util import ArrayIndexer, check_input, validate_vector
import warnings


class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)):
Expand Down Expand Up @@ -34,6 +37,14 @@ def score_pairs(self, pairs):
-------
scores: `numpy.ndarray` of shape=(n_pairs,)
The score of every pair.

See Also
--------
get_metric : a method that returns a function to compute the metric between
two points. The difference with `score_pairs` is that it works on two 1D
arrays and cannot use a preprocessor. Besides, the returned function is
independent of the metric learner and hence is not modified if the metric
learner is.
"""

def check_preprocessor(self):
Expand Down Expand Up @@ -85,6 +96,47 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic',
tuple_size=getattr(self, '_tuple_size', None),
**kwargs)

@abstractmethod
def get_metric(self):
"""Returns a function that takes as input two 1D arrays and outputs the
learned metric score on these two points.

This function will be independent from the metric learner that learned it
(it will not be modified if the initial metric learner is modified),
and it can be directly plugged into the `metric` argument of
scikit-learn's estimators.

Returns
-------
metric_fun : function
The function described above.


Examples
--------
.. doctest::

>>> from metric_learn import NCA
>>> from sklearn.datasets import make_classification
>>> from sklearn.neighbors import KNeighborsClassifier
>>> nca = NCA()
>>> X, y = make_classification()
>>> nca.fit(X, y)
>>> knn = KNeighborsClassifier(metric=nca.get_metric())
>>> knn.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
KNeighborsClassifier(algorithm='auto', leaf_size=30,
metric=<function MahalanobisMixin.get_metric.<locals>.metric_fun
at 0x...>,
metric_params=None, n_jobs=None, n_neighbors=5, p=2,
weights='uniform')

See Also
--------
score_pairs : a method that returns the metric score between several pairs
of points. Unlike `get_metric`, this is a method of the metric learner
and therefore can change if the metric learner changes. Besides, it can
use the metric learner's preprocessor, and works on concatenated arrays.
"""

class MetricTransformer(six.with_metaclass(ABCMeta)):

Expand Down Expand Up @@ -146,6 +198,17 @@ def score_pairs(self, pairs):
-------
scores: `numpy.ndarray` of shape=(n_pairs,)
The learned Mahalanobis distance for every pair.

See Also
--------
get_metric : a method that returns a function to compute the metric between
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same updates as above

two points. The difference with `score_pairs` is that it works on two 1D
arrays and cannot use a preprocessor. Besides, the returned function is
independent of the metric learner and hence is not modified if the metric
learner is.

:ref:`mahalanobis_distances` : The section of the project documentation
that describes Mahalanobis Distances.
"""
pairs = check_input(pairs, type_of_inputs='tuples',
preprocessor=self.preprocessor_,
Expand Down Expand Up @@ -177,7 +240,57 @@ def transform(self, X):
accept_sparse=True)
return X_checked.dot(self.transformer_.T)

def get_metric(self):
transformer_T = self.transformer_.T.copy()

def metric_fun(u, v, squared=False):
"""This function computes the metric between u and v, according to the
previously learned metric.

Parameters
----------
u : array-like, shape=(n_features,)
The first point involved in the distance computation.

v : array-like, shape=(n_features,)
The second point involved in the distance computation.

squared : `bool`
If True, the function will return the squared metric between u and
v, which is faster to compute.

Returns
-------
distance: float
The distance between u and v according to the new metric.
"""
u = validate_vector(u)
v = validate_vector(v)
transformed_diff = (u - v).dot(transformer_T)
dist = np.dot(transformed_diff, transformed_diff.T)
if not squared:
dist = np.sqrt(dist)
return dist

return metric_fun

get_metric.__doc__ = BaseMetricLearner.get_metric.__doc__

def metric(self):
# TODO: remove this method in version 0.6.0
warnings.warn(("`metric` is deprecated since version 0.5.0 and will be "
"removed in 0.6.0. Use `get_mahalanobis_matrix` instead."),
DeprecationWarning)
return self.get_mahalanobis_matrix()

def get_mahalanobis_matrix(self):
"""Returns a copy of the Mahalanobis matrix learned by the metric learner.

Returns
-------
M : `numpy.ndarray`, shape=(n_components, n_features)
The copy of the learned Mahalanobis matrix.
"""
return self.transformer_.T.dot(self.transformer_)


Expand Down
6 changes: 3 additions & 3 deletions metric_learn/rca.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def fit(self, X, chunks):
chunks = np.asanyarray(chunks, dtype=int)
chunk_mask, chunked_data = _chunk_mean_centering(X_t, chunks)

inner_cov = np.cov(chunked_data, rowvar=0, bias=1)
inner_cov = np.atleast_2d(np.cov(chunked_data, rowvar=0, bias=1))
dim = self._check_dimension(np.linalg.matrix_rank(inner_cov), X_t)

# Fisher Linear Discriminant projection
Expand All @@ -122,13 +122,13 @@ def fit(self, X, chunks):
vals, vecs = np.linalg.eig(tmp)
inds = np.argsort(vals)[:dim]
A = vecs[:, inds]
inner_cov = A.T.dot(inner_cov).dot(A)
inner_cov = np.atleast_2d(A.T.dot(inner_cov).dot(A))
self.transformer_ = _inv_sqrtm(inner_cov).dot(A.T)
else:
self.transformer_ = _inv_sqrtm(inner_cov).T

if M_pca is not None:
self.transformer_ = self.transformer_.dot(M_pca)
self.transformer_ = np.atleast_2d(self.transformer_.dot(M_pca))

return self

Expand Down
2 changes: 1 addition & 1 deletion metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _fit(self, pairs, y):
# set up prior M
if self.use_cov:
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
self.M_ = pinvh(np.cov(X, rowvar = False))
self.M_ = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
else:
self.M_ = np.identity(pairs.shape[2])
diff = pairs[:, 0] - pairs[:, 1]
Expand Down
10 changes: 6 additions & 4 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_iris(self):
self.assertLess(csep, 0.15)

# Sanity checks for learned matrices.
self.assertEqual(lfda.metric().shape, (4, 4))
self.assertEqual(lfda.get_mahalanobis_matrix().shape, (4, 4))
self.assertEqual(lfda.transformer_.shape, (2, 4))


Expand Down Expand Up @@ -348,14 +348,16 @@ def test_iris(self):
[+0.000868, +0.001468, -0.002021, -0.002879],
[-0.001195, -0.002021, +0.002782, +0.003964],
[-0.001703, -0.002879, +0.003964, +0.005648]]
assert_array_almost_equal(expected, mmc.metric(), decimal=6)
assert_array_almost_equal(expected, mmc.get_mahalanobis_matrix(),
decimal=6)

# Diagonal metric
mmc = MMC(diagonal=True)
mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d]))
expected = [0, 0, 1.210220, 1.228596]
assert_array_almost_equal(np.diag(expected), mmc.metric(), decimal=6)

assert_array_almost_equal(np.diag(expected), mmc.get_mahalanobis_matrix(),
decimal=6)

# Supervised Full
mmc = MMC_Supervised()
mmc.fit(self.iris_points, self.iris_labels)
Expand Down
82 changes: 82 additions & 0 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest
import unittest
import metric_learn
import numpy as np
from sklearn import clone
from sklearn.utils.testing import set_random_state
from test.test_utils import ids_metric_learners, metric_learners


class TestStringRepr(unittest.TestCase):
Expand Down Expand Up @@ -81,5 +86,82 @@ def test_mmc(self):
num_labeled='deprecated', preprocessor=None, verbose=False)
""".strip('\n'))


@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
ids=ids_metric_learners)
def test_get_metric_is_independent_from_metric_learner(estimator,
build_dataset):
"""Tests that the get_metric method returns a function that is independent
from the original metric learner"""
input_data, labels, _, X = build_dataset()
model = clone(estimator)
set_random_state(model)

# we fit the metric learner on it and then we compute the metric on some
# points
model.fit(input_data, labels)
metric = model.get_metric()
score = metric(X[0], X[1])

# then we refit the estimator on another dataset
model.fit(np.sin(input_data), labels)

# we recompute the distance between the two points: it should be the same
score_bis = metric(X[0], X[1])
assert score_bis == score


@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
ids=ids_metric_learners)
def test_get_metric_raises_error(estimator, build_dataset):
"""Tests that the metric returned by get_metric raises errors similar to
the distance functions in scipy.spatial.distance"""
input_data, labels, _, X = build_dataset()
model = clone(estimator)
set_random_state(model)
model.fit(input_data, labels)
metric = model.get_metric()

list_test_get_metric_raises = [(X[0].tolist() + [5.2], X[1]), # vectors with
# different dimensions
(X[0:4], X[1:5]), # 2D vectors
(X[0].tolist() + [5.2], X[1] + [7.2])]
# vectors of same dimension but incompatible with what the metric learner
# was trained on

for u, v in list_test_get_metric_raises:
with pytest.raises(ValueError):
metric(u, v)


@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
ids=ids_metric_learners)
def test_get_metric_works_does_not_raise(estimator, build_dataset):
"""Tests that the metric returned by get_metric does not raise errors (or
warnings) similarly to the distance functions in scipy.spatial.distance"""
input_data, labels, _, X = build_dataset()
model = clone(estimator)
set_random_state(model)
model.fit(input_data, labels)
metric = model.get_metric()

list_test_get_metric_doesnt_raise = [(X[0], X[1]),
(X[0].tolist(), X[1].tolist()),
(X[0][None], X[1][None])]

for u, v in list_test_get_metric_doesnt_raise:
with pytest.warns(None) as record:
metric(u, v)
assert len(record) == 0

# Test that the scalar case works
model.transformer_ = np.array([3.1])
metric = model.get_metric()
for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]:
with pytest.warns(None) as record:
metric(u, v)
assert len(record) == 0


if __name__ == '__main__':
unittest.main()
Loading