Skip to content

[MRG] Create new Mahalanobis mixin #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b3d739
WIP create MahalanobisMixin
May 25, 2018
f21cc85
ENH Update algorithms with Mahalanobis Mixin:
May 25, 2018
6f8a115
Merge branch 'new_api_design' into feat/mahalanobis_class
Jun 11, 2018
f9e3c82
FIX: add missing import
Jun 11, 2018
1a32c11
FIX: update sklearn's function check_no_fit_attributes_set_in_init to…
Jun 11, 2018
d0f5019
FIX: take function ``_get_args`` from scikit-learn's PR https://githu…
Jun 11, 2018
eba2a60
ENH: add transformer_ attribute and improve docstring
Jun 14, 2018
b5d966f
WIP: move transform() in BaseMetricLearner to transformer_from_metric…
Jun 18, 2018
ee0d1bd
WIP: refactor metric to original formulation: a function, with result…
Jun 18, 2018
6b5a3b5
WIP: make all Mahalanobis Metric Learner algorithms have transformer_…
Jun 19, 2018
6eb65ac
ENH Add score_pairs function
Jun 25, 2018
35ece36
TST add test on toy example for score_pairs
Jun 26, 2018
dca6838
ENH Add embed function
Jun 27, 2018
3254ce3
FIX fix error in slicing of quadruplets
Jun 27, 2018
e209b21
FIX minor corrections
Jun 27, 2018
abea7de
FIX minor corrections
Jun 27, 2018
65e794a
FIX fix PEP8 errors
Jun 27, 2018
12b5429
FIX remove possible one-sample scoring from docstring for now
Jun 27, 2018
eff278e
REF rename n_features_out to num_dims to be more coherent with curren…
Jun 27, 2018
810d191
MAINT: Adress https://github.com/metric-learn/metric-learn/pull/96#pu…
Jul 24, 2018
585b5d2
ENH: Add check_tuples
Jul 24, 2018
af0a3ac
FIX: fix parenthesis
Jul 24, 2018
f2b0163
ENH: put docstring of transformer_ in each metric learner
Aug 22, 2018
3c37fd7
FIX: style knitpicks to uniformize transformer_ docstring with childs
Aug 22, 2018
912c1db
FIX: make transformer_from_metric public
Aug 22, 2018
0e0ebf1
Address https://github.com/metric-learn/metric-learn/pull/96#pullrequ…
Aug 23, 2018
31350e8
FIX: fix pairwise distances check
Aug 23, 2018
d1f811b
FIX: ensure random state is set in all tests
Aug 23, 2018
4dd8990
FIX: fix test with real value to test in check_tuples
Aug 23, 2018
779a93a
FIX: update MetricTransformer to be abstract method and have the appr…
Sep 3, 2018
75d4ad2
Merge branch 'feat/mahalanobis_class' of https://github.com/wdevazelh…
Sep 3, 2018
657cdcd
MAINT: make BaseMetricLearner and MetricTransformer abstract
Sep 3, 2018
131ccbb
MAINT: remove __init__ method from BaseMetricLearner
Sep 3, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 97 additions & 29 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,29 @@
from sklearn.utils.validation import check_array
from sklearn.metrics import roc_auc_score
import numpy as np
from abc import ABCMeta, abstractmethod
import six


class BaseMetricLearner(BaseEstimator):
def __init__(self):
raise NotImplementedError('BaseMetricLearner should not be instantiated')

def metric(self):
"""Computes the Mahalanobis matrix from the transformation matrix.
@abstractmethod
def score_pairs(self, pairs):
"""Returns the score between pairs
(can be a similarity, or a distance/metric depending on the algorithm)

.. math:: M = L^{\\top} L
Parameters
----------
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
3D array of pairs.

Returns
-------
M : (d x d) matrix
scores: `numpy.ndarray` of shape=(n_pairs,)
The score of every pair.
"""
L = self.transformer()
return L.T.dot(L)

def transformer(self):
"""Computes the transformation matrix from the Mahalanobis matrix.

L = cholesky(M).T

Returns
-------
L : upper triangular (d x d) matrix
"""
return cholesky(self.metric()).T


class MetricTransformer(TransformerMixin):
Expand All @@ -52,11 +47,90 @@ def transform(self, X=None):
X = self.X_
else:
X = check_array(X, accept_sparse=True)
L = self.transformer()
L = self.transformer_
return X.dot(L.T)


class _PairsClassifierMixin:
class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner)):
"""Mahalanobis metric learning algorithms.

Algorithm that learns a Mahalanobis (pseudo) distance :math:`d_M(x, x')`,
defined between two column vectors :math:`x` and :math:`x'` by: :math:`d_M(x,
x') = \sqrt{(x-x')^T M (x-x')}`, where :math:`M` is a learned symmetric
positive semi-definite (PSD) matrix. The metric between points can then be
expressed as the euclidean distance between points embedded in a new space
through a linear transformation. Indeed, the above matrix can be decomposed
into the product of two transpose matrices (through SVD or Cholesky
decomposition): :math:`d_M(x, x')^2 = (x-x')^T M (x-x') = (x-x')^T L^T L
(x-x') = (L x - L x')^T (L x- L x')`

Attributes
----------
transformer_ : `np.ndarray`, shape=(num_dims, n_features)
The learned linear transformation ``L``.
"""

def score_pairs(self, pairs):
"""Returns the learned Mahalanobis distance between pairs.

This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}`
where ``M`` is the learned Mahalanobis matrix, for every pair of points
``x`` and ``x'``. This corresponds to the euclidean distance between
embeddings of the points in a new space, obtained through a linear
transformation. Indeed, we have also: :math:`d_M(x, x') = \sqrt{(x_e -
x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See
:class:`MahalanobisMixin`).

Parameters
----------
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
3D array of pairs, or 2D array of one pair.

Returns
-------
scores: `numpy.ndarray` of shape=(n_pairs,)
The learned Mahalanobis distance for every pair.
"""
pairwise_diffs = self.embed(pairs[..., 1, :] - pairs[..., 0, :]) # (for
# MahalanobisMixin, the embedding is linear so we can just embed the
# difference)
return np.sqrt(np.sum(pairwise_diffs**2, axis=-1))

def embed(self, X):
"""Embeds data points in the learned linear embedding space.

Transforms samples in ``X`` into ``X_embedded``, samples inside a new
embedding space such that: ``X_embedded = X.dot(L.T)``, where ``L`` is
the learned linear transformation (See :class:`MahalanobisMixin`).

Parameters
----------
X : `numpy.ndarray`, shape=(n_samples, n_features)
The data points to embed.

Returns
-------
X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims)
The embedded data points.
"""
return X.dot(self.transformer_.T)

def metric(self):
return self.transformer_.T.dot(self.transformer_)

def transformer_from_metric(self, metric):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably be a private method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

"""Computes the transformation matrix from the Mahalanobis matrix.

L = cholesky(M).T

Returns
-------
L : upper triangular (d x d) matrix
"""
return cholesky(metric).T


class _PairsClassifierMixin(BaseMetricLearner):

def predict(self, pairs):
"""Predicts the learned metric between input pairs.
Expand All @@ -74,9 +148,7 @@ def predict(self, pairs):
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
The predicted learned metric value between samples in every pair.
"""
pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :]
return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs,
axis=1))
return self.score_pairs(pairs)

def decision_function(self, pairs):
return self.predict(pairs)
Expand Down Expand Up @@ -107,7 +179,7 @@ def score(self, pairs, y):
return roc_auc_score(y, self.decision_function(pairs))


class _QuadrupletsClassifierMixin:
class _QuadrupletsClassifierMixin(BaseMetricLearner):

def predict(self, quadruplets):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be more logical if predict would compute accuracy on quadruplets (proportion of quadruplets correctly ordered) and score would compute the difference between distances

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having a predict and a decision_function doing the same thing is not ideal, and in the case of quadruplets this can be fixed easily since there is no threshold to fix like in pairs
But I did not really get what you mean ?
Since score is the scikit-learn like function for scoring it should return one scalar, whereas predict should return a sample-wise output
I guess the most coherent with scikit-learn would be that predict would output a binary sign (or 0 or 1) depending on the ordering of pairs in the quadruplet, and decision_function would return the differences between distances, since it is a float-type score (like predict_proba but without being necessarily between 0 and 1)
score would return the accuracy of the predict function (proportion of quadruplets correctly ordered)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is what I meant, sorry for the confusion

"""Predicts differences between sample distances in input quadruplets.
Expand All @@ -125,12 +197,8 @@ def predict(self, quadruplets):
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
Metric differences.
"""
similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :]
dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :]
return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) *
similar_diffs, axis=1)) -
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) *
dissimilar_diffs, axis=1)))
return (self.score_pairs(quadruplets[..., :2, :]) -
self.score_pairs(quadruplets[..., 2:, :]))

def decision_function(self, quadruplets):
return self.predict(quadruplets)
Expand Down
9 changes: 4 additions & 5 deletions metric_learn/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
import numpy as np
from sklearn.utils.validation import check_array

from .base_metric import BaseMetricLearner, MetricTransformer
from .base_metric import MahalanobisMixin, MetricTransformer


class Covariance(BaseMetricLearner, MetricTransformer):
class Covariance(MetricTransformer, MahalanobisMixin):
def __init__(self):
pass

def metric(self):
return self.M_

def fit(self, X, y=None):
"""
X : data matrix, (n x d)
Expand All @@ -33,4 +30,6 @@ def fit(self, X, y=None):
self.M_ = 1./self.M_
else:
self.M_ = np.linalg.inv(self.M_)

self.transformer_ = self.transformer_from_metric(check_array(self.M_))
return self
12 changes: 5 additions & 7 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
from six.moves import xrange
from sklearn.metrics import pairwise_distances
from sklearn.utils.validation import check_array, check_X_y

from .base_metric import (BaseMetricLearner, _PairsClassifierMixin,
MetricTransformer)
from .base_metric import (_PairsClassifierMixin, MetricTransformer,
MahalanobisMixin)
from .constraints import Constraints, wrap_pairs
from ._util import vector_norm


class _BaseITML(BaseMetricLearner):
class _BaseITML(MahalanobisMixin):
"""Information Theoretic Metric Learning (ITML)"""
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
A0=None, verbose=False):
Expand Down Expand Up @@ -129,10 +128,9 @@ def _fit(self, pairs, y, bounds=None):
if self.verbose:
print('itml converged at iter: %d, conv = %f' % (it, conv))
self.n_iter_ = it
return self

def metric(self):
return self.A_
self.transformer_ = self.transformer_from_metric(self.A_)
return self


class ITML(_BaseITML, _PairsClassifierMixin):
Expand Down
7 changes: 2 additions & 5 deletions metric_learn/lfda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from sklearn.metrics import pairwise_distances
from sklearn.utils.validation import check_X_y

from .base_metric import BaseMetricLearner, MetricTransformer
from .base_metric import MahalanobisMixin, MetricTransformer


class LFDA(BaseMetricLearner, MetricTransformer):
class LFDA(MahalanobisMixin, MetricTransformer):
'''
Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction
Sugiyama, ICML 2006
Expand Down Expand Up @@ -51,9 +51,6 @@ def __init__(self, num_dims=None, k=None, embedding_type='weighted'):
self.embedding_type = embedding_type
self.k = k

def transformer(self):
return self.transformer_

def _process_inputs(self, X, y):
unique_classes, y = np.unique(y, return_inverse=True)
self.X_, y = check_X_y(X, y)
Expand Down
13 changes: 5 additions & 8 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from sklearn.utils.validation import check_X_y, check_array
from sklearn.metrics import euclidean_distances

from .base_metric import BaseMetricLearner, MetricTransformer
from .base_metric import MahalanobisMixin, MetricTransformer


# commonality between LMNN implementations
class _base_LMNN(BaseMetricLearner, MetricTransformer):
class _base_LMNN(MahalanobisMixin, MetricTransformer):
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
regularization=0.5, convergence_tol=0.001, use_pca=True,
verbose=False):
Expand All @@ -44,9 +44,6 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
self.use_pca = use_pca
self.verbose = verbose

def transformer(self):
return self.L_


# slower Python version
class python_LMNN(_base_LMNN):
Expand All @@ -60,7 +57,7 @@ def _process_inputs(self, X, labels):
self.labels_ = np.arange(len(unique_labels))
if self.use_pca:
warnings.warn('use_pca does nothing for the python_LMNN implementation')
self.L_ = np.eye(num_dims)
self.transformer_ = np.eye(num_dims)
required_k = np.bincount(self.label_inds_).min()
if self.k > required_k:
raise ValueError('not enough class labels for specified k'
Expand Down Expand Up @@ -92,7 +89,7 @@ def fit(self, X, y):

# initialize gradient and L
G = dfG * reg + df * (1-reg)
L = self.L_
L = self.transformer_
objective = np.inf

# main loop
Expand Down Expand Up @@ -177,7 +174,7 @@ def fit(self, X, y):
print("LMNN didn't converge in %d steps." % self.max_iter)

# store the last L
self.L_ = L
self.transformer_ = L
self.n_iter_ = it
return self

Expand Down
12 changes: 6 additions & 6 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import numpy as np
import scipy.linalg
from six.moves import xrange

from sklearn.utils.validation import check_array, check_X_y

from .base_metric import (BaseMetricLearner, _QuadrupletsClassifierMixin,
MetricTransformer)
from .base_metric import (_QuadrupletsClassifierMixin, MetricTransformer,
MahalanobisMixin)
from .constraints import Constraints


class _BaseLSML(BaseMetricLearner):
class _BaseLSML(MahalanobisMixin):
def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False):
"""Initialize LSML.

Expand Down Expand Up @@ -58,9 +59,6 @@ def _prepare_quadruplets(self, quadruplets, weights):
self.M_ = self.prior
self.prior_inv_ = np.linalg.inv(self.prior)

def metric(self):
return self.M_

def _fit(self, quadruplets, weights=None):
self._prepare_quadruplets(quadruplets, weights)
step_sizes = np.logspace(-10, 0, 10)
Expand Down Expand Up @@ -96,6 +94,8 @@ def _fit(self, quadruplets, weights=None):
if self.verbose:
print("Didn't converge after", it, "iterations. Final loss:", s_best)
self.n_iter_ = it

self.transformer_ = self.transformer_from_metric(self.M_)
return self

def _comparison_loss(self, metric):
Expand Down
8 changes: 3 additions & 5 deletions metric_learn/mlkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from scipy.optimize import minimize
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA

from sklearn.utils.validation import check_X_y

from .base_metric import BaseMetricLearner, MetricTransformer
from .base_metric import MahalanobisMixin, MetricTransformer

EPS = np.finfo(float).eps


class MLKR(BaseMetricLearner, MetricTransformer):
class MLKR(MahalanobisMixin, MetricTransformer):
"""Metric Learning for Kernel Regression (MLKR)"""
def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001,
max_iter=1000):
Expand Down Expand Up @@ -90,9 +91,6 @@ def fit(self, X, y):
self.n_iter_ = res.nit
return self

def transformer(self):
return self.transformer_


def _loss(flatA, X, y, dX):
A = flatA.reshape((-1, X.shape[1]))
Expand Down
Loading