Skip to content

Commit 24b0def

Browse files
authored
Merge pull request #95 from wdevazelhes/feat/api_prediction
[MRG] New API should allow prediction functions and scoring
2 parents 13f1535 + b741a9e commit 24b0def

File tree

12 files changed

+451
-106
lines changed

12 files changed

+451
-106
lines changed

metric_learn/base_metric.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from numpy.linalg import inv, cholesky
1+
from numpy.linalg import cholesky
22
from sklearn.base import BaseEstimator, TransformerMixin
33
from sklearn.utils.validation import check_array
4+
from sklearn.metrics import roc_auc_score
5+
import numpy as np
46

57

6-
class BaseMetricLearner(BaseEstimator, TransformerMixin):
8+
class BaseMetricLearner(BaseEstimator):
79
def __init__(self):
810
raise NotImplementedError('BaseMetricLearner should not be instantiated')
911

@@ -30,6 +32,9 @@ def transformer(self):
3032
"""
3133
return cholesky(self.metric()).T
3234

35+
36+
class MetricTransformer(TransformerMixin):
37+
3338
def transform(self, X=None):
3439
"""Applies the metric transformation.
3540
@@ -49,3 +54,104 @@ def transform(self, X=None):
4954
X = check_array(X, accept_sparse=True)
5055
L = self.transformer()
5156
return X.dot(L.T)
57+
58+
59+
class _PairsClassifierMixin:
60+
61+
def predict(self, pairs):
62+
"""Predicts the learned metric between input pairs.
63+
64+
Returns the learned metric value between samples in every pair. It should
65+
ideally be low for similar samples and high for dissimilar samples.
66+
67+
Parameters
68+
----------
69+
pairs : array-like, shape=(n_constraints, 2, n_features)
70+
Input pairs.
71+
72+
Returns
73+
-------
74+
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
75+
The predicted learned metric value between samples in every pair.
76+
"""
77+
pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :]
78+
return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs,
79+
axis=1))
80+
81+
def decision_function(self, pairs):
82+
return self.predict(pairs)
83+
84+
def score(self, pairs, y):
85+
"""Computes score of pairs similarity prediction.
86+
87+
Returns the ``roc_auc`` score of the fitted metric learner. It is
88+
computed in the following way: for every value of a threshold
89+
``t`` we classify all pairs of samples where the predicted distance is
90+
inferior to ``t`` as belonging to the "similar" class, and the other as
91+
belonging to the "dissimilar" class, and we count false positive and
92+
true positives as in a classical ``roc_auc`` curve.
93+
94+
Parameters
95+
----------
96+
pairs : array-like, shape=(n_constraints, 2, n_features)
97+
Input Pairs.
98+
99+
y : array-like, shape=(n_constraints,)
100+
The corresponding labels.
101+
102+
Returns
103+
-------
104+
score : float
105+
The ``roc_auc`` score.
106+
"""
107+
return roc_auc_score(y, self.decision_function(pairs))
108+
109+
110+
class _QuadrupletsClassifierMixin:
111+
112+
def predict(self, quadruplets):
113+
"""Predicts differences between sample distances in input quadruplets.
114+
115+
For each quadruplet of samples, computes the difference between the learned
116+
metric of the first pair minus the learned metric of the second pair.
117+
118+
Parameters
119+
----------
120+
quadruplets : array-like, shape=(n_constraints, 4, n_features)
121+
Input quadruplets.
122+
123+
Returns
124+
-------
125+
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
126+
Metric differences.
127+
"""
128+
similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :]
129+
dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :]
130+
return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) *
131+
similar_diffs, axis=1)) -
132+
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) *
133+
dissimilar_diffs, axis=1)))
134+
135+
def decision_function(self, quadruplets):
136+
return self.predict(quadruplets)
137+
138+
def score(self, quadruplets, y=None):
139+
"""Computes score on input quadruplets
140+
141+
Returns the accuracy score of the following classification task: a record
142+
is correctly classified if the predicted similarity between the first two
143+
samples is higher than that of the last two.
144+
145+
Parameters
146+
----------
147+
quadruplets : array-like, shape=(n_constraints, 4, n_features)
148+
Input quadruplets.
149+
150+
y : Ignored, for scikit-learn compatibility.
151+
152+
Returns
153+
-------
154+
score : float
155+
The quadruplets score.
156+
"""
157+
return - np.mean(np.sign(self.decision_function(quadruplets)))

metric_learn/covariance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import numpy as np
1313
from sklearn.utils.validation import check_array
1414

15-
from .base_metric import BaseMetricLearner
15+
from .base_metric import BaseMetricLearner, MetricTransformer
1616

1717

18-
class Covariance(BaseMetricLearner):
18+
class Covariance(BaseMetricLearner, MetricTransformer):
1919
def __init__(self):
2020
pass
2121

metric_learn/itml.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
from sklearn.metrics import pairwise_distances
2020
from sklearn.utils.validation import check_array, check_X_y
2121

22-
from .base_metric import BaseMetricLearner
22+
from .base_metric import (BaseMetricLearner, _PairsClassifierMixin,
23+
MetricTransformer)
2324
from .constraints import Constraints, wrap_pairs
2425
from ._util import vector_norm
2526

2627

27-
class ITML(BaseMetricLearner):
28+
class _BaseITML(BaseMetricLearner):
2829
"""Information Theoretic Metric Learning (ITML)"""
2930
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
3031
A0=None, verbose=False):
@@ -78,24 +79,7 @@ def _process_pairs(self, pairs, y, bounds):
7879
y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))])
7980
return pairs, y
8081

81-
82-
def fit(self, pairs, y, bounds=None):
83-
"""Learn the ITML model.
84-
85-
Parameters
86-
----------
87-
pairs: array-like, shape=(n_constraints, 2, n_features)
88-
Array of pairs. Each row corresponds to two points.
89-
y: array-like, of shape (n_constraints,)
90-
Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
91-
bounds : list (pos,neg) pairs, optional
92-
bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg
93-
94-
Returns
95-
-------
96-
self : object
97-
Returns the instance.
98-
"""
82+
def _fit(self, pairs, y, bounds=None):
9983
pairs, y = self._process_pairs(pairs, y, bounds)
10084
gamma = self.gamma
10185
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
@@ -151,7 +135,29 @@ def metric(self):
151135
return self.A_
152136

153137

154-
class ITML_Supervised(ITML):
138+
class ITML(_BaseITML, _PairsClassifierMixin):
139+
140+
def fit(self, pairs, y, bounds=None):
141+
"""Learn the ITML model.
142+
143+
Parameters
144+
----------
145+
pairs: array-like, shape=(n_constraints, 2, n_features)
146+
Array of pairs. Each row corresponds to two points.
147+
y: array-like, of shape (n_constraints,)
148+
Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
149+
bounds : list (pos,neg) pairs, optional
150+
bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg
151+
152+
Returns
153+
-------
154+
self : object
155+
Returns the instance.
156+
"""
157+
return self._fit(pairs, y, bounds=bounds)
158+
159+
160+
class ITML_Supervised(_BaseITML, MetricTransformer):
155161
"""Information Theoretic Metric Learning (ITML)"""
156162
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
157163
num_labeled=np.inf, num_constraints=None, bounds=None, A0=None,
@@ -175,9 +181,9 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
175181
verbose : bool, optional
176182
if True, prints information while learning
177183
"""
178-
ITML.__init__(self, gamma=gamma, max_iter=max_iter,
179-
convergence_threshold=convergence_threshold,
180-
A0=A0, verbose=verbose)
184+
_BaseITML.__init__(self, gamma=gamma, max_iter=max_iter,
185+
convergence_threshold=convergence_threshold,
186+
A0=A0, verbose=verbose)
181187
self.num_labeled = num_labeled
182188
self.num_constraints = num_constraints
183189
self.bounds = bounds
@@ -207,4 +213,4 @@ def fit(self, X, y, random_state=np.random):
207213
pos_neg = c.positive_negative_pairs(num_constraints,
208214
random_state=random_state)
209215
pairs, y = wrap_pairs(X, pos_neg)
210-
return ITML.fit(self, pairs, y, bounds=self.bounds)
216+
return _BaseITML._fit(self, pairs, y, bounds=self.bounds)

metric_learn/lfda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from sklearn.metrics import pairwise_distances
1919
from sklearn.utils.validation import check_X_y
2020

21-
from .base_metric import BaseMetricLearner
21+
from .base_metric import BaseMetricLearner, MetricTransformer
2222

2323

24-
class LFDA(BaseMetricLearner):
24+
class LFDA(BaseMetricLearner, MetricTransformer):
2525
'''
2626
Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction
2727
Sugiyama, ICML 2006

metric_learn/lmnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from sklearn.utils.validation import check_X_y, check_array
1818
from sklearn.metrics import euclidean_distances
1919

20-
from .base_metric import BaseMetricLearner
20+
from .base_metric import BaseMetricLearner, MetricTransformer
2121

2222

2323
# commonality between LMNN implementations
24-
class _base_LMNN(BaseMetricLearner):
24+
class _base_LMNN(BaseMetricLearner, MetricTransformer):
2525
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
2626
regularization=0.5, convergence_tol=0.001, use_pca=True,
2727
verbose=False):

metric_learn/lsml.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from six.moves import xrange
1414
from sklearn.utils.validation import check_array, check_X_y
1515

16-
from .base_metric import BaseMetricLearner
17-
from .constraints import Constraints, wrap_pairs
16+
from .base_metric import (BaseMetricLearner, _QuadrupletsClassifierMixin,
17+
MetricTransformer)
18+
from .constraints import Constraints
1819

1920

20-
class LSML(BaseMetricLearner):
21+
class _BaseLSML(BaseMetricLearner):
2122
def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False):
2223
"""Initialize LSML.
2324
@@ -60,24 +61,7 @@ def _prepare_quadruplets(self, quadruplets, weights):
6061
def metric(self):
6162
return self.M_
6263

63-
def fit(self, quadruplets, weights=None):
64-
"""Learn the LSML model.
65-
66-
Parameters
67-
----------
68-
quadruplets : array-like, shape=(n_constraints, 4, n_features)
69-
Each row corresponds to 4 points. In order to supervise the
70-
algorithm in the right way, we should have the four samples ordered
71-
in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3])
72-
for all 0 <= i < n_constraints.
73-
weights : (n_constraints,) array of floats, optional
74-
scale factor for each constraint
75-
76-
Returns
77-
-------
78-
self : object
79-
Returns the instance.
80-
"""
64+
def _fit(self, quadruplets, weights=None):
8165
self._prepare_quadruplets(quadruplets, weights)
8266
step_sizes = np.logspace(-10, 0, 10)
8367
# Keep track of the best step size and the loss at that step.
@@ -140,7 +124,30 @@ def _gradient(self, metric):
140124
return dMetric
141125

142126

143-
class LSML_Supervised(LSML):
127+
class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
128+
129+
def fit(self, quadruplets, weights=None):
130+
"""Learn the LSML model.
131+
132+
Parameters
133+
----------
134+
quadruplets : array-like, shape=(n_constraints, 4, n_features)
135+
Each row corresponds to 4 points. In order to supervise the
136+
algorithm in the right way, we should have the four samples ordered
137+
in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3])
138+
for all 0 <= i < n_constraints.
139+
weights : (n_constraints,) array of floats, optional
140+
scale factor for each constraint
141+
142+
Returns
143+
-------
144+
self : object
145+
Returns the instance.
146+
"""
147+
return self._fit(quadruplets, weights=weights)
148+
149+
150+
class LSML_Supervised(_BaseLSML, MetricTransformer):
144151
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
145152
num_constraints=None, weights=None, verbose=False):
146153
"""Initialize the learner.
@@ -160,8 +167,8 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
160167
verbose : bool, optional
161168
if True, prints information while learning
162169
"""
163-
LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior,
164-
verbose=verbose)
170+
_BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior,
171+
verbose=verbose)
165172
self.num_labeled = num_labeled
166173
self.num_constraints = num_constraints
167174
self.weights = weights
@@ -189,5 +196,6 @@ def fit(self, X, y, random_state=np.random):
189196
c = Constraints.random_subset(y, self.num_labeled,
190197
random_state=random_state)
191198
pos_neg = c.positive_negative_pairs(num_constraints, same_length=True,
192-
random_state=random_state)
193-
return LSML.fit(self, X[np.column_stack(pos_neg)], weights=self.weights)
199+
random_state=random_state)
200+
return _BaseLSML._fit(self, X[np.column_stack(pos_neg)],
201+
weights=self.weights)

metric_learn/mlkr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from sklearn.decomposition import PCA
1414
from sklearn.utils.validation import check_X_y
1515

16-
from .base_metric import BaseMetricLearner
16+
from .base_metric import BaseMetricLearner, MetricTransformer
1717

1818
EPS = np.finfo(float).eps
1919

2020

21-
class MLKR(BaseMetricLearner):
21+
class MLKR(BaseMetricLearner, MetricTransformer):
2222
"""Metric Learning for Kernel Regression (MLKR)"""
2323
def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001,
2424
max_iter=1000):

0 commit comments

Comments
 (0)