Skip to content

[MRG] Move transformer_from_metric to util #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions metric_learn/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,30 @@ def check_collapsed_pairs(pairs):
raise ValueError("{} collapsed pairs found (where the left element is "
"the same as the right element), out of {} pairs "
"in total.".format(num_ident, pairs.shape[0]))


def transformer_from_metric(metric):
"""Computes the transformation matrix from the Mahalanobis matrix.

Since by definition the metric `M` is positive semi-definite (PSD), it
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
computation of the Cholesky decomposition used does not support
non-definite matrices. If the metric is not definite, this method will
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
decomposition of M with the eigenvalues in the diagonal matrix w and the
columns of V being the eigenvectors. If M is diagonal, this method will
just return its elementwise square root (since the diagonalization of
the matrix is itself).

Returns
-------
L : (d x d) matrix
"""

if np.allclose(metric, np.diag(np.diag(metric))):
return np.sqrt(metric)
elif not np.isclose(np.linalg.det(metric), 0):
return np.linalg.cholesky(metric).T
else:
w, V = np.linalg.eigh(metric)
return V.T * np.sqrt(np.maximum(0, w[:, None]))
27 changes: 0 additions & 27 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from numpy.linalg import cholesky
from sklearn.base import BaseEstimator
from sklearn.utils.validation import _is_arraylike
from sklearn.metrics import roc_auc_score
Expand Down Expand Up @@ -181,32 +180,6 @@ def transform(self, X):
def metric(self):
return self.transformer_.T.dot(self.transformer_)

def transformer_from_metric(self, metric):
"""Computes the transformation matrix from the Mahalanobis matrix.

Since by definition the metric `M` is positive semi-definite (PSD), it
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
computation of the Cholesky decomposition used does not support
non-definite matrices. If the metric is not definite, this method will
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
decomposition of M with the eigenvalues in the diagonal matrix w and the
columns of V being the eigenvectors. If M is diagonal, this method will
just return its elementwise square root (since the diagonalization of
the matrix is itself).

Returns
-------
L : (d x d) matrix
"""

if np.allclose(metric, np.diag(np.diag(metric))):
return np.sqrt(metric)
elif not np.isclose(np.linalg.det(metric), 0):
return cholesky(metric).T
else:
w, V = np.linalg.eigh(metric)
return V.T * np.sqrt(np.maximum(0, w[:, None]))


class _PairsClassifierMixin(BaseMetricLearner):

Expand Down
5 changes: 3 additions & 2 deletions metric_learn/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.base import TransformerMixin

from .base_metric import MahalanobisMixin
from ._util import transformer_from_metric


class Covariance(MahalanobisMixin, TransformerMixin):
Expand All @@ -22,7 +23,7 @@ class Covariance(MahalanobisMixin, TransformerMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def __init__(self, preprocessor=None):
Expand All @@ -40,5 +41,5 @@ def fit(self, X, y=None):
else:
M = np.linalg.inv(M)

self.transformer_ = self.transformer_from_metric(np.atleast_2d(M))
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
return self
8 changes: 4 additions & 4 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sklearn.base import TransformerMixin
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
from .constraints import Constraints, wrap_pairs
from ._util import vector_norm
from ._util import vector_norm, transformer_from_metric


class _BaseITML(MahalanobisMixin):
Expand Down Expand Up @@ -125,7 +125,7 @@ def _fit(self, pairs, y, bounds=None):
print('itml converged at iter: %d, conv = %f' % (it, conv))
self.n_iter_ = it

self.transformer_ = self.transformer_from_metric(self.A_)
self.transformer_ = transformer_from_metric(self.A_)
return self


Expand All @@ -136,7 +136,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def fit(self, pairs, y, bounds=None):
Expand Down Expand Up @@ -169,7 +169,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See `transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
Expand Down
7 changes: 4 additions & 3 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin
from .constraints import Constraints
from ._util import transformer_from_metric


class _BaseLSML(MahalanobisMixin):
Expand Down Expand Up @@ -101,7 +102,7 @@ def _fit(self, quadruplets, y=None, weights=None):
print("Didn't converge after", it, "iterations. Final loss:", s_best)
self.n_iter_ = it

self.transformer_ = self.transformer_from_metric(self.M_)
self.transformer_ = transformer_from_metric(self.M_)
return self

def _comparison_loss(self, metric):
Expand Down Expand Up @@ -137,7 +138,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def fit(self, quadruplets, weights=None):
Expand Down Expand Up @@ -170,7 +171,7 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def __init__(self, tol=1e-3, max_iter=1000, prior=None,
Expand Down
10 changes: 5 additions & 5 deletions metric_learn/mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .base_metric import _PairsClassifierMixin, MahalanobisMixin
from .constraints import Constraints, wrap_pairs
from ._util import vector_norm
from ._util import vector_norm, transformer_from_metric


class _BaseMMC(MahalanobisMixin):
Expand Down Expand Up @@ -206,7 +206,7 @@ def _fit_full(self, pairs, y):
self.A_[:] = A_old
self.n_iter_ = cycle

self.transformer_ = self.transformer_from_metric(self.A_)
self.transformer_ = transformer_from_metric(self.A_)
return self

def _fit_diag(self, pairs, y):
Expand Down Expand Up @@ -267,7 +267,7 @@ def _fit_diag(self, pairs, y):

self.A_ = np.diag(w)

self.transformer_ = self.transformer_from_metric(self.A_)
self.transformer_ = transformer_from_metric(self.A_)
return self

def _fD(self, neg_pairs, A):
Expand Down Expand Up @@ -355,7 +355,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def fit(self, pairs, y):
Expand Down Expand Up @@ -386,7 +386,7 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
Expand Down
7 changes: 4 additions & 3 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .base_metric import MahalanobisMixin, _PairsClassifierMixin
from .constraints import Constraints, wrap_pairs
from ._util import transformer_from_metric


class _BaseSDML(MahalanobisMixin):
Expand Down Expand Up @@ -68,7 +69,7 @@ def _fit(self, pairs, y):
emp_cov = emp_cov.T.dot(emp_cov)
_, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)

self.transformer_ = self.transformer_from_metric(self.M_)
self.transformer_ = transformer_from_metric(self.M_)
return self


Expand All @@ -79,7 +80,7 @@ class SDML(_BaseSDML, _PairsClassifierMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def fit(self, pairs, y):
Expand Down Expand Up @@ -110,7 +111,7 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
metric (See function `transformer_from_metric`.)
"""

def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
Expand Down