Skip to content

Commit 810d191

Browse files
author
William de Vazelhes
committed
MAINT: Adress #96 (review)
- replace embed by transform and add always the input X in calling the function - mutualize _transformer_from_metric not to be overwritten in MMC - improve test_mahalanobis_mixin.test_score_pairs_pairwise according to #96 (comment) - improve test_mahalanobis_mixin.check_is_distance_matrix - correct typos and nitpicks
1 parent eff278e commit 810d191

15 files changed

+125
-108
lines changed

examples/sandwich.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def sandwich_demo():
3030

3131
for ax_num, ml in enumerate(mls, start=3):
3232
ml.fit(x, y)
33-
tx = ml.transform()
33+
tx = ml.transform(x)
3434
ml_knn = nearest_neighbors(tx, k=2)
3535
ax = plt.subplot(3, 2, ax_num)
3636
plot_sandwich_data(tx, y, axis=ax)

metric_learn/base_metric.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from numpy.linalg import cholesky
2-
from sklearn.base import BaseEstimator, TransformerMixin
2+
from sklearn.base import BaseEstimator
33
from sklearn.utils.validation import check_array
44
from sklearn.metrics import roc_auc_score
55
import numpy as np
@@ -28,9 +28,9 @@ def score_pairs(self, pairs):
2828
"""
2929

3030

31-
class MetricTransformer(TransformerMixin):
31+
class MetricTransformer():
3232

33-
def transform(self, X=None):
33+
def transform(self, X):
3434
"""Applies the metric transformation.
3535
3636
Parameters
@@ -43,15 +43,10 @@ def transform(self, X=None):
4343
transformed : (n x d) matrix
4444
Input data transformed to the metric space by :math:`XL^{\\top}`
4545
"""
46-
if X is None:
47-
X = self.X_
48-
else:
49-
X = check_array(X, accept_sparse=True)
50-
L = self.transformer_
51-
return X.dot(L.T)
5246

5347

54-
class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner)):
48+
class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner,
49+
MetricTransformer)):
5550
"""Mahalanobis metric learning algorithms.
5651
5752
Algorithm that learns a Mahalanobis (pseudo) distance :math:`d_M(x, x')`,
@@ -91,12 +86,12 @@ def score_pairs(self, pairs):
9186
scores: `numpy.ndarray` of shape=(n_pairs,)
9287
The learned Mahalanobis distance for every pair.
9388
"""
94-
pairwise_diffs = self.embed(pairs[..., 1, :] - pairs[..., 0, :]) # (for
95-
# MahalanobisMixin, the embedding is linear so we can just embed the
89+
pairwise_diffs = self.transform(pairs[..., 1, :] - pairs[..., 0, :])
90+
# (for MahalanobisMixin, the embedding is linear so we can just embed the
9691
# difference)
9792
return np.sqrt(np.sum(pairwise_diffs**2, axis=-1))
9893

99-
def embed(self, X):
94+
def transform(self, X):
10095
"""Embeds data points in the learned linear embedding space.
10196
10297
Transforms samples in ``X`` into ``X_embedded``, samples inside a new
@@ -113,21 +108,37 @@ def embed(self, X):
113108
X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims)
114109
The embedded data points.
115110
"""
116-
return X.dot(self.transformer_.T)
111+
X_checked = check_array(X, accept_sparse=True, ensure_2d=False)
112+
return X_checked.dot(self.transformer_.T)
117113

118114
def metric(self):
119115
return self.transformer_.T.dot(self.transformer_)
120116

121-
def transformer_from_metric(self, metric):
117+
def _transformer_from_metric(self, metric):
122118
"""Computes the transformation matrix from the Mahalanobis matrix.
123119
124-
L = cholesky(M).T
120+
Since by definition the metric `M` is positive semi-definite (PSD), it
121+
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
122+
computation of the Cholesky decomposition used does not support
123+
non-definite matrices. If the metric is not definite, this method will
124+
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
125+
decomposition of M with the eigenvalues in the diagonal matrix w and the
126+
columns of V being the eigenvectors. If M is diagonal, this method will
127+
just return its elementwise square root (since the diagonalization of
128+
the matrix is itself).
125129
126130
Returns
127131
-------
128-
L : upper triangular (d x d) matrix
132+
L : (d x d) matrix
129133
"""
130-
return cholesky(metric).T
134+
135+
if np.allclose(metric, np.diag(np.diag(metric))):
136+
return np.sqrt(metric)
137+
elif not np.isclose(np.linalg.det(metric), 0):
138+
return cholesky(metric).T
139+
else:
140+
w, V = np.linalg.eigh(metric)
141+
return V.T * np.sqrt(np.maximum(0, w[:, None]))
131142

132143

133144
class _PairsClassifierMixin(BaseMetricLearner):
@@ -182,6 +193,24 @@ def score(self, pairs, y):
182193
class _QuadrupletsClassifierMixin(BaseMetricLearner):
183194

184195
def predict(self, quadruplets):
196+
"""Predicts the ordering between sample distances in input quadruplets.
197+
198+
For each quadruplet, returns 1 if the quadruplet is in the right order (
199+
first pair is more similar than second pair), and -1 if not.
200+
201+
Parameters
202+
----------
203+
quadruplets : array-like, shape=(n_constraints, 4, n_features)
204+
Input quadruplets.
205+
206+
Returns
207+
-------
208+
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
209+
Predictions of the ordering of pairs, for each quadruplet.
210+
"""
211+
return np.sign(self.decision_function(quadruplets))
212+
213+
def decision_function(self, quadruplets):
185214
"""Predicts differences between sample distances in input quadruplets.
186215
187216
For each quadruplet of samples, computes the difference between the learned
@@ -194,15 +223,12 @@ def predict(self, quadruplets):
194223
195224
Returns
196225
-------
197-
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
226+
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
198227
Metric differences.
199228
"""
200229
return (self.score_pairs(quadruplets[..., :2, :]) -
201230
self.score_pairs(quadruplets[..., 2:, :]))
202231

203-
def decision_function(self, quadruplets):
204-
return self.predict(quadruplets)
205-
206232
def score(self, quadruplets, y=None):
207233
"""Computes score on input quadruplets
208234
@@ -222,4 +248,4 @@ def score(self, quadruplets, y=None):
222248
score : float
223249
The quadruplets score.
224250
"""
225-
return - np.mean(np.sign(self.decision_function(quadruplets)))
251+
return - np.mean(self.predict(quadruplets))

metric_learn/covariance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from __future__ import absolute_import
1212
import numpy as np
1313
from sklearn.utils.validation import check_array
14+
from sklearn.base import TransformerMixin
1415

15-
from .base_metric import MahalanobisMixin, MetricTransformer
16+
from .base_metric import MahalanobisMixin
1617

1718

18-
class Covariance(MetricTransformer, MahalanobisMixin):
19+
class Covariance(MahalanobisMixin, TransformerMixin):
1920
def __init__(self):
2021
pass
2122

@@ -31,5 +32,5 @@ def fit(self, X, y=None):
3132
else:
3233
self.M_ = np.linalg.inv(self.M_)
3334

34-
self.transformer_ = self.transformer_from_metric(check_array(self.M_))
35+
self.transformer_ = self._transformer_from_metric(check_array(self.M_))
3536
return self

metric_learn/itml.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from six.moves import xrange
1919
from sklearn.metrics import pairwise_distances
2020
from sklearn.utils.validation import check_array, check_X_y
21-
from .base_metric import (_PairsClassifierMixin, MetricTransformer,
22-
MahalanobisMixin)
21+
from sklearn.base import TransformerMixin
22+
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2323
from .constraints import Constraints, wrap_pairs
2424
from ._util import vector_norm
2525

@@ -53,7 +53,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
5353

5454
def _process_pairs(self, pairs, y, bounds):
5555
pairs, y = check_X_y(pairs, y, accept_sparse=False,
56-
ensure_2d=False, allow_nd=True)
56+
ensure_2d=False, allow_nd=True)
5757

5858
# check to make sure that no two constrained vectors are identical
5959
pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1]
@@ -129,7 +129,7 @@ def _fit(self, pairs, y, bounds=None):
129129
print('itml converged at iter: %d, conv = %f' % (it, conv))
130130
self.n_iter_ = it
131131

132-
self.transformer_ = self.transformer_from_metric(self.A_)
132+
self.transformer_ = self._transformer_from_metric(self.A_)
133133
return self
134134

135135

@@ -155,7 +155,7 @@ def fit(self, pairs, y, bounds=None):
155155
return self._fit(pairs, y, bounds=bounds)
156156

157157

158-
class ITML_Supervised(_BaseITML, MetricTransformer):
158+
class ITML_Supervised(_BaseITML, TransformerMixin):
159159
"""Information Theoretic Metric Learning (ITML)"""
160160
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
161161
num_labeled=np.inf, num_constraints=None, bounds=None, A0=None,

metric_learn/lfda.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from six.moves import xrange
1818
from sklearn.metrics import pairwise_distances
1919
from sklearn.utils.validation import check_X_y
20+
from sklearn.base import TransformerMixin
21+
from .base_metric import MahalanobisMixin
2022

21-
from .base_metric import MahalanobisMixin, MetricTransformer
2223

23-
24-
class LFDA(MahalanobisMixin, MetricTransformer):
24+
class LFDA(MahalanobisMixin, TransformerMixin):
2525
'''
2626
Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction
2727
Sugiyama, ICML 2006

metric_learn/lmnn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from six.moves import xrange
1717
from sklearn.utils.validation import check_X_y, check_array
1818
from sklearn.metrics import euclidean_distances
19-
20-
from .base_metric import MahalanobisMixin, MetricTransformer
19+
from sklearn.base import TransformerMixin
20+
from .base_metric import MahalanobisMixin
2121

2222

2323
# commonality between LMNN implementations
24-
class _base_LMNN(MahalanobisMixin, MetricTransformer):
24+
class _base_LMNN(MahalanobisMixin, TransformerMixin):
2525
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
2626
regularization=0.5, convergence_tol=0.001, use_pca=True,
2727
verbose=False):
@@ -189,7 +189,7 @@ def _select_targets(self):
189189
return target_neighbors
190190

191191
def _find_impostors(self, furthest_neighbors):
192-
Lx = self.transform()
192+
Lx = self.transform(self.X_)
193193
margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx)
194194
impostors = []
195195
for label in self.labels_[:-1]:
@@ -256,7 +256,7 @@ def fit(self, X, y):
256256
self._lmnn.train()
257257
else:
258258
self._lmnn.train(np.eye(X.shape[1]))
259-
self.L_ = self._lmnn.get_linear_transform()
259+
self.L_ = self._lmnn.get_linear_transform(X)
260260
return self
261261

262262
except ImportError:

metric_learn/lsml.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import numpy as np
1212
import scipy.linalg
1313
from six.moves import xrange
14-
14+
from sklearn.base import TransformerMixin
1515
from sklearn.utils.validation import check_array, check_X_y
1616

17-
from .base_metric import (_QuadrupletsClassifierMixin, MetricTransformer,
18-
MahalanobisMixin)
17+
from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin
1918
from .constraints import Constraints
2019

2120

@@ -95,7 +94,7 @@ def _fit(self, quadruplets, weights=None):
9594
print("Didn't converge after", it, "iterations. Final loss:", s_best)
9695
self.n_iter_ = it
9796

98-
self.transformer_ = self.transformer_from_metric(self.M_)
97+
self.transformer_ = self._transformer_from_metric(self.M_)
9998
return self
10099

101100
def _comparison_loss(self, metric):
@@ -147,7 +146,7 @@ def fit(self, quadruplets, weights=None):
147146
return self._fit(quadruplets, weights=weights)
148147

149148

150-
class LSML_Supervised(_BaseLSML, MetricTransformer):
149+
class LSML_Supervised(_BaseLSML, TransformerMixin):
151150
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
152151
num_constraints=None, weights=None, verbose=False):
153152
"""Initialize the learner.

metric_learn/mlkr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010
import numpy as np
1111
from scipy.optimize import minimize
1212
from scipy.spatial.distance import pdist, squareform
13+
from sklearn.base import TransformerMixin
1314
from sklearn.decomposition import PCA
1415

1516
from sklearn.utils.validation import check_X_y
1617

17-
from .base_metric import MahalanobisMixin, MetricTransformer
18+
from .base_metric import MahalanobisMixin
1819

1920
EPS = np.finfo(float).eps
2021

2122

22-
class MLKR(MahalanobisMixin, MetricTransformer):
23+
class MLKR(MahalanobisMixin, TransformerMixin):
2324
"""Metric Learning for Kernel Regression (MLKR)"""
2425
def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001,
2526
max_iter=1000):

metric_learn/mmc.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
from __future__ import print_function, absolute_import, division
2020
import numpy as np
2121
from six.moves import xrange
22-
22+
from sklearn.base import TransformerMixin
2323
from sklearn.utils.validation import check_array, check_X_y
2424

25-
from .base_metric import (_PairsClassifierMixin, MahalanobisMixin,
26-
MetricTransformer)
25+
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2726
from .constraints import Constraints, wrap_pairs
2827
from ._util import vector_norm
2928

@@ -215,7 +214,7 @@ def _fit_full(self, pairs, y):
215214
self.A_[:] = A_old
216215
self.n_iter_ = cycle
217216

218-
self.transformer_ = self.transformer_from_metric(self.A_)
217+
self.transformer_ = self._transformer_from_metric(self.A_)
219218
return self
220219

221220
def _fit_diag(self, pairs, y):
@@ -275,7 +274,7 @@ def _fit_diag(self, pairs, y):
275274

276275
self.A_ = np.diag(w)
277276

278-
self.transformer_ = self.transformer_from_metric(self.A_)
277+
self.transformer_ = self._transformer_from_metric(self.A_)
279278
return self
280279

281280
def _fD(self, neg_pairs, A):
@@ -355,24 +354,6 @@ def _D_constraint(self, neg_pairs, w):
355354
sum_deri2 / sum_dist - np.outer(sum_deri1, sum_deri1) / (sum_dist * sum_dist)
356355
)
357356

358-
def transformer_from_metric(self, metric):
359-
"""Computes the transformation matrix from the Mahalanobis matrix.
360-
L = V.T * w^(-1/2), with A = V*w*V.T being the eigenvector decomposition of A with
361-
the eigenvalues in the diagonal matrix w and the columns of V being the eigenvectors.
362-
363-
The Cholesky decomposition cannot be applied here, since MMC learns only a positive
364-
*semi*-definite Mahalanobis matrix.
365-
366-
Returns
367-
-------
368-
L : (d x d) matrix
369-
"""
370-
if self.diagonal:
371-
return np.sqrt(metric)
372-
else:
373-
w, V = np.linalg.eigh(metric)
374-
return V.T * np.sqrt(np.maximum(0, w[:, None]))
375-
376357

377358
class MMC(_BaseMMC, _PairsClassifierMixin):
378359

@@ -394,7 +375,7 @@ def fit(self, pairs, y):
394375
return self._fit(pairs, y)
395376

396377

397-
class MMC_Supervised(_BaseMMC, MetricTransformer):
378+
class MMC_Supervised(_BaseMMC, TransformerMixin):
398379
"""Mahalanobis Metric for Clustering (MMC)"""
399380
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
400381
num_labeled=np.inf, num_constraints=None,

metric_learn/nca.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from __future__ import absolute_import
77
import numpy as np
88
from six.moves import xrange
9+
from sklearn.base import TransformerMixin
910
from sklearn.utils.validation import check_X_y
1011

11-
from .base_metric import MahalanobisMixin, MetricTransformer
12+
from .base_metric import MahalanobisMixin
1213

1314
EPS = np.finfo(float).eps
1415

1516

16-
class NCA(MahalanobisMixin, MetricTransformer):
17+
class NCA(MahalanobisMixin, TransformerMixin):
1718
def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01):
1819
self.num_dims = num_dims
1920
self.max_iter = max_iter

0 commit comments

Comments
 (0)