Skip to content

Commit 106cbd2

Browse files
author
William de Vazelhes
committed
Implement scoring functions (and make tests work):
- Make PairsClassifierMixin and QuadrupletsClassifierMixin classes, to implement scoring functions - Implement a new API for supervised wrappers of weakly supervised learning estimators (through the use of base classes, (ex: BaseMMC), from which inherit child classes (ex: MMC and MMC_Supervised) (which is the same idea as in PR scikit-learn-contrib#85 - Delete tests that use tuples learners as transformers (as we do not want to support this behaviour anymore: it is too complicated to allow such different input types (tuples or points) for the same estimator
1 parent 776ab91 commit 106cbd2

File tree

12 files changed

+183
-84
lines changed

12 files changed

+183
-84
lines changed

metric_learn/base_metric.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from numpy.linalg import inv, cholesky
22
from sklearn.base import BaseEstimator, TransformerMixin
33
from sklearn.utils.validation import check_array
4+
from sklearn.metrics import roc_auc_score
5+
import numpy as np
46

57

6-
class BaseMetricLearner(BaseEstimator, TransformerMixin):
8+
class BaseMetricLearner(BaseEstimator):
79
def __init__(self):
810
raise NotImplementedError('BaseMetricLearner should not be instantiated')
911

@@ -19,6 +21,9 @@ def metric(self):
1921
L = self.transformer()
2022
return L.T.dot(L)
2123

24+
25+
class MetricTransformer(TransformerMixin):
26+
2227
def transformer(self):
2328
"""Computes the transformation matrix from the Mahalanobis matrix.
2429
@@ -49,3 +54,105 @@ def transform(self, X=None):
4954
X = check_array(X, accept_sparse=True)
5055
L = self.transformer()
5156
return X.dot(L.T)
57+
58+
59+
class _PairsClassifierMixin:
60+
61+
def predict(self, pairs):
62+
"""Predicts the learned similarity between input pairs.
63+
64+
Returns the learned metric value between samples in every pair. It should
65+
ideally be low for similar samples and high for dissimilar samples.
66+
67+
Parameters
68+
----------
69+
pairs : array-like, shape=(n_constraints, 2, n_features)
70+
A constrained dataset of paired samples.
71+
72+
Returns
73+
-------
74+
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
75+
The predicted learned metric value between samples in every pair.
76+
"""
77+
pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :]
78+
return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs,
79+
axis=1))
80+
81+
def decision_function(self, pairs):
82+
return self.predict(pairs)
83+
84+
def score(self, pairs, y):
85+
"""Computes score of pairs similarity prediction.
86+
87+
Returns the ``roc_auc`` score of the fitted metric learner. It is
88+
computed in the following way: for every value of a threshold
89+
``t`` we classify all pairs of samples where the predicted distance is
90+
inferior to ``t`` as belonging to the "similar" class, and the other as
91+
belonging to the "dissimilar" class, and we count false positive and
92+
true positives as in a classical ``roc_auc`` curve.
93+
94+
Parameters
95+
----------
96+
pairs : array-like, shape=(n_constraints, 2, n_features)
97+
Input Pairs.
98+
99+
y : array-like, shape=(n_constraints,)
100+
The corresponding labels.
101+
102+
Returns
103+
-------
104+
score : float
105+
The ``roc_auc`` score.
106+
"""
107+
return roc_auc_score(y, self.decision_function(pairs))
108+
109+
110+
class _QuadrupletsClassifierMixin:
111+
112+
def predict(self, quadruplets):
113+
"""Predicts differences between sample similarities in input quadruplets.
114+
115+
For each quadruplet of samples, computes the difference between the learned
116+
metric of the first pair minus the learned metric of the second pair.
117+
118+
Parameters
119+
----------
120+
quadruplets : array-like, shape=(n_constraints, 4, n_features)
121+
Input quadruplets.
122+
123+
Returns
124+
-------
125+
prediction : np.ndarray of floats, shape=(n_constraints,)
126+
Metric differences.
127+
"""
128+
similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :]
129+
dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :]
130+
return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) *
131+
similar_diffs, axis=1)) -
132+
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) *
133+
dissimilar_diffs, axis=1)))
134+
135+
def decision_function(self, quadruplets):
136+
return self.predict(quadruplets)
137+
138+
def score(self, quadruplets, y=None):
139+
"""Computes score on an input constrained dataset
140+
141+
Returns the accuracy score of the following classification task: a record
142+
is correctly classified if the predicted similarity between the first two
143+
samples is higher than that of the last two.
144+
145+
Parameters
146+
----------
147+
quadruplets : array-like, shape=(n_constraints, 4, n_features)
148+
Input quadruplets.
149+
150+
y : Ignored, for scikit-learn compatibility.
151+
152+
Returns
153+
-------
154+
score : float
155+
The quadruplets score.
156+
"""
157+
predicted_sign = self.decision_function(quadruplets) < 0
158+
return np.sum(predicted_sign) / predicted_sign.shape[0]

metric_learn/covariance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import numpy as np
1313
from sklearn.utils.validation import check_array
1414

15-
from .base_metric import BaseMetricLearner
15+
from .base_metric import BaseMetricLearner, MetricTransformer
1616

1717

18-
class Covariance(BaseMetricLearner):
18+
class Covariance(BaseMetricLearner, MetricTransformer):
1919
def __init__(self):
2020
pass
2121

metric_learn/itml.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
from sklearn.metrics import pairwise_distances
2020
from sklearn.utils.validation import check_array, check_X_y
2121

22-
from .base_metric import BaseMetricLearner
22+
from .base_metric import (BaseMetricLearner, _PairsClassifierMixin,
23+
MetricTransformer)
2324
from .constraints import Constraints, wrap_pairs
2425
from ._util import vector_norm
2526

2627

27-
class ITML(BaseMetricLearner):
28+
class _BaseITML(BaseMetricLearner):
2829
"""Information Theoretic Metric Learning (ITML)"""
2930
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
3031
A0=None, verbose=False):
@@ -80,8 +81,7 @@ def _process_pairs(self, pairs, y, bounds):
8081
y = y.astype(bool)
8182
return pairs, y
8283

83-
84-
def fit(self, pairs, y, bounds=None):
84+
def _fit(self, pairs, y, bounds=None):
8585
"""Learn the ITML model.
8686
8787
Parameters
@@ -153,7 +153,13 @@ def metric(self):
153153
return self.A_
154154

155155

156-
class ITML_Supervised(ITML):
156+
class ITML(_BaseITML, _PairsClassifierMixin):
157+
158+
def fit(self, pairs, y, bounds=None):
159+
return self._fit(pairs, y, bounds=bounds)
160+
161+
162+
class ITML_Supervised(_BaseITML, MetricTransformer):
157163
"""Information Theoretic Metric Learning (ITML)"""
158164
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
159165
num_labeled=np.inf, num_constraints=None, bounds=None, A0=None,
@@ -177,9 +183,9 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
177183
verbose : bool, optional
178184
if True, prints information while learning
179185
"""
180-
ITML.__init__(self, gamma=gamma, max_iter=max_iter,
181-
convergence_threshold=convergence_threshold,
182-
A0=A0, verbose=verbose)
186+
_BaseITML.__init__(self, gamma=gamma, max_iter=max_iter,
187+
convergence_threshold=convergence_threshold,
188+
A0=A0, verbose=verbose)
183189
self.num_labeled = num_labeled
184190
self.num_constraints = num_constraints
185191
self.bounds = bounds
@@ -209,4 +215,4 @@ def fit(self, X, y, random_state=np.random):
209215
pos_neg = c.positive_negative_pairs(num_constraints,
210216
random_state=random_state)
211217
pairs, y = wrap_pairs(X, pos_neg)
212-
return ITML.fit(self, pairs, y, bounds=self.bounds)
218+
return _BaseITML._fit(self, pairs, y, bounds=self.bounds)

metric_learn/lfda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from sklearn.metrics import pairwise_distances
1919
from sklearn.utils.validation import check_X_y
2020

21-
from .base_metric import BaseMetricLearner
21+
from .base_metric import BaseMetricLearner, MetricTransformer
2222

2323

24-
class LFDA(BaseMetricLearner):
24+
class LFDA(BaseMetricLearner, MetricTransformer):
2525
'''
2626
Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction
2727
Sugiyama, ICML 2006

metric_learn/lmnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from sklearn.utils.validation import check_X_y, check_array
1818
from sklearn.metrics import euclidean_distances
1919

20-
from .base_metric import BaseMetricLearner
20+
from .base_metric import BaseMetricLearner, MetricTransformer
2121

2222

2323
# commonality between LMNN implementations
24-
class _base_LMNN(BaseMetricLearner):
24+
class _base_LMNN(BaseMetricLearner, MetricTransformer):
2525
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
2626
regularization=0.5, convergence_tol=0.001, use_pca=True,
2727
verbose=False):

metric_learn/lsml.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from six.moves import xrange
1414
from sklearn.utils.validation import check_array, check_X_y
1515

16-
from .base_metric import BaseMetricLearner
16+
from .base_metric import (BaseMetricLearner, _QuadrupletsClassifierMixin,
17+
MetricTransformer)
1718
from .constraints import Constraints, wrap_pairs
1819

1920

20-
class LSML(BaseMetricLearner):
21+
class _BaseLSML(BaseMetricLearner):
2122
def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False):
2223
"""Initialize LSML.
2324
@@ -60,7 +61,7 @@ def _prepare_quadruplets(self, quadruplets, weights):
6061
def metric(self):
6162
return self.M_
6263

63-
def fit(self, quadruplets, weights=None):
64+
def _fit(self, quadruplets, weights=None):
6465
"""Learn the LSML model.
6566
6667
Parameters
@@ -140,7 +141,13 @@ def _gradient(self, metric):
140141
return dMetric
141142

142143

143-
class LSML_Supervised(LSML):
144+
class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
145+
146+
def fit(self, quadruplets, weights=None):
147+
return self._fit(quadruplets, weights=weights)
148+
149+
150+
class LSML_Supervised(_BaseLSML, MetricTransformer):
144151
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
145152
num_constraints=None, weights=None, verbose=False):
146153
"""Initialize the learner.
@@ -160,8 +167,8 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
160167
verbose : bool, optional
161168
if True, prints information while learning
162169
"""
163-
LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior,
164-
verbose=verbose)
170+
_BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior,
171+
verbose=verbose)
165172
self.num_labeled = num_labeled
166173
self.num_constraints = num_constraints
167174
self.weights = weights
@@ -189,5 +196,6 @@ def fit(self, X, y, random_state=np.random):
189196
c = Constraints.random_subset(y, self.num_labeled,
190197
random_state=random_state)
191198
pos_neg = c.positive_negative_pairs(num_constraints, same_length=True,
192-
random_state=random_state)
193-
return LSML.fit(self, X[np.column_stack(pos_neg)], weights=self.weights)
199+
random_state=random_state)
200+
return _BaseLSML._fit(self, X[np.column_stack(pos_neg)],
201+
weights=self.weights)

metric_learn/mlkr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from sklearn.decomposition import PCA
1414
from sklearn.utils.validation import check_X_y
1515

16-
from .base_metric import BaseMetricLearner
16+
from .base_metric import BaseMetricLearner, MetricTransformer
1717

1818
EPS = np.finfo(float).eps
1919

2020

21-
class MLKR(BaseMetricLearner):
21+
class MLKR(BaseMetricLearner, MetricTransformer):
2222
"""Metric Learning for Kernel Regression (MLKR)"""
2323
def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001,
2424
max_iter=1000):

metric_learn/mmc.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from sklearn.metrics import pairwise_distances
2323
from sklearn.utils.validation import check_array, check_X_y
2424

25-
from .base_metric import BaseMetricLearner
25+
from .base_metric import (BaseMetricLearner, _PairsClassifierMixin,
26+
MetricTransformer)
2627
from .constraints import Constraints, wrap_pairs
2728
from ._util import vector_norm
2829

2930

30-
31-
class MMC(BaseMetricLearner):
31+
class _BaseMMC(BaseMetricLearner):
3232
"""Mahalanobis Metric for Clustering (MMC)"""
3333
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3,
3434
A0=None, diagonal=False, diagonal_c=1.0, verbose=False):
@@ -58,8 +58,7 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3,
5858
self.diagonal_c = diagonal_c
5959
self.verbose = verbose
6060

61-
62-
def fit(self, pairs, y):
61+
def _fit(self, pairs, y):
6362
"""Learn the MMC model.
6463
6564
Parameters
@@ -390,7 +389,13 @@ def transformer(self):
390389
return V.T * np.sqrt(np.maximum(0, w[:,None]))
391390

392391

393-
class MMC_Supervised(MMC):
392+
class MMC(_BaseMMC, _PairsClassifierMixin):
393+
394+
def fit(self, pairs, y):
395+
return self._fit(pairs, y)
396+
397+
398+
class MMC_Supervised(_BaseMMC, MetricTransformer):
394399
"""Mahalanobis Metric for Clustering (MMC)"""
395400
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
396401
num_labeled=np.inf, num_constraints=None,
@@ -418,10 +423,10 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
418423
verbose : bool, optional
419424
if True, prints information while learning
420425
"""
421-
MMC.__init__(self, max_iter=max_iter, max_proj=max_proj,
422-
convergence_threshold=convergence_threshold,
423-
A0=A0, diagonal=diagonal, diagonal_c=diagonal_c,
424-
verbose=verbose)
426+
_BaseMMC.__init__(self, max_iter=max_iter, max_proj=max_proj,
427+
convergence_threshold=convergence_threshold,
428+
A0=A0, diagonal=diagonal, diagonal_c=diagonal_c,
429+
verbose=verbose)
425430
self.num_labeled = num_labeled
426431
self.num_constraints = num_constraints
427432

@@ -448,4 +453,4 @@ def fit(self, X, y, random_state=np.random):
448453
pos_neg = c.positive_negative_pairs(num_constraints,
449454
random_state=random_state)
450455
pairs, y = wrap_pairs(X, pos_neg)
451-
return MMC.fit(self, pairs, y)
456+
return _BaseMMC._fit(self, pairs, y)

metric_learn/nca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from six.moves import xrange
99
from sklearn.utils.validation import check_X_y
1010

11-
from .base_metric import BaseMetricLearner
11+
from .base_metric import BaseMetricLearner, MetricTransformer
1212

1313
EPS = np.finfo(float).eps
1414

1515

16-
class NCA(BaseMetricLearner):
16+
class NCA(BaseMetricLearner, MetricTransformer):
1717
def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01):
1818
self.num_dims = num_dims
1919
self.max_iter = max_iter

metric_learn/rca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sklearn import decomposition
1919
from sklearn.utils.validation import check_array
2020

21-
from .base_metric import BaseMetricLearner
21+
from .base_metric import BaseMetricLearner, MetricTransformer
2222
from .constraints import Constraints
2323

2424

@@ -35,7 +35,7 @@ def _chunk_mean_centering(data, chunks):
3535
return chunk_mask, chunk_data
3636

3737

38-
class RCA(BaseMetricLearner):
38+
class RCA(BaseMetricLearner, MetricTransformer):
3939
"""Relevant Components Analysis (RCA)"""
4040
def __init__(self, num_dims=None, pca_comps=None):
4141
"""Initialize the learner.

0 commit comments

Comments
 (0)