Skip to content

Commit 34c159e

Browse files
author
William de Vazelhes
committed
Revert "[MRG] Move transformer_from_metric to util (scikit-learn-contrib#151)"
This reverts commit 8ffd998.
1 parent f60a8d7 commit 34c159e

File tree

7 files changed

+44
-47
lines changed

7 files changed

+44
-47
lines changed

metric_learn/_util.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -322,30 +322,3 @@ def check_collapsed_pairs(pairs):
322322
raise ValueError("{} collapsed pairs found (where the left element is "
323323
"the same as the right element), out of {} pairs "
324324
"in total.".format(num_ident, pairs.shape[0]))
325-
326-
327-
def transformer_from_metric(metric):
328-
"""Computes the transformation matrix from the Mahalanobis matrix.
329-
330-
Since by definition the metric `M` is positive semi-definite (PSD), it
331-
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
332-
computation of the Cholesky decomposition used does not support
333-
non-definite matrices. If the metric is not definite, this method will
334-
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
335-
decomposition of M with the eigenvalues in the diagonal matrix w and the
336-
columns of V being the eigenvectors. If M is diagonal, this method will
337-
just return its elementwise square root (since the diagonalization of
338-
the matrix is itself).
339-
340-
Returns
341-
-------
342-
L : (d x d) matrix
343-
"""
344-
345-
if np.allclose(metric, np.diag(np.diag(metric))):
346-
return np.sqrt(metric)
347-
elif not np.isclose(np.linalg.det(metric), 0):
348-
return np.linalg.cholesky(metric).T
349-
else:
350-
w, V = np.linalg.eigh(metric)
351-
return V.T * np.sqrt(np.maximum(0, w[:, None]))

metric_learn/base_metric.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from numpy.linalg import cholesky
12
from sklearn.base import BaseEstimator
23
from sklearn.utils.validation import _is_arraylike
34
from sklearn.metrics import roc_auc_score
@@ -180,6 +181,32 @@ def transform(self, X):
180181
def metric(self):
181182
return self.transformer_.T.dot(self.transformer_)
182183

184+
def transformer_from_metric(self, metric):
185+
"""Computes the transformation matrix from the Mahalanobis matrix.
186+
187+
Since by definition the metric `M` is positive semi-definite (PSD), it
188+
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
189+
computation of the Cholesky decomposition used does not support
190+
non-definite matrices. If the metric is not definite, this method will
191+
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
192+
decomposition of M with the eigenvalues in the diagonal matrix w and the
193+
columns of V being the eigenvectors. If M is diagonal, this method will
194+
just return its elementwise square root (since the diagonalization of
195+
the matrix is itself).
196+
197+
Returns
198+
-------
199+
L : (d x d) matrix
200+
"""
201+
202+
if np.allclose(metric, np.diag(np.diag(metric))):
203+
return np.sqrt(metric)
204+
elif not np.isclose(np.linalg.det(metric), 0):
205+
return cholesky(metric).T
206+
else:
207+
w, V = np.linalg.eigh(metric)
208+
return V.T * np.sqrt(np.maximum(0, w[:, None]))
209+
183210

184211
class _PairsClassifierMixin(BaseMetricLearner):
185212

metric_learn/covariance.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from sklearn.base import TransformerMixin
1414

1515
from .base_metric import MahalanobisMixin
16-
from ._util import transformer_from_metric
1716

1817

1918
class Covariance(MahalanobisMixin, TransformerMixin):
@@ -23,7 +22,7 @@ class Covariance(MahalanobisMixin, TransformerMixin):
2322
----------
2423
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
2524
The linear transformation ``L`` deduced from the learned Mahalanobis
26-
metric (See function `transformer_from_metric`.)
25+
metric (See :meth:`transformer_from_metric`.)
2726
"""
2827

2928
def __init__(self, preprocessor=None):
@@ -41,5 +40,5 @@ def fit(self, X, y=None):
4140
else:
4241
M = np.linalg.inv(M)
4342

44-
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
43+
self.transformer_ = self.transformer_from_metric(np.atleast_2d(M))
4544
return self

metric_learn/itml.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.base import TransformerMixin
2323
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2424
from .constraints import Constraints, wrap_pairs
25-
from ._util import vector_norm, transformer_from_metric
25+
from ._util import vector_norm
2626

2727

2828
class _BaseITML(MahalanobisMixin):
@@ -125,7 +125,7 @@ def _fit(self, pairs, y, bounds=None):
125125
print('itml converged at iter: %d, conv = %f' % (it, conv))
126126
self.n_iter_ = it
127127

128-
self.transformer_ = transformer_from_metric(self.A_)
128+
self.transformer_ = self.transformer_from_metric(self.A_)
129129
return self
130130

131131

@@ -136,7 +136,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
136136
----------
137137
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
138138
The linear transformation ``L`` deduced from the learned Mahalanobis
139-
metric (See function `transformer_from_metric`.)
139+
metric (See :meth:`transformer_from_metric`.)
140140
"""
141141

142142
def fit(self, pairs, y, bounds=None):
@@ -169,7 +169,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
169169
----------
170170
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
171171
The linear transformation ``L`` deduced from the learned Mahalanobis
172-
metric (See function `transformer_from_metric`.)
172+
metric (See `transformer_from_metric`.)
173173
"""
174174

175175
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,

metric_learn/lsml.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin
1818
from .constraints import Constraints
19-
from ._util import transformer_from_metric
2019

2120

2221
class _BaseLSML(MahalanobisMixin):
@@ -102,7 +101,7 @@ def _fit(self, quadruplets, y=None, weights=None):
102101
print("Didn't converge after", it, "iterations. Final loss:", s_best)
103102
self.n_iter_ = it
104103

105-
self.transformer_ = transformer_from_metric(self.M_)
104+
self.transformer_ = self.transformer_from_metric(self.M_)
106105
return self
107106

108107
def _comparison_loss(self, metric):
@@ -138,7 +137,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
138137
----------
139138
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
140139
The linear transformation ``L`` deduced from the learned Mahalanobis
141-
metric (See function `transformer_from_metric`.)
140+
metric (See :meth:`transformer_from_metric`.)
142141
"""
143142

144143
def fit(self, quadruplets, weights=None):
@@ -171,7 +170,7 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
171170
----------
172171
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
173172
The linear transformation ``L`` deduced from the learned Mahalanobis
174-
metric (See function `transformer_from_metric`.)
173+
metric (See :meth:`transformer_from_metric`.)
175174
"""
176175

177176
def __init__(self, tol=1e-3, max_iter=1000, prior=None,

metric_learn/mmc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2727
from .constraints import Constraints, wrap_pairs
28-
from ._util import vector_norm, transformer_from_metric
28+
from ._util import vector_norm
2929

3030

3131
class _BaseMMC(MahalanobisMixin):
@@ -206,7 +206,7 @@ def _fit_full(self, pairs, y):
206206
self.A_[:] = A_old
207207
self.n_iter_ = cycle
208208

209-
self.transformer_ = transformer_from_metric(self.A_)
209+
self.transformer_ = self.transformer_from_metric(self.A_)
210210
return self
211211

212212
def _fit_diag(self, pairs, y):
@@ -267,7 +267,7 @@ def _fit_diag(self, pairs, y):
267267

268268
self.A_ = np.diag(w)
269269

270-
self.transformer_ = transformer_from_metric(self.A_)
270+
self.transformer_ = self.transformer_from_metric(self.A_)
271271
return self
272272

273273
def _fD(self, neg_pairs, A):
@@ -355,7 +355,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
355355
----------
356356
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
357357
The linear transformation ``L`` deduced from the learned Mahalanobis
358-
metric (See function `transformer_from_metric`.)
358+
metric (See :meth:`transformer_from_metric`.)
359359
"""
360360

361361
def fit(self, pairs, y):
@@ -386,7 +386,7 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
386386
----------
387387
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
388388
The linear transformation ``L`` deduced from the learned Mahalanobis
389-
metric (See function `transformer_from_metric`.)
389+
metric (See :meth:`transformer_from_metric`.)
390390
"""
391391

392392
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,

metric_learn/sdml.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from .base_metric import MahalanobisMixin, _PairsClassifierMixin
1919
from .constraints import Constraints, wrap_pairs
20-
from ._util import transformer_from_metric
2120

2221

2322
class _BaseSDML(MahalanobisMixin):
@@ -66,7 +65,7 @@ def _fit(self, pairs, y):
6665
emp_cov = self.M_ + self.balance_param * loss_matrix
6766
_, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
6867

69-
self.transformer_ = transformer_from_metric(self.M_)
68+
self.transformer_ = self.transformer_from_metric(self.M_)
7069
return self
7170

7271

@@ -77,7 +76,7 @@ class SDML(_BaseSDML, _PairsClassifierMixin):
7776
----------
7877
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
7978
The linear transformation ``L`` deduced from the learned Mahalanobis
80-
metric (See function `transformer_from_metric`.)
79+
metric (See :meth:`transformer_from_metric`.)
8180
"""
8281

8382
def fit(self, pairs, y):
@@ -108,7 +107,7 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
108107
----------
109108
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
110109
The linear transformation ``L`` deduced from the learned Mahalanobis
111-
metric (See function `transformer_from_metric`.)
110+
metric (See :meth:`transformer_from_metric`.)
112111
"""
113112

114113
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,

0 commit comments

Comments
 (0)