diff --git a/examples/sandwich.py b/examples/sandwich.py index 34b48a00..08ec17c5 100644 --- a/examples/sandwich.py +++ b/examples/sandwich.py @@ -30,7 +30,7 @@ def sandwich_demo(): for ax_num, ml in enumerate(mls, start=3): ml.fit(x, y) - tx = ml.transform() + tx = ml.transform(x) ml_knn = nearest_neighbors(tx, k=2) ax = plt.subplot(3, 2, ax_num) plot_sandwich_data(tx, y, axis=ax) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index b34860d6..e7f24e7d 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -9,4 +9,42 @@ def vector_norm(X): return np.apply_along_axis(np.linalg.norm, 1, X) else: def vector_norm(X): - return np.linalg.norm(X, axis=1) \ No newline at end of file + return np.linalg.norm(X, axis=1) + + +def check_tuples(tuples): + """Check that the input is a valid 3D array representing a dataset of tuples. + + Equivalent of `check_array` in scikit-learn. + + Parameters + ---------- + tuples : object + The tuples to check. + + Returns + ------- + tuples_valid : object + The validated input. + """ + # If input is scalar raise error + if np.isscalar(tuples): + raise ValueError( + "Expected 3D array, got scalar instead. Cannot apply this function on " + "scalars.") + # If input is 1D raise error + if len(tuples.shape) == 1: + raise ValueError( + "Expected 3D array, got 1D array instead:\ntuples={}.\n" + "Reshape your data using tuples.reshape(1, -1, 1) if it contains a " + "single tuple and the points in the tuple have a single " + "feature.".format(tuples)) + # If input is 2D raise error + if len(tuples.shape) == 2: + raise ValueError( + "Expected 3D array, got 2D array instead:\ntuples={}.\n" + "Reshape your data either using tuples.reshape(-1, {}, 1) if " + "your data has a single feature or tuples.reshape(1, {}, -1) " + "if it contains a single tuple.".format(tuples, tuples.shape[1], + tuples.shape[0])) + return tuples diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 889de999..4044f223 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,62 +1,148 @@ from numpy.linalg import cholesky -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array from sklearn.metrics import roc_auc_score import numpy as np +from abc import ABCMeta, abstractmethod +import six +from ._util import check_tuples -class BaseMetricLearner(BaseEstimator): - def __init__(self): - raise NotImplementedError('BaseMetricLearner should not be instantiated') +class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)): - def metric(self): - """Computes the Mahalanobis matrix from the transformation matrix. + @abstractmethod + def score_pairs(self, pairs): + """Returns the score between pairs + (can be a similarity, or a distance/metric depending on the algorithm) - .. math:: M = L^{\\top} L + Parameters + ---------- + pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features) + 3D array of pairs. Returns ------- - M : (d x d) matrix + scores: `numpy.ndarray` of shape=(n_pairs,) + The score of every pair. """ - L = self.transformer() - return L.T.dot(L) - def transformer(self): - """Computes the transformation matrix from the Mahalanobis matrix. - L = cholesky(M).T +class MetricTransformer(six.with_metaclass(ABCMeta)): + + @abstractmethod + def transform(self, X): + """Applies the metric transformation. + + Parameters + ---------- + X : (n x d) matrix + Data to transform. Returns ------- - L : upper triangular (d x d) matrix + transformed : (n x d) matrix + Input data transformed to the metric space by :math:`XL^{\\top}` """ - return cholesky(self.metric()).T -class MetricTransformer(TransformerMixin): +class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner, + MetricTransformer)): + """Mahalanobis metric learning algorithms. + + Algorithm that learns a Mahalanobis (pseudo) distance :math:`d_M(x, x')`, + defined between two column vectors :math:`x` and :math:`x'` by: :math:`d_M(x, + x') = \sqrt{(x-x')^T M (x-x')}`, where :math:`M` is a learned symmetric + positive semi-definite (PSD) matrix. The metric between points can then be + expressed as the euclidean distance between points embedded in a new space + through a linear transformation. Indeed, the above matrix can be decomposed + into the product of two transpose matrices (through SVD or Cholesky + decomposition): :math:`d_M(x, x')^2 = (x-x')^T M (x-x') = (x-x')^T L^T L + (x-x') = (L x - L x')^T (L x- L x')` + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + + def score_pairs(self, pairs): + """Returns the learned Mahalanobis distance between pairs. + + This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}` + where ``M`` is the learned Mahalanobis matrix, for every pair of points + ``x`` and ``x'``. This corresponds to the euclidean distance between + embeddings of the points in a new space, obtained through a linear + transformation. Indeed, we have also: :math:`d_M(x, x') = \sqrt{(x_e - + x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See + :class:`MahalanobisMixin`). - def transform(self, X=None): - """Applies the metric transformation. + Parameters + ---------- + pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features) + 3D array of pairs, or 2D array of one pair. + + Returns + ------- + scores: `numpy.ndarray` of shape=(n_pairs,) + The learned Mahalanobis distance for every pair. + """ + pairs = check_tuples(pairs) + pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :]) + # (for MahalanobisMixin, the embedding is linear so we can just embed the + # difference) + return np.sqrt(np.sum(pairwise_diffs**2, axis=-1)) + + def transform(self, X): + """Embeds data points in the learned linear embedding space. + + Transforms samples in ``X`` into ``X_embedded``, samples inside a new + embedding space such that: ``X_embedded = X.dot(L.T)``, where ``L`` is + the learned linear transformation (See :class:`MahalanobisMixin`). Parameters ---------- - X : (n x d) matrix, optional - Data to transform. If not supplied, the training data will be used. + X : `numpy.ndarray`, shape=(n_samples, n_features) + The data points to embed. Returns ------- - transformed : (n x d) matrix - Input data transformed to the metric space by :math:`XL^{\\top}` + X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims) + The embedded data points. + """ + X_checked = check_array(X, accept_sparse=True) + return X_checked.dot(self.transformer_.T) + + def metric(self): + return self.transformer_.T.dot(self.transformer_) + + def transformer_from_metric(self, metric): + """Computes the transformation matrix from the Mahalanobis matrix. + + Since by definition the metric `M` is positive semi-definite (PSD), it + admits a Cholesky decomposition: L = cholesky(M).T. However, currently the + computation of the Cholesky decomposition used does not support + non-definite matrices. If the metric is not definite, this method will + return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector + decomposition of M with the eigenvalues in the diagonal matrix w and the + columns of V being the eigenvectors. If M is diagonal, this method will + just return its elementwise square root (since the diagonalization of + the matrix is itself). + + Returns + ------- + L : (d x d) matrix """ - if X is None: - X = self.X_ + + if np.allclose(metric, np.diag(np.diag(metric))): + return np.sqrt(metric) + elif not np.isclose(np.linalg.det(metric), 0): + return cholesky(metric).T else: - X = check_array(X, accept_sparse=True) - L = self.transformer() - return X.dot(L.T) + w, V = np.linalg.eigh(metric) + return V.T * np.sqrt(np.maximum(0, w[:, None])) -class _PairsClassifierMixin: +class _PairsClassifierMixin(BaseMetricLearner): def predict(self, pairs): """Predicts the learned metric between input pairs. @@ -74,11 +160,11 @@ def predict(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted learned metric value between samples in every pair. """ - pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :] - return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs, - axis=1)) + pairs = check_tuples(pairs) + return self.score_pairs(pairs) def decision_function(self, pairs): + pairs = check_tuples(pairs) return self.predict(pairs) def score(self, pairs, y): @@ -104,12 +190,32 @@ def score(self, pairs, y): score : float The ``roc_auc`` score. """ + pairs = check_tuples(pairs) return roc_auc_score(y, self.decision_function(pairs)) -class _QuadrupletsClassifierMixin: +class _QuadrupletsClassifierMixin(BaseMetricLearner): def predict(self, quadruplets): + """Predicts the ordering between sample distances in input quadruplets. + + For each quadruplet, returns 1 if the quadruplet is in the right order ( + first pair is more similar than second pair), and -1 if not. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Input quadruplets. + + Returns + ------- + prediction : `numpy.ndarray` of floats, shape=(n_constraints,) + Predictions of the ordering of pairs, for each quadruplet. + """ + quadruplets = check_tuples(quadruplets) + return np.sign(self.decision_function(quadruplets)) + + def decision_function(self, quadruplets): """Predicts differences between sample distances in input quadruplets. For each quadruplet of samples, computes the difference between the learned @@ -122,18 +228,12 @@ def predict(self, quadruplets): Returns ------- - prediction : `numpy.ndarray` of floats, shape=(n_constraints,) + decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ - similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :] - dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :] - return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) * - similar_diffs, axis=1)) - - np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) * - dissimilar_diffs, axis=1))) - - def decision_function(self, quadruplets): - return self.predict(quadruplets) + quadruplets = check_tuples(quadruplets) + return (self.score_pairs(quadruplets[:, :2, :]) - + self.score_pairs(quadruplets[:, 2:, :])) def score(self, quadruplets, y=None): """Computes score on input quadruplets @@ -154,4 +254,5 @@ def score(self, quadruplets, y=None): score : float The quadruplets score. """ - return - np.mean(np.sign(self.decision_function(quadruplets))) + quadruplets = check_tuples(quadruplets) + return -np.mean(self.predict(quadruplets)) diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 689650b4..4e8c1a0f 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -11,17 +11,24 @@ from __future__ import absolute_import import numpy as np from sklearn.utils.validation import check_array +from sklearn.base import TransformerMixin -from .base_metric import BaseMetricLearner, MetricTransformer +from .base_metric import MahalanobisMixin -class Covariance(BaseMetricLearner, MetricTransformer): +class Covariance(MahalanobisMixin, TransformerMixin): + """Covariance metric (baseline method) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + def __init__(self): pass - def metric(self): - return self.M_ - def fit(self, X, y=None): """ X : data matrix, (n x d) @@ -33,4 +40,6 @@ def fit(self, X, y=None): self.M_ = 1./self.M_ else: self.M_ = np.linalg.inv(self.M_) + + self.transformer_ = self.transformer_from_metric(check_array(self.M_)) return self diff --git a/metric_learn/itml.py b/metric_learn/itml.py index fc839611..d8bd24c2 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -18,14 +18,13 @@ from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y - -from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, - MetricTransformer) +from sklearn.base import TransformerMixin +from .base_metric import _PairsClassifierMixin, MahalanobisMixin from .constraints import Constraints, wrap_pairs -from ._util import vector_norm +from ._util import vector_norm, check_tuples -class _BaseITML(BaseMetricLearner): +class _BaseITML(MahalanobisMixin): """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, A0=None, verbose=False): @@ -53,8 +52,11 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, self.verbose = verbose def _process_pairs(self, pairs, y, bounds): + # for now we check_X_y and check_tuples but we should only + # check_tuples_y in the future pairs, y = check_X_y(pairs, y, accept_sparse=False, - ensure_2d=False, allow_nd=True) + ensure_2d=False, allow_nd=True) + pairs = check_tuples(pairs) # check to make sure that no two constrained vectors are identical pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] @@ -129,13 +131,20 @@ def _fit(self, pairs, y, bounds=None): if self.verbose: print('itml converged at iter: %d, conv = %f' % (it, conv)) self.n_iter_ = it - return self - def metric(self): - return self.A_ + self.transformer_ = self.transformer_from_metric(self.A_) + return self class ITML(_BaseITML, _PairsClassifierMixin): + """Information Theoretic Metric Learning (ITML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ def fit(self, pairs, y, bounds=None): """Learn the ITML model. @@ -157,8 +166,16 @@ def fit(self, pairs, y, bounds=None): return self._fit(pairs, y, bounds=bounds) -class ITML_Supervised(_BaseITML, MetricTransformer): - """Information Theoretic Metric Learning (ITML)""" +class ITML_Supervised(_BaseITML, TransformerMixin): + """Supervised version of Information Theoretic Metric Learning (ITML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See `transformer_from_metric`.) + """ + def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, num_labeled=np.inf, num_constraints=None, bounds=None, A0=None, verbose=False): @@ -191,6 +208,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, def fit(self, X, y, random_state=np.random): """Create constraints from labels and learn the ITML model. + Parameters ---------- X : (n x d) matrix diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index 03df5f24..c06fca91 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -17,15 +17,21 @@ from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_X_y +from sklearn.base import TransformerMixin +from .base_metric import MahalanobisMixin -from .base_metric import BaseMetricLearner, MetricTransformer - -class LFDA(BaseMetricLearner, MetricTransformer): +class LFDA(MahalanobisMixin, TransformerMixin): ''' Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction Sugiyama, ICML 2006 + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. ''' + def __init__(self, num_dims=None, k=None, embedding_type='weighted'): ''' Initialize LFDA. @@ -51,9 +57,6 @@ def __init__(self, num_dims=None, k=None, embedding_type='weighted'): self.embedding_type = embedding_type self.k = k - def transformer(self): - return self.transformer_ - def _process_inputs(self, X, y): unique_classes, y = np.unique(y, return_inverse=True) self.X_, y = check_X_y(X, y) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 581dc72a..7ce4d051 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -16,12 +16,12 @@ from six.moves import xrange from sklearn.utils.validation import check_X_y, check_array from sklearn.metrics import euclidean_distances - -from .base_metric import BaseMetricLearner, MetricTransformer +from sklearn.base import TransformerMixin +from .base_metric import MahalanobisMixin # commonality between LMNN implementations -class _base_LMNN(BaseMetricLearner, MetricTransformer): +class _base_LMNN(MahalanobisMixin, TransformerMixin): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, use_pca=True, verbose=False): @@ -44,9 +44,6 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, self.use_pca = use_pca self.verbose = verbose - def transformer(self): - return self.L_ - # slower Python version class python_LMNN(_base_LMNN): @@ -60,7 +57,7 @@ def _process_inputs(self, X, labels): self.labels_ = np.arange(len(unique_labels)) if self.use_pca: warnings.warn('use_pca does nothing for the python_LMNN implementation') - self.L_ = np.eye(num_dims) + self.transformer_ = np.eye(num_dims) required_k = np.bincount(self.label_inds_).min() if self.k > required_k: raise ValueError('not enough class labels for specified k' @@ -92,7 +89,7 @@ def fit(self, X, y): # initialize gradient and L G = dfG * reg + df * (1-reg) - L = self.L_ + L = self.transformer_ objective = np.inf # main loop @@ -177,7 +174,7 @@ def fit(self, X, y): print("LMNN didn't converge in %d steps." % self.max_iter) # store the last L - self.L_ = L + self.transformer_ = L self.n_iter_ = it return self @@ -192,7 +189,7 @@ def _select_targets(self): return target_neighbors def _find_impostors(self, furthest_neighbors): - Lx = self.transform() + Lx = self.transform(self.X_) margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) impostors = [] for label in self.labels_[:-1]: @@ -246,6 +243,13 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None): from modshogun import RealFeatures, MulticlassLabels class LMNN(_base_LMNN): + """Large Margin Nearest Neighbor (LMNN) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ def fit(self, X, y): self.X_, y = check_X_y(X, y, dtype=float) @@ -259,7 +263,7 @@ def fit(self, X, y): self._lmnn.train() else: self._lmnn.train(np.eye(X.shape[1])) - self.L_ = self._lmnn.get_linear_transform() + self.L_ = self._lmnn.get_linear_transform(X) return self except ImportError: diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index cdbc75d5..0e8b3513 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -11,14 +11,15 @@ import numpy as np import scipy.linalg from six.moves import xrange +from sklearn.base import TransformerMixin from sklearn.utils.validation import check_array, check_X_y +from ._util import check_tuples -from .base_metric import (BaseMetricLearner, _QuadrupletsClassifierMixin, - MetricTransformer) +from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin from .constraints import Constraints -class _BaseLSML(BaseMetricLearner): +class _BaseLSML(MahalanobisMixin): def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): """Initialize LSML. @@ -37,8 +38,11 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): self.verbose = verbose def _prepare_quadruplets(self, quadruplets, weights): - pairs = check_array(quadruplets, accept_sparse=False, - ensure_2d=False, allow_nd=True) + # for now we check_array and check_tuples but we should only + # check_tuples in the future (with enhanced check_tuples) + quadruplets = check_array(quadruplets, accept_sparse=False, + ensure_2d=False, allow_nd=True) + quadruplets = check_tuples(quadruplets) # check to make sure that no two constrained vectors are identical self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :] @@ -51,16 +55,14 @@ def _prepare_quadruplets(self, quadruplets, weights): self.w_ = weights self.w_ /= self.w_.sum() # weights must sum to 1 if self.prior is None: - X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) + X = np.vstack({tuple(row) for row in + quadruplets.reshape(-1, quadruplets.shape[2])}) self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False)) self.M_ = np.linalg.inv(self.prior_inv_) else: self.M_ = self.prior self.prior_inv_ = np.linalg.inv(self.prior) - def metric(self): - return self.M_ - def _fit(self, quadruplets, weights=None): self._prepare_quadruplets(quadruplets, weights) step_sizes = np.logspace(-10, 0, 10) @@ -96,6 +98,8 @@ def _fit(self, quadruplets, weights=None): if self.verbose: print("Didn't converge after", it, "iterations. Final loss:", s_best) self.n_iter_ = it + + self.transformer_ = self.transformer_from_metric(self.M_) return self def _comparison_loss(self, metric): @@ -125,6 +129,14 @@ def _gradient(self, metric): class LSML(_BaseLSML, _QuadrupletsClassifierMixin): + """Least Squared-residual Metric Learning (LSML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ def fit(self, quadruplets, weights=None): """Learn the LSML model. @@ -147,7 +159,16 @@ def fit(self, quadruplets, weights=None): return self._fit(quadruplets, weights=weights) -class LSML_Supervised(_BaseLSML, MetricTransformer): +class LSML_Supervised(_BaseLSML, TransformerMixin): + """Supervised version of Least Squared-residual Metric Learning (LSML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, num_constraints=None, weights=None, verbose=False): """Initialize the learner. diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index a16c40aa..9f774322 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -10,16 +10,25 @@ import numpy as np from scipy.optimize import minimize from scipy.spatial.distance import pdist, squareform +from sklearn.base import TransformerMixin from sklearn.decomposition import PCA + from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner, MetricTransformer +from .base_metric import MahalanobisMixin EPS = np.finfo(float).eps -class MLKR(BaseMetricLearner, MetricTransformer): - """Metric Learning for Kernel Regression (MLKR)""" +class MLKR(MahalanobisMixin, TransformerMixin): + """Metric Learning for Kernel Regression (MLKR) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001, max_iter=1000): """ @@ -90,9 +99,6 @@ def fit(self, X, y): self.n_iter_ = res.nit return self - def transformer(self): - return self.transformer_ - def _loss(flatA, X, y, dX): A = flatA.reshape((-1, X.shape[1])) diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index f61bb1c7..2f2ee400 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -19,15 +19,15 @@ from __future__ import print_function, absolute_import, division import numpy as np from six.moves import xrange +from sklearn.base import TransformerMixin from sklearn.utils.validation import check_array, check_X_y -from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, - MetricTransformer) +from .base_metric import _PairsClassifierMixin, MahalanobisMixin from .constraints import Constraints, wrap_pairs -from ._util import vector_norm +from ._util import vector_norm, check_tuples -class _BaseMMC(BaseMetricLearner): +class _BaseMMC(MahalanobisMixin): """Mahalanobis Metric for Clustering (MMC)""" def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, A0=None, diagonal=False, diagonal_c=1.0, verbose=False): @@ -65,8 +65,11 @@ def _fit(self, pairs, y): return self._fit_full(pairs, y) def _process_pairs(self, pairs, y): + # for now we check_X_y and check_tuples but we should only + # check_tuples_y in the future pairs, y = check_X_y(pairs, y, accept_sparse=False, - ensure_2d=False, allow_nd=True) + ensure_2d=False, allow_nd=True) + pairs = check_tuples(pairs) # check to make sure that no two constrained vectors are identical pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] @@ -213,6 +216,8 @@ def _fit_full(self, pairs, y): print('mmc converged at iter %d, conv = %f' % (cycle, delta)) self.A_[:] = A_old self.n_iter_ = cycle + + self.transformer_ = self.transformer_from_metric(self.A_) return self def _fit_diag(self, pairs, y): @@ -271,6 +276,8 @@ def _fit_diag(self, pairs, y): it += 1 self.A_ = np.diag(w) + + self.transformer_ = self.transformer_from_metric(self.A_) return self def _fD(self, neg_pairs, A): @@ -350,29 +357,16 @@ def _D_constraint(self, neg_pairs, w): sum_deri2 / sum_dist - np.outer(sum_deri1, sum_deri1) / (sum_dist * sum_dist) ) - def metric(self): - return self.A_ - - def transformer(self): - """Computes the transformation matrix from the Mahalanobis matrix. - L = V.T * w^(-1/2), with A = V*w*V.T being the eigenvector decomposition of A with - the eigenvalues in the diagonal matrix w and the columns of V being the eigenvectors. - - The Cholesky decomposition cannot be applied here, since MMC learns only a positive - *semi*-definite Mahalanobis matrix. - - Returns - ------- - L : (d x d) matrix - """ - if self.diagonal: - return np.sqrt(self.A_) - else: - w, V = np.linalg.eigh(self.A_) - return V.T * np.sqrt(np.maximum(0, w[:,None])) - class MMC(_BaseMMC, _PairsClassifierMixin): + """Mahalanobis Metric for Clustering (MMC) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ def fit(self, pairs, y): """Learn the MMC model. @@ -392,8 +386,16 @@ def fit(self, pairs, y): return self._fit(pairs, y) -class MMC_Supervised(_BaseMMC, MetricTransformer): - """Mahalanobis Metric for Clustering (MMC)""" +class MMC_Supervised(_BaseMMC, TransformerMixin): + """Supervised version of Mahalanobis Metric for Clustering (MMC) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, num_labeled=np.inf, num_constraints=None, A0=None, diagonal=False, diagonal_c=1.0, verbose=False): diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 9a6af0c3..19e016ec 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -6,22 +6,28 @@ from __future__ import absolute_import import numpy as np from six.moves import xrange +from sklearn.base import TransformerMixin from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner, MetricTransformer +from .base_metric import MahalanobisMixin EPS = np.finfo(float).eps -class NCA(BaseMetricLearner, MetricTransformer): +class NCA(MahalanobisMixin, TransformerMixin): + """Neighborhood Components Analysis (NCA) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01): self.num_dims = num_dims self.max_iter = max_iter self.learning_rate = learning_rate - def transformer(self): - return self.A_ - def fit(self, X, y): """ X: data matrix, (n x d) @@ -54,6 +60,6 @@ def fit(self, X, y): A += self.learning_rate * A.dot(d) self.X_ = X - self.A_ = A + self.transformer_ = A self.n_iter_ = it return self diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 36dd0aae..170e21f8 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -16,9 +16,10 @@ import warnings from six.moves import xrange from sklearn import decomposition +from sklearn.base import TransformerMixin from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner, MetricTransformer +from .base_metric import MahalanobisMixin from .constraints import Constraints @@ -35,8 +36,15 @@ def _chunk_mean_centering(data, chunks): return chunk_mask, chunk_data -class RCA(BaseMetricLearner, MetricTransformer): - """Relevant Components Analysis (RCA)""" +class RCA(MahalanobisMixin, TransformerMixin): + """Relevant Components Analysis (RCA) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + def __init__(self, num_dims=None, pca_comps=None): """Initialize the learner. @@ -55,9 +63,6 @@ def __init__(self, num_dims=None, pca_comps=None): self.num_dims = num_dims self.pca_comps = pca_comps - def transformer(self): - return self.transformer_ - def _process_data(self, X): self.X_ = X = check_array(X) @@ -136,6 +141,14 @@ def _inv_sqrtm(x): class RCA_Supervised(RCA): + """Supervised version of Relevant Components Analysis (RCA) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + def __init__(self, num_dims=None, pca_comps=None, num_chunks=100, chunk_size=2): """Initialize the learner. diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 2e40ad91..0d3c8b92 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -10,17 +10,17 @@ from __future__ import absolute_import import numpy as np -from scipy.sparse.csgraph import laplacian +from sklearn.base import TransformerMixin from sklearn.covariance import graph_lasso from sklearn.utils.extmath import pinvh from sklearn.utils.validation import check_array, check_X_y -from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, - MetricTransformer) +from .base_metric import MahalanobisMixin, _PairsClassifierMixin from .constraints import Constraints, wrap_pairs +from ._util import check_tuples -class _BaseSDML(BaseMetricLearner): +class _BaseSDML(MahalanobisMixin): def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose=False): """ @@ -44,8 +44,12 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, self.verbose = verbose def _prepare_pairs(self, pairs, y): + # for now we check_X_y and check_tuples but we should only + # check_tuples_y in the future pairs, y = check_X_y(pairs, y, accept_sparse=False, - ensure_2d=False, allow_nd=True) + ensure_2d=False, allow_nd=True) + pairs = check_tuples(pairs) + # set up prior M if self.use_cov: X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) @@ -55,9 +59,6 @@ def _prepare_pairs(self, pairs, y): diff = pairs[:, 0] - pairs[:, 1] return (diff.T * y).dot(diff) - def metric(self): - return self.M_ - def _fit(self, pairs, y): loss_matrix = self._prepare_pairs(pairs, y) P = self.M_ + self.balance_param * loss_matrix @@ -65,10 +66,20 @@ def _fit(self, pairs, y): # hack: ensure positive semidefinite emp_cov = emp_cov.T.dot(emp_cov) _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) + + self.transformer_ = self.transformer_from_metric(self.M_) return self class SDML(_BaseSDML, _PairsClassifierMixin): + """Sparse Distance Metric Learning (SDML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ def fit(self, pairs, y): """Learn the SDML model. @@ -88,7 +99,16 @@ def fit(self, pairs, y): return self._fit(pairs, y) -class SDML_Supervised(_BaseSDML, MetricTransformer): +class SDML_Supervised(_BaseSDML, TransformerMixin): + """Supervised version of Sparse Distance Metric Learning (SDML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, num_labeled=np.inf, num_constraints=None, verbose=False): """ diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 1756b105..1671c8ef 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -38,7 +38,7 @@ def test_iris(self): cov = Covariance() cov.fit(self.iris_points) - csep = class_separation(cov.transform(), self.iris_labels) + csep = class_separation(cov.transform(self.iris_points), self.iris_labels) # deterministic result self.assertAlmostEqual(csep, 0.73068122) @@ -68,7 +68,8 @@ def test_iris(self): lmnn = LMNN_cls(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.iris_points, self.iris_labels) - csep = class_separation(lmnn.transform(), self.iris_labels) + csep = class_separation(lmnn.transform(self.iris_points), + self.iris_labels) self.assertLess(csep, 0.25) @@ -97,12 +98,12 @@ def test_iris(self): [+0.2532, 0.5835, -0.8461, -0.8915], [-0.729, -0.6386, 1.767, 1.832], [-0.9405, -0.8461, 2.281, 2.794]] - assert_array_almost_equal(expected, nca.transformer(), decimal=3) + assert_array_almost_equal(expected, nca.transformer_, decimal=3) # With dimension reduction nca = NCA(max_iter=(100000//n), learning_rate=0.01, num_dims=2) nca.fit(self.iris_points, self.iris_labels) - csep = class_separation(nca.transform(), self.iris_labels) + csep = class_separation(nca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) @@ -110,19 +111,19 @@ class TestLFDA(MetricTestCase): def test_iris(self): lfda = LFDA(k=2, num_dims=2) lfda.fit(self.iris_points, self.iris_labels) - csep = class_separation(lfda.transform(), self.iris_labels) + csep = class_separation(lfda.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) # Sanity checks for learned matrices. self.assertEqual(lfda.metric().shape, (4, 4)) - self.assertEqual(lfda.transformer().shape, (2, 4)) + self.assertEqual(lfda.transformer_.shape, (2, 4)) class TestRCA(MetricTestCase): def test_iris(self): rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) rca.fit(self.iris_points, self.iris_labels) - csep = class_separation(rca.transform(), self.iris_labels) + csep = class_separation(rca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) def test_feature_null_variance(self): @@ -131,14 +132,14 @@ def test_feature_null_variance(self): # Apply PCA with the number of components rca = RCA_Supervised(num_dims=2, pca_comps=3, num_chunks=30, chunk_size=2) rca.fit(X, self.iris_labels) - csep = class_separation(rca.transform(), self.iris_labels) + csep = class_separation(rca.transform(X), self.iris_labels) self.assertLess(csep, 0.30) # Apply PCA with the minimum variance ratio rca = RCA_Supervised(num_dims=2, pca_comps=0.95, num_chunks=30, chunk_size=2) rca.fit(X, self.iris_labels) - csep = class_separation(rca.transform(), self.iris_labels) + csep = class_separation(rca.transform(X), self.iris_labels) self.assertLess(csep, 0.30) @@ -146,7 +147,7 @@ class TestMLKR(MetricTestCase): def test_iris(self): mlkr = MLKR() mlkr.fit(self.iris_points, self.iris_labels) - csep = class_separation(mlkr.transform(), self.iris_labels) + csep = class_separation(mlkr.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index d239ec95..f898a0fe 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -19,7 +19,7 @@ def setUpClass(self): def test_cov(self): cov = Covariance() cov.fit(self.X) - res_1 = cov.transform() + res_1 = cov.transform(self.X) cov = Covariance() res_2 = cov.fit_transform(self.X) @@ -53,7 +53,7 @@ def test_itml_supervised(self): def test_lmnn(self): lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) - res_1 = lmnn.transform() + res_1 = lmnn.transform(self.X) lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) res_2 = lmnn.fit_transform(self.X, self.y) @@ -76,7 +76,7 @@ def test_nca(self): n = self.X.shape[0] nca = NCA(max_iter=(100000//n), learning_rate=0.01) nca.fit(self.X, self.y) - res_1 = nca.transform() + res_1 = nca.transform(self.X) nca = NCA(max_iter=(100000//n), learning_rate=0.01) res_2 = nca.fit_transform(self.X, self.y) @@ -86,7 +86,7 @@ def test_nca(self): def test_lfda(self): lfda = LFDA(k=2, num_dims=2) lfda.fit(self.X, self.y) - res_1 = lfda.transform() + res_1 = lfda.transform(self.X) lfda = LFDA(k=2, num_dims=2) res_2 = lfda.fit_transform(self.X, self.y) @@ -100,7 +100,7 @@ def test_rca_supervised(self): seed = np.random.RandomState(1234) rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) rca.fit(self.X, self.y, random_state=seed) - res_1 = rca.transform() + res_1 = rca.transform(self.X) seed = np.random.RandomState(1234) rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) @@ -111,7 +111,7 @@ def test_rca_supervised(self): def test_mlkr(self): mlkr = MLKR(num_dims=2) mlkr.fit(self.X, self.y) - res_1 = mlkr.transform() + res_1 = mlkr.transform(self.X) mlkr = MLKR(num_dims=2) res_2 = mlkr.fit_transform(self.X, self.y) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py new file mode 100644 index 00000000..09a98ece --- /dev/null +++ b/test/test_mahalanobis_mixin.py @@ -0,0 +1,231 @@ +from itertools import product + +import pytest +import numpy as np +from numpy.testing import assert_array_almost_equal +from scipy.spatial.distance import pdist, squareform +from sklearn import clone +from sklearn.datasets import load_iris +from sklearn.utils import check_random_state, shuffle +from sklearn.utils.testing import set_random_state + +from metric_learn import (Constraints, ITML, LSML, MMC, SDML, Covariance, LFDA, + LMNN, MLKR, NCA, RCA) +from metric_learn.constraints import wrap_pairs +from functools import partial + +RNG = check_random_state(0) + +def build_data(): + dataset = load_iris() + X, y = shuffle(dataset.data, dataset.target, random_state=RNG) + num_constraints = 20 + constraints = Constraints.random_subset(y, random_state=RNG) + pairs = constraints.positive_negative_pairs(num_constraints, + same_length=True, + random_state=RNG) + return X, pairs + + +def build_pairs(): + # test that you can do cross validation on tuples of points with + # a WeaklySupervisedMetricLearner + X, pairs = build_data() + pairs, y = wrap_pairs(X, pairs) + pairs, y = shuffle(pairs, y, random_state=RNG) + return pairs, y + + +def build_quadruplets(): + # test that you can do cross validation on a tuples of points with + # a WeaklySupervisedMetricLearner + X, pairs = build_data() + c = np.column_stack(pairs) + quadruplets = X[c] + quadruplets = shuffle(quadruplets, random_state=RNG) + return quadruplets, None + + +list_estimators = [(Covariance(), build_data), + (ITML(), build_pairs), + (LFDA(), partial(load_iris, return_X_y=True)), + (LMNN(), partial(load_iris, return_X_y=True)), + (LSML(), build_quadruplets), + (MLKR(), partial(load_iris, return_X_y=True)), + (MMC(), build_pairs), + (NCA(), partial(load_iris, return_X_y=True)), + (RCA(), partial(load_iris, return_X_y=True)), + (SDML(), build_pairs) + ] + +ids_estimators = ['covariance', + 'itml', + 'lfda', + 'lmnn', + 'lsml', + 'mlkr', + 'mmc', + 'nca', + 'rca', + 'sdml', + ] + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_score_pairs_pairwise(estimator, build_dataset): + # Computing pairwise scores should return a euclidean distance matrix. + inputs, labels = build_dataset() + X, _ = load_iris(return_X_y=True) + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + + pairwise = model.score_pairs(np.array(list(product(X, X))))\ + .reshape(n_samples, n_samples) + + check_is_distance_matrix(pairwise) + + # a necessary condition for euclidean distance matrices: (see + # https://en.wikipedia.org/wiki/Euclidean_distance_matrix) + assert np.linalg.matrix_rank(pairwise**2) <= min(X.shape) + 2 + + # assert that this distance is coherent with pdist on embeddings + assert_array_almost_equal(squareform(pairwise), pdist(model.transform(X))) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_score_pairs_toy_example(estimator, build_dataset): + # Checks that score_pairs works on a toy example + inputs, labels = build_dataset() + X, _ = load_iris(return_X_y=True) + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + pairs = np.stack([X[:10], X[10:20]], axis=1) + embedded_pairs = pairs.dot(model.transformer_.T) + distances = np.sqrt(np.sum((embedded_pairs[:, 1] - + embedded_pairs[:, 0])**2, + axis=-1)) + assert_array_almost_equal(model.score_pairs(pairs), distances) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_score_pairs_finite(estimator, build_dataset): + # tests that the score is finite + inputs, labels = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + X, _ = load_iris(return_X_y=True) + pairs = np.array(list(product(X, X))) + assert np.isfinite(model.score_pairs(pairs)).all() + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_score_pairs_dim(estimator, build_dataset): + # scoring of 3D arrays should return 1D array (several tuples), + # and scoring of 2D arrays (one tuple) should return an error (like + # scikit-learn's error when scoring 1D arrays) + inputs, labels = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + X, _ = load_iris(return_X_y=True) + tuples = np.array(list(product(X, X))) + assert model.score_pairs(tuples).shape == (tuples.shape[0],) + msg = ("Expected 3D array, got 2D array instead:\ntuples={}.\n" + "Reshape your data either using tuples.reshape(-1, {}, 1) if " + "your data has a single feature or tuples.reshape(1, {}, -1) " + "if it contains a single tuple.".format(tuples, tuples.shape[1], + tuples.shape[0])) + with pytest.raises(ValueError, message=msg): + model.score_pairs(tuples[1]) + + +def check_is_distance_matrix(pairwise): + assert (pairwise >= 0).all() # positivity + assert np.array_equal(pairwise, pairwise.T) # symmetry + assert (pairwise.diagonal() == 0).all() # identity + # triangular inequality + tol = 1e-15 + assert (pairwise <= pairwise[:, :, np.newaxis] + + pairwise[:, np.newaxis, :] + tol).all() + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_embed_toy_example(estimator, build_dataset): + # Checks that embed works on a toy example + inputs, labels = build_dataset() + X, _ = load_iris(return_X_y=True) + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + embedded_points = X.dot(model.transformer_.T) + assert_array_almost_equal(model.transform(X), embedded_points) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_embed_dim(estimator, build_dataset): + # Checks that the the dimension of the output space is as expected + inputs, labels = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + X, _ = load_iris(return_X_y=True) + assert model.transform(X).shape == X.shape + + # assert that ValueError is thrown if input shape is 1D + err_msg = ("Expected 2D array, got 1D array instead:\narray={}.\n" + "Reshape your data either using array.reshape(-1, 1) if " + "your data has a single feature or array.reshape(1, -1) " + "if it contains a single sample.".format(X)) + with pytest.raises(ValueError, message=err_msg): + model.score_pairs(model.transform(X[0, :])) + # we test that the shape is also OK when doing dimensionality reduction + if type(model).__name__ in {'LFDA', 'MLKR', 'NCA', 'RCA'}: + model.set_params(num_dims=2) + model.fit(inputs, labels) + assert model.transform(X).shape == (X.shape[0], 2) + # assert that ValueError is thrown if input shape is 1D + with pytest.raises(ValueError, message=err_msg): + model.transform(model.transform(X[0, :])) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_embed_finite(estimator, build_dataset): + # Checks that embed returns vectors with finite values + inputs, labels = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + X, _ = load_iris(return_X_y=True) + assert np.isfinite(model.transform(X)).all() + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_embed_is_linear(estimator, build_dataset): + # Checks that the embedding is linear + inputs, labels = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(inputs, labels) + X, _ = load_iris(return_X_y=True) + assert_array_almost_equal(model.transform(X[:10] + X[10:20]), + model.transform(X[:10]) + + model.transform(X[10:20])) + assert_array_almost_equal(model.transform(5 * X[:10]), + 5 * model.transform(X[:10])) diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index e027d176..3b8f9e0e 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -19,60 +19,60 @@ def setUpClass(self): def test_cov(self): cov = Covariance() cov.fit(self.X) - L = cov.transformer() + L = cov.transformer_ assert_array_almost_equal(L.T.dot(L), cov.metric()) def test_lsml_supervised(self): seed = np.random.RandomState(1234) lsml = LSML_Supervised(num_constraints=200) lsml.fit(self.X, self.y, random_state=seed) - L = lsml.transformer() + L = lsml.transformer_ assert_array_almost_equal(L.T.dot(L), lsml.metric()) def test_itml_supervised(self): seed = np.random.RandomState(1234) itml = ITML_Supervised(num_constraints=200) itml.fit(self.X, self.y, random_state=seed) - L = itml.transformer() + L = itml.transformer_ assert_array_almost_equal(L.T.dot(L), itml.metric()) def test_lmnn(self): lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) - L = lmnn.transformer() + L = lmnn.transformer_ assert_array_almost_equal(L.T.dot(L), lmnn.metric()) def test_sdml_supervised(self): seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500) sdml.fit(self.X, self.y, random_state=seed) - L = sdml.transformer() + L = sdml.transformer_ assert_array_almost_equal(L.T.dot(L), sdml.metric()) def test_nca(self): n = self.X.shape[0] nca = NCA(max_iter=(100000//n), learning_rate=0.01) nca.fit(self.X, self.y) - L = nca.transformer() + L = nca.transformer_ assert_array_almost_equal(L.T.dot(L), nca.metric()) def test_lfda(self): lfda = LFDA(k=2, num_dims=2) lfda.fit(self.X, self.y) - L = lfda.transformer() + L = lfda.transformer_ assert_array_almost_equal(L.T.dot(L), lfda.metric()) def test_rca_supervised(self): seed = np.random.RandomState(1234) rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) rca.fit(self.X, self.y, random_state=seed) - L = rca.transformer() + L = rca.transformer_ assert_array_almost_equal(L.T.dot(L), rca.metric()) def test_mlkr(self): mlkr = MLKR(num_dims=2) mlkr.fit(self.X, self.y) - L = mlkr.transformer() + L = mlkr.transformer_ assert_array_almost_equal(L.T.dot(L), mlkr.metric()) diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..8ca3aac3 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +from metric_learn._util import check_tuples + + +def test_check_tuples(): + X = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + check_tuples(X) + + X = 5 + msg = ("Expected 3D array, got scalar instead. Cannot apply this function " + "on scalars.") + with pytest.raises(ValueError, message=msg): + check_tuples(X) + + X = np.array([1, 2, 3]) + msg = ("Expected 3D array, got 1D array instead:\ntuples=[1, 2, 3].\n" + "Reshape your data using tuples.reshape(1, -1, 1) if it contains a " + "single tuple and the points in the tuple have a single feature.") + with pytest.raises(ValueError, message=msg): + check_tuples(X) + + X = np.array([[1, 2, 3], [2, 3, 5]]) + msg = ("Expected 3D array, got 2D array instead:\ntuples=[[1, 2, 3], " + "[2, 3, 5]].\nReshape your data either using " + "tuples.reshape(-1, 3, 1) if your data has a single feature or " + "tuples.reshape(1, 2, -1) if it contains a single tuple.") + with pytest.raises(ValueError, message=msg): + check_tuples(X) diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py index 6386d22a..8cae4bfc 100644 --- a/test/test_weakly_supervised.py +++ b/test/test_weakly_supervised.py @@ -5,6 +5,7 @@ from sklearn.utils.estimator_checks import is_public_parameter from sklearn.utils.testing import (assert_allclose_dense_sparse, set_random_state) +from sklearn.utils.fixes import signature from metric_learn import ITML, MMC, SDML, LSML from metric_learn.constraints import wrap_pairs, Constraints @@ -12,13 +13,13 @@ import numpy as np from sklearn.model_selection import cross_val_score, train_test_split +RNG = check_random_state(0) def build_data(): - RNG = check_random_state(0) dataset = load_iris() X, y = shuffle(dataset.data, dataset.target, random_state=RNG) num_constraints = 20 - constraints = Constraints.random_subset(y) + constraints = Constraints.random_subset(y, random_state=RNG) pairs = constraints.positive_negative_pairs(num_constraints, same_length=True, random_state=RNG) @@ -30,9 +31,9 @@ def build_pairs(): # a WeaklySupervisedMetricLearner X, pairs = build_data() pairs, y = wrap_pairs(X, pairs) - pairs, y = shuffle(pairs, y) + pairs, y = shuffle(pairs, y, random_state=RNG) (pairs_train, pairs_test, y_train, - y_test) = train_test_split(pairs, y) + y_test) = train_test_split(pairs, y, random_state=RNG) return (pairs, y, pairs_train, pairs_test, y_train, y_test) @@ -43,9 +44,10 @@ def build_quadruplets(): X, pairs = build_data() c = np.column_stack(pairs) quadruplets = X[c] - quadruplets = shuffle(quadruplets) + quadruplets = shuffle(quadruplets, random_state=RNG) y = y_train = y_test = None - quadruplets_train, quadruplets_test = train_test_split(quadruplets) + quadruplets_train, quadruplets_test = train_test_split(quadruplets, + random_state=RNG) return (quadruplets, y, quadruplets_train, quadruplets_test, y_train, y_test) @@ -99,22 +101,32 @@ def test_simple_estimator(estimator, build_dataset): @pytest.mark.parametrize('estimator', [est[0] for est in list_estimators], ids=ids_estimators) -def test_no_fit_attributes_set_in_init(estimator): - """Check that Estimator.__init__ doesn't set trailing-_ attributes.""" - # From scikit-learn - estimator = clone(estimator) - for attr in dir(estimator): - if attr.endswith("_") and not attr.startswith("__"): - # This check is for properties, they can be listed in dir - # while at the same time have hasattr return False as long - # as the property getter raises an AttributeError - assert hasattr(estimator, attr), \ - ("By convention, attributes ending with '_' are " - "estimated from data in scikit-learn. Consequently they " - "should not be initialized in the constructor of an " - "estimator but in the fit method. Attribute {!r} " - "was found in estimator {}".format( - attr, type(estimator).__name__)) +def test_no_attributes_set_in_init(estimator): + """Check setting during init. Taken from scikit-learn.""" + estimator = clone(estimator) + if hasattr(type(estimator).__init__, "deprecated_original"): + return + + init_params = _get_args(type(estimator).__init__) + parents_init_params = [param for params_parent in + (_get_args(parent) for parent in + type(estimator).__mro__) + for param in params_parent] + + # Test for no setting apart from parameters during init + invalid_attr = (set(vars(estimator)) - set(init_params) - + set(parents_init_params)) + assert not invalid_attr, \ + ("Estimator %s should not set any attribute apart" + " from parameters during init. Found attributes %s." + % (type(estimator).__name__, sorted(invalid_attr))) + # Ensure that each parameter is set in init + invalid_attr = (set(init_params) - set(vars(estimator)) - + set(["self"])) + assert not invalid_attr, \ + ("Estimator %s should store all parameters" + " as an attribute during init. Did not find " + "attributes %s." % (type(estimator).__name__, sorted(invalid_attr))) @pytest.mark.parametrize('estimator, build_dataset', list_estimators, @@ -158,17 +170,24 @@ def test_dict_unchanged(estimator, build_dataset): (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) - if hasattr(estimator, "n_components"): - estimator.n_components = 1 + if hasattr(estimator, "num_dims"): + estimator.num_dims = 1 estimator.fit(tuples, y) - for method in ["predict", "transform", "decision_function", - "predict_proba"]: + for method in ["predict", "decision_function", "predict_proba"]: if hasattr(estimator, method): dict_before = estimator.__dict__.copy() getattr(estimator, method)(tuples) assert estimator.__dict__ == dict_before, \ ("Estimator changes __dict__ during %s" % method) + for method in ["transform"]: + if hasattr(estimator, method): + dict_before = estimator.__dict__.copy() + # we transform only 2D arrays (dataset of points) + getattr(estimator, method)(tuples[:, 0, :]) + assert estimator.__dict__ == dict_before, \ + ("Estimator changes __dict__ during %s" + % method) @pytest.mark.parametrize('estimator, build_dataset', list_estimators, @@ -179,8 +198,8 @@ def test_dont_overwrite_parameters(estimator, build_dataset): (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) - if hasattr(estimator, "n_components"): - estimator.n_components = 1 + if hasattr(estimator, "num_dims"): + estimator.num_dims = 1 dict_before_fit = estimator.__dict__.copy() estimator.fit(tuples, y) @@ -212,3 +231,23 @@ def test_dont_overwrite_parameters(estimator, build_dataset): " to change attributes started" " or ended with _, but" " %s changed" % ', '.join(attrs_changed_by_fit)) + + +def _get_args(function, varargs=False): + """Helper to get function arguments""" + + try: + params = signature(function).parameters + except ValueError: + # Error on builtin C function + return [] + args = [key for key, param in params.items() + if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)] + if varargs: + varargs = [param.name for param in params.values() + if param.kind == param.VAR_POSITIONAL] + if len(varargs) == 0: + varargs = None + return args, varargs + else: + return args