-
Notifications
You must be signed in to change notification settings - Fork 229
[MRG] Refactor the metric() method #152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
a3384b1
8e0d197
6dd118e
c7e40f6
1947ea5
bee6902
646cf97
00d37c9
c9eefb4
bd6aac0
22141f5
9e447f6
201320b
4b660fa
61a33cc
72153ed
d943406
d2c0614
5e29295
92669ae
a2955e0
c8708b2
7d4efd9
0c7c5dc
7dfd874
80c2943
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
from numpy.linalg import cholesky | ||
from scipy.spatial.distance import euclidean | ||
from sklearn.base import BaseEstimator | ||
from sklearn.utils.validation import _is_arraylike | ||
from sklearn.metrics import roc_auc_score | ||
import numpy as np | ||
from abc import ABCMeta, abstractmethod | ||
import six | ||
from ._util import ArrayIndexer, check_input | ||
from ._util import ArrayIndexer, check_input, validate_vector | ||
import warnings | ||
|
||
|
||
class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)): | ||
|
@@ -34,6 +37,14 @@ def score_pairs(self, pairs): | |
------- | ||
scores: `numpy.ndarray` of shape=(n_pairs,) | ||
The score of every pair. | ||
|
||
See Also | ||
-------- | ||
get_metric : a method that returns a function to compute the metric between | ||
two points. The difference with `score_pairs` is that it works on two 1D | ||
arrays and cannot use a preprocessor. Besides, the returned function is | ||
independent of the metric learner and hence is not modified if the metric | ||
learner is. | ||
""" | ||
|
||
def check_preprocessor(self): | ||
|
@@ -85,6 +96,47 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic', | |
tuple_size=getattr(self, '_tuple_size', None), | ||
**kwargs) | ||
|
||
@abstractmethod | ||
def get_metric(self): | ||
"""Returns a function that takes as input two 1D arrays and outputs the | ||
learned metric score on these two points. | ||
|
||
This function will be independent from the metric learner that learned it | ||
(it will not be modified if the initial metric learner is modified), | ||
and it can be directly plugged into the `metric` argument of | ||
scikit-learn's estimators. | ||
|
||
Returns | ||
------- | ||
metric_fun : function | ||
The function described above. | ||
|
||
|
||
Examples | ||
-------- | ||
.. doctest:: | ||
|
||
>>> from metric_learn import NCA | ||
>>> from sklearn.datasets import make_classification | ||
>>> from sklearn.neighbors import KNeighborsClassifier | ||
>>> nca = NCA() | ||
>>> X, y = make_classification() | ||
>>> nca.fit(X, y) | ||
>>> knn = KNeighborsClassifier(metric=nca.get_metric()) | ||
>>> knn.fit(X, y) # doctest: +NORMALIZE_WHITESPACE | ||
KNeighborsClassifier(algorithm='auto', leaf_size=30, | ||
metric=<function MahalanobisMixin.get_metric.<locals>.metric_fun | ||
at 0x...>, | ||
metric_params=None, n_jobs=None, n_neighbors=5, p=2, | ||
weights='uniform') | ||
|
||
See Also | ||
-------- | ||
score_pairs : a method that returns the metric score between several pairs | ||
of points. Unlike `get_metric`, this is a method of the metric learner | ||
and therefore can change if the metric learner changes. Besides, it can | ||
use the metric learner's preprocessor, and works on concatenated arrays. | ||
""" | ||
|
||
class MetricTransformer(six.with_metaclass(ABCMeta)): | ||
|
||
|
@@ -146,6 +198,17 @@ def score_pairs(self, pairs): | |
------- | ||
scores: `numpy.ndarray` of shape=(n_pairs,) | ||
The learned Mahalanobis distance for every pair. | ||
|
||
See Also | ||
-------- | ||
get_metric : a method that returns a function to compute the metric between | ||
two points. The difference with `score_pairs` is that it works on two 1D | ||
arrays and cannot use a preprocessor. Besides, the returned function is | ||
independent of the metric learner and hence is not modified if the metric | ||
learner is. | ||
|
||
:ref:`mahalanobis_distances` : The section of the project documentation | ||
that describes Mahalanobis Distances. | ||
""" | ||
pairs = check_input(pairs, type_of_inputs='tuples', | ||
preprocessor=self.preprocessor_, | ||
|
@@ -177,9 +240,87 @@ def transform(self, X): | |
accept_sparse=True) | ||
return X_checked.dot(self.transformer_.T) | ||
|
||
def get_metric(self): | ||
"""Returns a function that takes as input two 1D arrays and outputs the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to only write this docstring once. Unfortunately Python doesn't do inherited docstrings, so we have to assign it manually:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool thanks ! I always wondered how to avoid copying docstrings |
||
learned metric score on these two points. | ||
|
||
This function will be independent from the metric learner that learned it | ||
(it will not be modified if the initial metric learner is modified), | ||
and it can be directly plugged into the `metric` argument of | ||
scikit-learn's estimators. | ||
|
||
Returns | ||
------- | ||
metric_fun : function | ||
The function described above. | ||
|
||
Examples | ||
-------- | ||
.. doctest:: | ||
|
||
>>> from metric_learn import NCA | ||
>>> from sklearn.datasets import make_classification | ||
>>> from sklearn.neighbors import KNeighborsClassifier | ||
>>> nca = NCA() | ||
>>> X, y = make_classification() | ||
>>> nca.fit(X, y) | ||
>>> knn = KNeighborsClassifier(metric=nca.get_metric()) | ||
>>> knn.fit(X, y) # doctest: +NORMALIZE_WHITESPACE | ||
KNeighborsClassifier(algorithm='auto', leaf_size=30, | ||
metric=<function MahalanobisMixin.get_metric.<locals>.metric_fun | ||
at 0x...>, | ||
metric_params=None, n_jobs=None, n_neighbors=5, p=2, | ||
weights='uniform') | ||
|
||
See Also | ||
-------- | ||
score_pairs : a method that returns the metric score between several pairs | ||
of points. Unlike `get_metric`, this is a method of the metric learner | ||
and therefore can change if the metric learner changes. Besides, it can | ||
use the metric learner's preprocessor, and works on concatenated arrays. | ||
|
||
:ref:`mahalanobis_distances` : The section of the project documentation | ||
that describes Mahalanobis Distances. | ||
""" | ||
transformer_T = self.transformer_.T.copy() | ||
|
||
def metric_fun(u, v): | ||
"""This function computes the metric between u and v, according to the | ||
previously learned metric. | ||
|
||
Parameters | ||
---------- | ||
u : array-like, shape=(n_features,) | ||
The first point involved in the distance computation. | ||
v : array-like, shape=(n_features,) | ||
The second point involved in the distance computation. | ||
Returns | ||
------- | ||
distance: float | ||
The distance between u and v according to the new metric. | ||
""" | ||
u = validate_vector(u) | ||
v = validate_vector(v) | ||
return euclidean(u.dot(transformer_T), v.dot(transformer_T)) | ||
return metric_fun | ||
|
||
def metric(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should keep this for now but mark it as deprecated, and point to the new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree, I forgot about that |
||
# TODO: remove this method in version 0.6.0 | ||
warnings.warn(("`metric` is deprecated since version 0.5.0 and will be " | ||
"removed in 0.6.0. Use `get_mahalanobis_matrix` instead."), | ||
DeprecationWarning) | ||
return self.transformer_.T.dot(self.transformer_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might as well call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right, thanks |
||
|
||
def get_mahalanobis_matrix(self): | ||
"""Returns a copy of the Mahalanobis matrix learned by the metric learner. | ||
|
||
Returns | ||
------- | ||
M : `numpy.ndarray`, shape=(n_components, n_features) | ||
The copy of the learned Mahalanobis matrix. | ||
""" | ||
return self.transformer_.T.dot(self.transformer_).copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no need for a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right thanks, I don't know why I left the copy there... |
||
|
||
|
||
class _PairsClassifierMixin(BaseMetricLearner): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same updates as above