Skip to content

Commit 8ffd998

Browse files
belletperimosocordiae
authored andcommitted
[MRG] Move transformer_from_metric to util (#151)
* move method to util and update classes accordingly * remove forgotten self * typo
1 parent d00196d commit 8ffd998

File tree

7 files changed

+47
-44
lines changed

7 files changed

+47
-44
lines changed

metric_learn/_util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,30 @@ def check_collapsed_pairs(pairs):
322322
raise ValueError("{} collapsed pairs found (where the left element is "
323323
"the same as the right element), out of {} pairs "
324324
"in total.".format(num_ident, pairs.shape[0]))
325+
326+
327+
def transformer_from_metric(metric):
328+
"""Computes the transformation matrix from the Mahalanobis matrix.
329+
330+
Since by definition the metric `M` is positive semi-definite (PSD), it
331+
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
332+
computation of the Cholesky decomposition used does not support
333+
non-definite matrices. If the metric is not definite, this method will
334+
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
335+
decomposition of M with the eigenvalues in the diagonal matrix w and the
336+
columns of V being the eigenvectors. If M is diagonal, this method will
337+
just return its elementwise square root (since the diagonalization of
338+
the matrix is itself).
339+
340+
Returns
341+
-------
342+
L : (d x d) matrix
343+
"""
344+
345+
if np.allclose(metric, np.diag(np.diag(metric))):
346+
return np.sqrt(metric)
347+
elif not np.isclose(np.linalg.det(metric), 0):
348+
return np.linalg.cholesky(metric).T
349+
else:
350+
w, V = np.linalg.eigh(metric)
351+
return V.T * np.sqrt(np.maximum(0, w[:, None]))

metric_learn/base_metric.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from numpy.linalg import cholesky
21
from sklearn.base import BaseEstimator
32
from sklearn.utils.validation import _is_arraylike
43
from sklearn.metrics import roc_auc_score
@@ -181,32 +180,6 @@ def transform(self, X):
181180
def metric(self):
182181
return self.transformer_.T.dot(self.transformer_)
183182

184-
def transformer_from_metric(self, metric):
185-
"""Computes the transformation matrix from the Mahalanobis matrix.
186-
187-
Since by definition the metric `M` is positive semi-definite (PSD), it
188-
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
189-
computation of the Cholesky decomposition used does not support
190-
non-definite matrices. If the metric is not definite, this method will
191-
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
192-
decomposition of M with the eigenvalues in the diagonal matrix w and the
193-
columns of V being the eigenvectors. If M is diagonal, this method will
194-
just return its elementwise square root (since the diagonalization of
195-
the matrix is itself).
196-
197-
Returns
198-
-------
199-
L : (d x d) matrix
200-
"""
201-
202-
if np.allclose(metric, np.diag(np.diag(metric))):
203-
return np.sqrt(metric)
204-
elif not np.isclose(np.linalg.det(metric), 0):
205-
return cholesky(metric).T
206-
else:
207-
w, V = np.linalg.eigh(metric)
208-
return V.T * np.sqrt(np.maximum(0, w[:, None]))
209-
210183

211184
class _PairsClassifierMixin(BaseMetricLearner):
212185

metric_learn/covariance.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn.base import TransformerMixin
1414

1515
from .base_metric import MahalanobisMixin
16+
from ._util import transformer_from_metric
1617

1718

1819
class Covariance(MahalanobisMixin, TransformerMixin):
@@ -22,7 +23,7 @@ class Covariance(MahalanobisMixin, TransformerMixin):
2223
----------
2324
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
2425
The linear transformation ``L`` deduced from the learned Mahalanobis
25-
metric (See :meth:`transformer_from_metric`.)
26+
metric (See function `transformer_from_metric`.)
2627
"""
2728

2829
def __init__(self, preprocessor=None):
@@ -40,5 +41,5 @@ def fit(self, X, y=None):
4041
else:
4142
M = np.linalg.inv(M)
4243

43-
self.transformer_ = self.transformer_from_metric(np.atleast_2d(M))
44+
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
4445
return self

metric_learn/itml.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.base import TransformerMixin
2323
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2424
from .constraints import Constraints, wrap_pairs
25-
from ._util import vector_norm
25+
from ._util import vector_norm, transformer_from_metric
2626

2727

2828
class _BaseITML(MahalanobisMixin):
@@ -125,7 +125,7 @@ def _fit(self, pairs, y, bounds=None):
125125
print('itml converged at iter: %d, conv = %f' % (it, conv))
126126
self.n_iter_ = it
127127

128-
self.transformer_ = self.transformer_from_metric(self.A_)
128+
self.transformer_ = transformer_from_metric(self.A_)
129129
return self
130130

131131

@@ -136,7 +136,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
136136
----------
137137
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
138138
The linear transformation ``L`` deduced from the learned Mahalanobis
139-
metric (See :meth:`transformer_from_metric`.)
139+
metric (See function `transformer_from_metric`.)
140140
"""
141141

142142
def fit(self, pairs, y, bounds=None):
@@ -169,7 +169,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
169169
----------
170170
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
171171
The linear transformation ``L`` deduced from the learned Mahalanobis
172-
metric (See `transformer_from_metric`.)
172+
metric (See function `transformer_from_metric`.)
173173
"""
174174

175175
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,

metric_learn/lsml.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin
1818
from .constraints import Constraints
19+
from ._util import transformer_from_metric
1920

2021

2122
class _BaseLSML(MahalanobisMixin):
@@ -101,7 +102,7 @@ def _fit(self, quadruplets, y=None, weights=None):
101102
print("Didn't converge after", it, "iterations. Final loss:", s_best)
102103
self.n_iter_ = it
103104

104-
self.transformer_ = self.transformer_from_metric(self.M_)
105+
self.transformer_ = transformer_from_metric(self.M_)
105106
return self
106107

107108
def _comparison_loss(self, metric):
@@ -137,7 +138,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
137138
----------
138139
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
139140
The linear transformation ``L`` deduced from the learned Mahalanobis
140-
metric (See :meth:`transformer_from_metric`.)
141+
metric (See function `transformer_from_metric`.)
141142
"""
142143

143144
def fit(self, quadruplets, weights=None):
@@ -170,7 +171,7 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
170171
----------
171172
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
172173
The linear transformation ``L`` deduced from the learned Mahalanobis
173-
metric (See :meth:`transformer_from_metric`.)
174+
metric (See function `transformer_from_metric`.)
174175
"""
175176

176177
def __init__(self, tol=1e-3, max_iter=1000, prior=None,

metric_learn/mmc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
2727
from .constraints import Constraints, wrap_pairs
28-
from ._util import vector_norm
28+
from ._util import vector_norm, transformer_from_metric
2929

3030

3131
class _BaseMMC(MahalanobisMixin):
@@ -206,7 +206,7 @@ def _fit_full(self, pairs, y):
206206
self.A_[:] = A_old
207207
self.n_iter_ = cycle
208208

209-
self.transformer_ = self.transformer_from_metric(self.A_)
209+
self.transformer_ = transformer_from_metric(self.A_)
210210
return self
211211

212212
def _fit_diag(self, pairs, y):
@@ -267,7 +267,7 @@ def _fit_diag(self, pairs, y):
267267

268268
self.A_ = np.diag(w)
269269

270-
self.transformer_ = self.transformer_from_metric(self.A_)
270+
self.transformer_ = transformer_from_metric(self.A_)
271271
return self
272272

273273
def _fD(self, neg_pairs, A):
@@ -355,7 +355,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
355355
----------
356356
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
357357
The linear transformation ``L`` deduced from the learned Mahalanobis
358-
metric (See :meth:`transformer_from_metric`.)
358+
metric (See function `transformer_from_metric`.)
359359
"""
360360

361361
def fit(self, pairs, y):
@@ -386,7 +386,7 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
386386
----------
387387
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
388388
The linear transformation ``L`` deduced from the learned Mahalanobis
389-
metric (See :meth:`transformer_from_metric`.)
389+
metric (See function `transformer_from_metric`.)
390390
"""
391391

392392
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,

metric_learn/sdml.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .base_metric import MahalanobisMixin, _PairsClassifierMixin
1919
from .constraints import Constraints, wrap_pairs
20+
from ._util import transformer_from_metric
2021

2122

2223
class _BaseSDML(MahalanobisMixin):
@@ -68,7 +69,7 @@ def _fit(self, pairs, y):
6869
emp_cov = emp_cov.T.dot(emp_cov)
6970
_, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
7071

71-
self.transformer_ = self.transformer_from_metric(self.M_)
72+
self.transformer_ = transformer_from_metric(self.M_)
7273
return self
7374

7475

@@ -79,7 +80,7 @@ class SDML(_BaseSDML, _PairsClassifierMixin):
7980
----------
8081
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
8182
The linear transformation ``L`` deduced from the learned Mahalanobis
82-
metric (See :meth:`transformer_from_metric`.)
83+
metric (See function `transformer_from_metric`.)
8384
"""
8485

8586
def fit(self, pairs, y):
@@ -110,7 +111,7 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
110111
----------
111112
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
112113
The linear transformation ``L`` deduced from the learned Mahalanobis
113-
metric (See :meth:`transformer_from_metric`.)
114+
metric (See function `transformer_from_metric`.)
114115
"""
115116

116117
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,

0 commit comments

Comments
 (0)