-
Notifications
You must be signed in to change notification settings - Fork 229
[MRG] New API should allow prediction functions and scoring #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
776ab91
106cbd2
237d467
c124ee6
2dae03e
a70d1a8
b741a9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
from numpy.linalg import inv, cholesky | ||
from numpy.linalg import cholesky | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from sklearn.utils.validation import check_array | ||
from sklearn.metrics import roc_auc_score | ||
import numpy as np | ||
|
||
|
||
class BaseMetricLearner(BaseEstimator, TransformerMixin): | ||
class BaseMetricLearner(BaseEstimator): | ||
def __init__(self): | ||
raise NotImplementedError('BaseMetricLearner should not be instantiated') | ||
|
||
|
@@ -30,6 +32,9 @@ def transformer(self): | |
""" | ||
return cholesky(self.metric()).T | ||
|
||
|
||
class MetricTransformer(TransformerMixin): | ||
|
||
def transform(self, X=None): | ||
"""Applies the metric transformation. | ||
|
||
|
@@ -49,3 +54,105 @@ def transform(self, X=None): | |
X = check_array(X, accept_sparse=True) | ||
L = self.transformer() | ||
return X.dot(L.T) | ||
|
||
|
||
class _PairsClassifierMixin: | ||
|
||
def predict(self, pairs): | ||
"""Predicts the learned similarity between input pairs. | ||
|
||
Returns the learned metric value between samples in every pair. It should | ||
ideally be low for similar samples and high for dissimilar samples. | ||
|
||
Parameters | ||
---------- | ||
pairs : array-like, shape=(n_constraints, 2, n_features) | ||
A constrained dataset of paired samples. | ||
|
||
Returns | ||
------- | ||
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) | ||
The predicted learned metric value between samples in every pair. | ||
""" | ||
pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :] | ||
return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs, | ||
axis=1)) | ||
|
||
def decision_function(self, pairs): | ||
return self.predict(pairs) | ||
|
||
def score(self, pairs, y): | ||
"""Computes score of pairs similarity prediction. | ||
|
||
Returns the ``roc_auc`` score of the fitted metric learner. It is | ||
computed in the following way: for every value of a threshold | ||
``t`` we classify all pairs of samples where the predicted distance is | ||
inferior to ``t`` as belonging to the "similar" class, and the other as | ||
belonging to the "dissimilar" class, and we count false positive and | ||
true positives as in a classical ``roc_auc`` curve. | ||
|
||
Parameters | ||
---------- | ||
pairs : array-like, shape=(n_constraints, 2, n_features) | ||
Input Pairs. | ||
|
||
y : array-like, shape=(n_constraints,) | ||
The corresponding labels. | ||
|
||
Returns | ||
------- | ||
score : float | ||
The ``roc_auc`` score. | ||
""" | ||
return roc_auc_score(y, self.decision_function(pairs)) | ||
|
||
|
||
class _QuadrupletsClassifierMixin: | ||
|
||
def predict(self, quadruplets): | ||
"""Predicts differences between sample similarities in input quadruplets. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. distances? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks |
||
|
||
For each quadruplet of samples, computes the difference between the learned | ||
metric of the first pair minus the learned metric of the second pair. | ||
|
||
Parameters | ||
---------- | ||
quadruplets : array-like, shape=(n_constraints, 4, n_features) | ||
Input quadruplets. | ||
|
||
Returns | ||
------- | ||
prediction : np.ndarray of floats, shape=(n_constraints,) | ||
Metric differences. | ||
""" | ||
similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :] | ||
dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :] | ||
return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) * | ||
similar_diffs, axis=1)) - | ||
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) * | ||
dissimilar_diffs, axis=1))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This pattern, distance under some metric, seems like it should be factored out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes indeed, the function will call function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (this should ultimately be in the Mahalanobis Mixin) |
||
|
||
def decision_function(self, quadruplets): | ||
return self.predict(quadruplets) | ||
|
||
def score(self, quadruplets, y=None): | ||
"""Computes score on an input constrained dataset | ||
|
||
Returns the accuracy score of the following classification task: a record | ||
is correctly classified if the predicted similarity between the first two | ||
samples is higher than that of the last two. | ||
|
||
Parameters | ||
---------- | ||
quadruplets : array-like, shape=(n_constraints, 4, n_features) | ||
Input quadruplets. | ||
|
||
y : Ignored, for scikit-learn compatibility. | ||
|
||
Returns | ||
------- | ||
score : float | ||
The quadruplets score. | ||
""" | ||
predicted_sign = self.decision_function(quadruplets) < 0 | ||
return np.sum(predicted_sign) / predicted_sign.shape[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Much cleaner indeed, thanks ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be metric instead of similarity here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, thanks