diff --git a/doc/conf.py b/doc/conf.py index 1c8beeab..679f0db0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -9,6 +9,10 @@ 'numpydoc', ] +autodoc_default_flags = ['members', 'inherited-members'] + +default_role='any' + templates_path = ['_templates'] source_suffix = '.rst' master_doc = 'index' diff --git a/doc/index.rst b/doc/index.rst index f50781fe..3f3d97c1 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -14,17 +14,34 @@ This package contains efficient Python implementations of several popular metric learning algorithms. .. toctree:: - :caption: Algorithms + :caption: Unsupervised Algorithms :maxdepth: 1 metric_learn.covariance - metric_learn.lmnn + +.. toctree:: + :caption: Weakly Supervised algorithms + :maxdepth: 1 + + metric_learn.weakly_supervised + metric_learn.mmc metric_learn.itml metric_learn.sdml metric_learn.lsml + +Note that all Weakly Supervised Metric Learners have a supervised version. See +:ref:`this section` for more details. + + +.. toctree:: + :caption: Supervised algorithms + :maxdepth: 1 + + metric_learn.lmnn metric_learn.nca metric_learn.lfda metric_learn.rca + metric_learn.mlkr Each metric supports the following methods: @@ -34,6 +51,9 @@ Each metric supports the following methods: data matrix :math:`X \in \mathbb{R}^{n \times d}` to the :math:`D`-dimensional learned metric space :math:`X L^{\top}`, in which standard Euclidean distances may be used. + +.. _transform_ml: + - ``transform(X)``, which applies the aforementioned transformation. - ``metric()``, which returns a Mahalanobis matrix :math:`M = L^{\top}L` such that distance between vectors ``x`` and diff --git a/doc/metric_learn.base_metric.rst b/doc/metric_learn.base_metric.rst index 050a360b..d252d77b 100644 --- a/doc/metric_learn.base_metric.rst +++ b/doc/metric_learn.base_metric.rst @@ -5,3 +5,4 @@ metric_learn.base_metric module :members: :undoc-members: :show-inheritance: + diff --git a/doc/metric_learn.constrained_dataset.rst b/doc/metric_learn.constrained_dataset.rst new file mode 100644 index 00000000..727a6f1e --- /dev/null +++ b/doc/metric_learn.constrained_dataset.rst @@ -0,0 +1,8 @@ +ConstrainedDataset +================== + +.. autoclass:: metric_learn.constraints.ConstrainedDataset + :members: + :undoc-members: + :show-inheritance: + diff --git a/doc/metric_learn.weakly_supervised.rst b/doc/metric_learn.weakly_supervised.rst new file mode 100644 index 00000000..2def13ed --- /dev/null +++ b/doc/metric_learn.weakly_supervised.rst @@ -0,0 +1,220 @@ +.. _wsml: + +Weakly Supervised Learning (General information) +================================================ + +Introduction +------------ + +In Distance Metric Learning, we are interested in learning a metric between +points that takes into account some supervised information about the +similarity between those points. If each point has a class, we can use this +information by saying that all intra-class points are similar, and inter-class +points are dissimilar. + +However, sometimes we do not have a class for each sample. Instead, we can have +pairs of points and a label for each saying whether the points in each pair are +similar or not. Indeed, if we imagine a hand labeled dataset of images with a +huge number of classes, it will be easier for a human to say whether two images +are similar rather that telling, among the huge number of classes, which one is +that of the shown image. We can also have a dataset of triplets of points where +we know the first sample is more similar to the second than the third. Or we +could also have quadruplets of points where the two first points are more +similar than the two last are. In fact, some metric learning algorithms are +made to use this kind of data. These are Weakly Supervised Metric Learners. For +instance, `ITML`, `MMC` and `SDML` work on labeled pairs, and `LSML` works on +unlabeled quadruplets. + +In the ``metric-learn`` package, we use an object called `ConstrainedDataset` +to store these kinds of datasets where each sample/line is a tuple of points +from an initial dataset. Contrary to a 3D numpy array where each line would be +a tuple of ``t`` points from an initial dataset, `ConstrainedDataset` is +memory efficient as it does not duplicate points in the underlying memory. +Instead, it stores indices of points involved in every tuple, as well as the +initial dataset. Plus, it supports slicing on tuples to be compatible with +scikit-learn utilities for cross-validation (see :ref:`performance_ws`). + +See documentation of `ConstrainedDataset` `here` for more +information. + + + +.. _workflow_ws: + +Basic worflow +------------- + +Let us see how we can use weakly supervised metric learners in a basic +scikit-learn like workflow with ``fit``, ``predict``, ``transform``, +``score`` etc. + +- Fitting + +Let's say we have a dataset of samples and we also know for some pairs of them +if they are similar of dissimilar. We want to fit a metric learner on this +data. First, we recognize this data is made of labeled pairs. What we will need +to do first is therefore to make a `ConstrainedDataset` with as input the +points ``X`` (an array of shape ``(n_samples, n_features)``, and the +constraints ``c`` (an array of shape ``(n_constraints, 2))`` of indices of +pairs. We also need to have a vector ``y_constraints`` of shape +``(n_constraints,)`` where each ``y_constraints_i`` is 1 if sample +``X[c[i,0]]`` is similar to sample ``X[c[i, 1]]`` and 0 if they are dissimilar. + +.. code:: python + + from metric_learn import ConstrainedDataset + X_constrained = ConstrainedDataset(X, c) + +Then we can fit a Weakly Supervised Metric Learner (here that inherits from +`PairsMixin`), on this data (let's use `MMC` for instance): + +.. code:: python + + from metric_learn import MMC + mmc = MMC() + mmc.fit(X_constrained, y_constraints) + +.. _transform_ws: + +- Transforming + +Weakly supervised metric learners can also be used as transformers. Let us say +we have a fitted estimator. At ``transform`` time, they can independently be +used on arrays of samples as well as `ConstrainedDataset`s. Indeed, they will +return transformed samples and thus only need input samples (they will ignore +any information on constraints in the input). The transformed samples are the +new points in an embedded space. See :ref:`this section` for +more details about this transformation. + +.. code:: python + + mmc.transform(X) + +- Predicting + +Weakly Supervised Metric Learners work on lines of data where each line is a +tuple of points of an original dataset. For some of these, we should also have +a label for each line (for instance in the cases of learning on pairs, each +label ``y_constraints_i`` should tell whether the pair in line ``i`` is a +similar or dissimilar pair). So for these algorithm, applying ``predict`` to an +input ConstrainedDataset will predict scalars related to this task for each +tuple. For instance in the case of pairs, ``predict`` will return for each +input pair a float measuring the similarity between samples in the pair. + +See the API documentation for `WeaklySupervisedMixin`'s childs +( `PairsMixin`, +`TripletsMixin`, `QuadrupletsMixin`) for the particular prediction functions of +each type of Weakly Supervised Metric Learner. + +.. code:: python + + mmc.predict(X_constrained) + +- Scoring + +We can also use scoring functions like this, calling the default scoring +function of the Weakly Supervised Learner we use: + +.. code:: python + + mmc.score(X_constrained, y_constraints) + +The type of score depends on the type of Weakly Supervised Metric Learner +used. See the API documentation for `WeaklySupervisedMixin`'s childs +(`PairsMixin`, `TripletsMixin`, `QuadrupletsMixin`) for the particular +default scoring functions of each type of estimator. + +See also :ref:`performance_ws`, for how to use scikit-learn's +cross-validation routines with Weakly Supervised Metric Learners. + + +.. _supervised_version: + +Supervised Version +------------------ + +Weakly Supervised Metric Learners can also be used in a supervised way: the +corresponding supervised algorithm will create a +`ConstrainedDataset` ``X_constrained`` +and labels +``y_constraints`` of tuples from a supervised dataset with labels. For +instance if we want to use the algorithm `MMC` on a dataset of points and +labels +(``X`` and ``y``), +we should use ``MMC_Supervised`` (the underlying code will create pairs of +samples from the same class and labels saying that they are similar, and pairs +of samples from a different class and labels saying that they are +dissimilar, before calling `MMC`). + +Example: + +.. code:: python + + from sklearn.datasets import make_classification + + X, y = make_classification() + mmc_supervised = MMC_Supervised() + mmc_supervised.fit_transform(X, y) + + +.. _performance_ws: + +Evaluating the performance of weakly supervised metric learning algorithms +-------------------------------------------------------------------------- + +To evaluate the performance of a classical supervised algorithm that takes in +an input dataset ``X`` and some labels ``y``, we can compute a cross-validation +score. However, weakly supervised algorithms cannot ``predict`` on one sample, +so we cannot split on samples to make a training set and a test set the same +way as we do with usual estimators. Instead, metric learning algorithms output +a score on a **tuple** of samples: for instance a similarity score on pairs of +samples. So doing cross-validation scoring for metric learning algorithms +implies to split on **tuples** of samples. Hopefully, `ConstrainedDataset` +allows to do so naturally. + +Here is how we would get the cross-validation score for the ``MMC`` algorithm: + +.. code:: python + + from sklearn.model_selection import cross_val_score + cross_val_score(mmc, X_constrained, y_constraints) + + +Pipelining +---------- + +Weakly Supervised Learners can also be embedded in scikit-learn pipelines. +However, they can only be combined with Transformers. This is because there +is already supervision from constraints and we cannot add more +supervision that would be used from scikit-learn's supervised estimators. + +For instance, you can combine it with another transformer like PCA or KMeans: + +.. code:: python + + from sklearn.decomposition import PCA + from sklearn.clustering import KMeans + from sklearn.pipeline import make_pipeline + + pipe_pca = make_pipeline(MMC(), PCA()) + pipe_pca.fit(X_constrained, y) + pipe_clustering = make_pipeline(MMC(), KMeans()) + pipe_clustering.fit(X_constrained, y) + +There are also some other things to keep in mind: + +- The ``X`` type input of the pipeline should be a `ConstrainedDataset` when + fitting, but when transforming or predicting it can be an array of samples. + Therefore, all the following lines are valid: + + .. code:: python + + pipe_pca.transform(X_constrained) + pipe_pca.fit_transform(X_constrained) + pipe_pca.transform(X_constrained.X) + +- You should also not try to cross-validate those pipelines with scikit-learn's + cross-validation functions (as their input data is a `ConstrainedDataset` + which when splitting can contain same points between train and test (but + of course not the same tuple of points)). + diff --git a/metric_learn/__init__.py b/metric_learn/__init__.py index b86c10e1..d32bda4f 100644 --- a/metric_learn/__init__.py +++ b/metric_learn/__init__.py @@ -1,6 +1,6 @@ from __future__ import absolute_import -from .constraints import Constraints +from .constraints import Constraints, ConstrainedDataset from .covariance import Covariance from .itml import ITML, ITML_Supervised from .lmnn import LMNN diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 02519de1..f559dd4e 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,9 +1,13 @@ -from numpy.linalg import inv, cholesky +from sklearn.metrics import roc_auc_score + +from metric_learn.constraints import ConstrainedDataset +from numpy.linalg import cholesky from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils.validation import check_array +import numpy as np +class BaseMetricLearner(BaseEstimator): -class BaseMetricLearner(BaseEstimator, TransformerMixin): def __init__(self): raise NotImplementedError('BaseMetricLearner should not be instantiated') @@ -35,8 +39,9 @@ def transform(self, X=None): Parameters ---------- - X : (n x d) matrix, optional + X : (n x d) matrix or `ConstrainedDataset` , optional Data to transform. If not supplied, the training data will be used. + In the case of a ConstrainedDataset, X_constrained.X is used. Returns ------- @@ -45,7 +50,287 @@ def transform(self, X=None): """ if X is None: X = self.X_ + elif type(X) is ConstrainedDataset: + X = X.X else: X = check_array(X, accept_sparse=True) L = self.transformer() return X.dot(L.T) + + +class SupervisedMixin(TransformerMixin): + + def __init__(self): + raise NotImplementedError('UnsupervisedMixin should not be instantiated') + + def fit(self, X, y): + return NotImplementedError + + +class UnsupervisedMixin(TransformerMixin): + + def __init__(self): + raise NotImplementedError('UnsupervisedMixin should not be instantiated') + + def fit(self, X, y=None): + return NotImplementedError + + +class WeaklySupervisedMixin(object): + + def __init__(self): + raise NotImplementedError('WeaklySupervisedMixin should not be ' + 'instantiated') + + def fit_transform(self, X_constrained, y=None, **fit_params): + """Fit to data, then transform it. + + Fits transformer to X and y with optional parameters fit_params + and returns a transformed version of X. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, t, n_features) + Training set of ``n_constraints`` tuples of samples. + y : None, or numpy array of shape [n_constraints] + Constraints labels. + """ + if y is None: + # fit method of arity 1 (unsupervised transformation) + return self.fit(X_constrained, **fit_params).transform(X_constrained) + else: + # fit method of arity 2 (supervised transformation) + return self.fit(X_constrained, y, **fit_params).transform(X_constrained) + + def decision_function(self, X_constrained): + return self.predict(X_constrained) + + +class PairsMixin(WeaklySupervisedMixin): + + def __init__(self): + raise NotImplementedError('PairsMixin should not be instantiated') + + def fit(self, X_constrained, y_constraints, **kwargs): + """Fit a pairs based metric learner. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 2, n_features) + Training `ConstrainedDataset`. + + y_constraints : array-like, shape=(n_constraints,) + Labels of constraints (0 for similar pairs, 1 for dissimilar). + + kwargs : Any + Algorithm specific parameters. + + Returns + ------- + self : The fitted estimator. + """ + return self._fit(X_constrained, y_constraints, **kwargs) + + def predict(self, X_constrained): + """Predicts the learned similarity between input pairs + + Returns the learned metric value between samples in every pair. It should + ideally be low for similar samples and high for dissimilar samples. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 2, n_features) + A constrained dataset of paired samples. + + Returns + ------- + y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) + The predicted learned metric value between samples in every pair. + """ + # TODO: provide better implementation + pairwise_diffs = (X_constrained.X[X_constrained.c[:, 0]] - + X_constrained.X[X_constrained.c[:, 1]]) + return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs, + axis=1)) + + def score(self, X_constrained, y_constraints): + """Computes score of pairs similarity prediction. + + Returns the ``roc_auc`` score of the fitted metric learner. It is + computed in the following way: for every value of a threshold + ``t`` we classify all pairs of samples where the predicted distance is + inferior to ``t`` as belonging to the "similar" class, and the other as + belonging to the "dissimilar" class, and we count false positive and + true positives as in a classical ``roc_auc`` curve. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 2, n_features) + Constrained dataset of paired samples. + + y_constraints : array-like, shape=(n_constraints,) + The corresponding labels. + + Returns + ------- + score : float + The ``roc_auc`` score. + """ + return roc_auc_score(y_constraints, self.decision_function(X_constrained)) + + +class TripletsMixin(WeaklySupervisedMixin): + + def __init__(self): + raise NotImplementedError('TripletsMixin should not be ' + 'instantiated') + + def fit(self, X_constrained, y=None, **kwargs): + """Fit a triplets based metric learner. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 3, n_features) + Training `ConstrainedDataset`. To give the right supervision to the + algorithm, the first two points should be more similar than the first + and the third. + + y : Ignored, for scikit-learn compatibility. + + kwargs : Any + Algorithm specific parameters. + + Returns + ------- + self : The fitted estimator. + """ + return self._fit(X_constrained, **kwargs) + + + def predict(self, X_constrained): + """Predict the difference between samples similarities in input triplets. + + For each triplet of samples in ``X_constrained``, returns the + difference between the learned similarity between the first and the + second point, minus the similarity between the first and the third point. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 3, n_features) + Input constrained dataset. + + Returns + ------- + prediction : `numpy.ndarray` of floats, shape=(n_constraints,) + Predictions for each triplet. + """ + # TODO: provide better implementation + similar_diffs = X_constrained.X[X_constrained.c[:, 0]] - \ + X_constrained.X[X_constrained.c[:, 1]] + dissimilar_diffs = X_constrained.X[X_constrained.c[:, 0]] - \ + X_constrained.X[X_constrained.c[:, 2]] + return np.sqrt(np.sum(similar_diffs.dot(self.metric()) * + similar_diffs, axis=1)) - \ + np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) * + dissimilar_diffs, axis=1)) + + def score(self, X_constrained, y=None): + """Computes score of triplets similarity prediction. + + Returns the accuracy score of the following classification task: a record + is correctly classified if the predicted similarity between the first two + samples is higher than that between the first and the third. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 3, n_features) + Constrained dataset of triplets of samples. + + y: Ignored (for scikit-learn compatibility). + + Returns + ------- + score: float + The triplets score. + """ + predicted_sign = self.decision_function(X_constrained) < 0 + return np.sum(predicted_sign) / predicted_sign.shape[0] + + + +class QuadrupletsMixin(WeaklySupervisedMixin): + + def __init__(self): + raise NotImplementedError('QuadrupletsMixin should not be ' + 'instantiated') + + def fit(self, X_constrained, y=None, **kwargs): + """Fit a quadruplets based metric learner. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 4, n_features) + Training `ConstrainedDataset`. To give the right supervision to the + algorithm, the first two points should be more similar than the last two. + + y : Ignored, for scikit-learn compatibility. + + kwargs : Any + Algorithm specific parameters. + + Returns + ------- + self : The fitted estimator. + """ + return self._fit(X_constrained, **kwargs) + + def predict(self, X_constrained): + """Predicts differences between sample similarities in input quadruplets. + + For each quadruplet of samples in ``X_constrained``, computes the + difference between the learned metric of the first pair minus the learned + metric of the second pair. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 4, n_features) + Input constrained dataset. + + Returns + ------- + prediction : np.ndarray of floats, shape=(n_constraints,) + Metric differences. + """ + similar_diffs = X_constrained.X[X_constrained.c[:, 0]] - \ + X_constrained.X[X_constrained.c[:, 1]] + dissimilar_diffs = X_constrained.X[X_constrained.c[:, 2]] - \ + X_constrained.X[X_constrained.c[:, 3]] + return np.sqrt(np.sum(similar_diffs.dot(self.metric()) * + similar_diffs, axis=1)) - \ + np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) * + dissimilar_diffs, axis=1)) + + def decision_fuction(self, X_constrained): + return self.predict(X_constrained) + + def score(self, X_constrained, y=None): + """Computes score on an input constrained dataset + + Returns the accuracy score of the following classification task: a record + is correctly classified if the predicted similarity between the first two + samples is higher than that of the last two. + + Parameters + ---------- + X_constrained : `ConstrainedDataset`, shape=(n_constraints, 4, n_features) + Input constrained dataset. + + y : Ignored, for scikit-learn compatibility. + + Returns + ------- + score : float + The quadruplets score. + """ + predicted_sign = self.decision_function(X_constrained) < 0 + return np.sum(predicted_sign) / predicted_sign.shape[0] \ No newline at end of file diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 8824450a..6c3b6333 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -5,9 +5,10 @@ import numpy as np import warnings from six.moves import xrange -from scipy.sparse import coo_matrix +from scipy.sparse import coo_matrix, issparse +from sklearn.utils import check_array -__all__ = ['Constraints'] +__all__ = ['Constraints', 'ConstrainedDataset'] class Constraints(object): @@ -18,17 +19,6 @@ def __init__(self, partial_labels): self.known_label_idx, = np.where(partial_labels >= 0) self.known_labels = partial_labels[self.known_label_idx] - def adjacency_matrix(self, num_constraints, random_state=np.random): - a, b, c, d = self.positive_negative_pairs(num_constraints, - random_state=random_state) - row = np.concatenate((a, c)) - col = np.concatenate((b, d)) - data = np.ones_like(row, dtype=int) - data[len(a):] = -1 - adj = coo_matrix((data, (row, col)), shape=(self.num_points,)*2) - # symmetrize - return adj + adj.T - def positive_negative_pairs(self, num_constraints, same_length=False, random_state=np.random): a, b = self._pairs(num_constraints, same_label=True, @@ -100,3 +90,144 @@ def random_subset(all_labels, num_preserved=np.inf, random_state=np.random): partial_labels = np.array(all_labels, copy=True) partial_labels[idx] = -1 return Constraints(partial_labels) + + +class ConstrainedDataset(object): + """Constrained Dataset + + This is what weakly supervised metric learning algorithms take as input. It + wraps a dataset ``X`` and some constraints ``c``. It mocks a 3D array of + shape ``(n_constraints, t, n_features)``, where each line contains t + samples from ``X``. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + X: array-like, shape=(n_samples, n_features) + Dataset of samples. + + c: array-like of integers between 0 and n_samples, shape=(n_constraints, t) + Array of indexes of the ``t`` samples to consider in each constraint. + + Attributes + ---------- + X: array-like, shape=(n_samples, n_features) + The dataset ``X`` stored in the `ConstrainedDataset`. + + c: array-like, shape=(n_constraints, t) + The current array of indices that is stored in the `ConstrainedDataset`. + + shape: tuple, len==3. + The shape of the `ConstrainedDataset`. It is (n_constraints, t, + n_features), where ``t`` is the number of samples in each tuple. + + Examples + -------- + X is a regular array-like dataset, with 4 samples of 3 features each. Let + us say we also have pair constraints. + + >>> X = [[1., 5., 6.], [7., 5., 2.], [9., 2., 0.], [2., 8., 4.]] + >>> constraints = [[0, 2], [1, 3], [2, 3]] + + The first element of the new dataset will be the pair of sample 0 and + sample 2. We can later have a labels array ``y_constraints`` which will + say if this pair is positive (similar samples) or negative. + + >>> X_constrained = ConstrainedDataset(X, constraints) + >>> X_constrained.toarray() + array([[[ 1., 5., 6.], + [ 9., 2., 0.]], + [[ 7., 5., 2.], + [ 2., 8., 4.]], + [[ 9., 2., 0.], + [ 2., 8., 4.]]]) + + """ + + def __init__(self, X, c): + # we convert the data to a suitable format + self.X = check_array(X, accept_sparse=True, warn_on_dtype=True) + self.c = check_array(c, dtype=['int'] + np.sctypes['int'] + + np.sctypes['uint'], + # we add 'int' at the beginning to tell it is the + # default format we want in case of conversion + ensure_2d=False, ensure_min_samples=False, + ensure_min_features=False, warn_on_dtype=True) + self._check_index(self.X.shape[0], self.c) + self.shape = (len(c) if hasattr(c, '__len__') else 0, self.c.shape[1] if + (len(self.c.shape) > 1 if hasattr(c, 'shape') else 0) else 0, + self.X.shape[1]) + + def __getitem__(self, item): + return ConstrainedDataset(self.X, self.c[item]) + + def __len__(self): + return self.shape[0] + + def __str__(self): + return self.toarray().__str__() + + def __repr__(self): + return self.toarray().__repr__() + + def toarray(self): + if issparse(self.X): + # if X is sparse we convert it to dense because sparse arrays cannot + # be 3D + return self.X.A[self.c] + else: + return self.X[self.c] + + @staticmethod + def _check_index(length, indices): + max_index = np.max(indices) + min_index = np.min(indices) + pb_index = None + if max_index >= length: + pb_index = max_index + elif min_index > length + 1: + pb_index = min_index + if pb_index is not None: + raise IndexError("ConstrainedDataset cannot be created: the length of " + "the dataset is {}, so index {} is out of range." + .format(length, pb_index)) + + @staticmethod + def pairs_from_labels(y): + # TODO: to be implemented + raise NotImplementedError + + @staticmethod + def triplets_from_labels(y): + # TODO: to be implemented + raise NotImplementedError + + +def unwrap_pairs(X_constrained, y): + y_zero = (y == 0).ravel() + a, b = X_constrained.c[y_zero].T + c, d = X_constrained.c[~y_zero].T + X = X_constrained.X + return X, [a, b, c, d] + +def wrap_pairs(X, constraints): + a = np.array(constraints[0]) + b = np.array(constraints[1]) + c = np.array(constraints[2]) + d = np.array(constraints[3]) + constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d)))) + y = np.vstack([np.zeros((len(a), 1)), np.ones((len(c), 1))]) + X_constrained = ConstrainedDataset(X, constraints) + return X_constrained, y + +def unwrap_to_graph(X_constrained, y): + + X, [a, b, c, d] = unwrap_pairs(X_constrained, y) + row = np.concatenate((a, c)) + col = np.concatenate((b, d)) + data = np.ones_like(row, dtype=int) + data[len(a):] = -1 + adj = coo_matrix((data, (row, col)), shape=(X_constrained.X.shape[0],) + * 2) + return X_constrained.X, adj + adj.T \ No newline at end of file diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 8fc07873..e25a1894 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -12,10 +12,10 @@ import numpy as np from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, UnsupervisedMixin -class Covariance(BaseMetricLearner): +class Covariance(BaseMetricLearner, UnsupervisedMixin): def __init__(self): pass diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 4d27c412..74085219 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -19,12 +19,12 @@ from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .base_metric import BaseMetricLearner, PairsMixin, SupervisedMixin +from .constraints import Constraints, unwrap_pairs, wrap_pairs from ._util import vector_norm -class ITML(BaseMetricLearner): +class _ITML(BaseMetricLearner): """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, A0=None, verbose=False): @@ -73,19 +73,19 @@ def _process_inputs(self, X, constraints, bounds): self.A_ = check_array(self.A0) return a,b,c,d - def fit(self, X, constraints, bounds=None): + def _fit(self, X_constrained, y_constraints, bounds=None): """Learn the ITML model. Parameters ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, with (a,b) specifying positive and (c,d) - negative pairs + X_constrained : ConstrainedDataset + with constraints being an array of shape [n_constraints, 2] + y_constraints : array-like, shape (n_constraints x 1) + labels of the constraints bounds : list (pos,neg) pairs, optional bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg """ + X, constraints = unwrap_pairs(X_constrained, y_constraints) a,b,c,d = self._process_inputs(X, constraints, bounds) gamma = self.gamma num_pos = len(a) @@ -140,7 +140,7 @@ def metric(self): return self.A_ -class ITML_Supervised(ITML): +class ITML_Supervised(_ITML, SupervisedMixin): """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, num_labeled=np.inf, num_constraints=None, bounds=None, A0=None, @@ -164,7 +164,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, verbose : bool, optional if True, prints information while learning """ - ITML.__init__(self, gamma=gamma, max_iter=max_iter, + _ITML.__init__(self, gamma=gamma, max_iter=max_iter, convergence_threshold=convergence_threshold, A0=A0, verbose=verbose) self.num_labeled = num_labeled @@ -195,4 +195,9 @@ def fit(self, X, y, random_state=np.random): random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) - return ITML.fit(self, X, pos_neg, bounds=self.bounds) + X_constrained, y = wrap_pairs(X, pos_neg) + return _ITML._fit(self, X_constrained, y, bounds=self.bounds) + +class ITML(_ITML, PairsMixin): + + pass diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index dbe5aa4f..16cb5634 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -18,10 +18,10 @@ from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, SupervisedMixin -class LFDA(BaseMetricLearner): +class LFDA(BaseMetricLearner, SupervisedMixin): ''' Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction Sugiyama, ICML 2006 diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index dea12f0c..7b1e8613 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -17,7 +17,7 @@ from sklearn.utils.validation import check_X_y, check_array from sklearn.metrics import euclidean_distances -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, SupervisedMixin # commonality between LMNN implementations @@ -49,7 +49,7 @@ def transformer(self): # slower Python version -class python_LMNN(_base_LMNN): +class python_LMNN(_base_LMNN, SupervisedMixin): def _process_inputs(self, X, labels): self.X_ = check_array(X, dtype=float) @@ -239,13 +239,12 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None): return np.dot(Xab.T, Xab * weights[:,None]) return np.dot(Xab.T, Xab) - try: # use the fast C++ version, if available from modshogun import LMNN as shogun_LMNN from modshogun import RealFeatures, MulticlassLabels - class LMNN(_base_LMNN): + class LMNN(_base_LMNN, SupervisedMixin): def fit(self, X, y): self.X_, y = check_X_y(X, y, dtype=float) @@ -262,5 +261,7 @@ def fit(self, X, y): self.L_ = self._lmnn.get_linear_transform() return self + except ImportError: + LMNN = python_LMNN diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 404fe286..4a24c77e 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -13,11 +13,11 @@ from six.moves import xrange from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .base_metric import BaseMetricLearner, SupervisedMixin, QuadrupletsMixin +from .constraints import Constraints, ConstrainedDataset -class LSML(BaseMetricLearner): +class _LSML(BaseMetricLearner): def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): """Initialize LSML. @@ -57,18 +57,23 @@ def _prepare_inputs(self, X, constraints, weights): def metric(self): return self.M_ - def fit(self, X, constraints, weights=None): + def _fit(self, X_constrained, weights=None): """Learn the LSML model. Parameters ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, such that d(X[a],X[b]) < d(X[c],X[d]) + X_constrained : ConstrainedDataset + with constraints being an array of shape [n_constraints, 4]. It + should be the concatenation of 4 column vectors a, b, c and d, + such that: ``d(X[a[i]],X[b[i]]) < d(X[c[i]],X[d[i]])`` for every + constraint index ``i``. + y : object + Not used, for scikit-learn compatibility weights : (m,) array of floats, optional scale factor for each constraint """ + X = X_constrained.X + constraints = [X_constrained.c[:, i].ravel() for i in range(4)] self._prepare_inputs(X, constraints, weights) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. @@ -131,7 +136,7 @@ def _gradient(self, metric): return dMetric -class LSML_Supervised(LSML): +class LSML_Supervised(_LSML, SupervisedMixin): def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, num_constraints=None, weights=None, verbose=False): """Initialize the learner. @@ -151,7 +156,7 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, verbose : bool, optional if True, prints information while learning """ - LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, + _LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -181,4 +186,10 @@ def fit(self, X, y, random_state=np.random): random_state=random_state) pairs = c.positive_negative_pairs(num_constraints, same_length=True, random_state=random_state) - return LSML.fit(self, X, pairs, weights=self.weights) + X_constrained = ConstrainedDataset(X, np.column_stack(pairs)) + return _LSML._fit(self, X_constrained, weights=self.weights) + + +class LSML(_LSML, QuadrupletsMixin): + + pass diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 35b80495..c557d27a 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -13,12 +13,12 @@ from sklearn.decomposition import PCA from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, SupervisedMixin EPS = np.finfo(float).eps -class MLKR(BaseMetricLearner): +class MLKR(BaseMetricLearner, SupervisedMixin): """Metric Learning for Kernel Regression (MLKR)""" def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001, max_iter=1000): @@ -110,3 +110,4 @@ def _loss(flatA, X, y, dX): M = (dX.T * W.ravel()).dot(dX) grad = 2 * A.dot(M) return cost, grad.ravel() + diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index efe33c38..7b986ec3 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -19,16 +19,16 @@ from __future__ import print_function, absolute_import, division import numpy as np from six.moves import xrange -from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .base_metric import PairsMixin, BaseMetricLearner, SupervisedMixin +from .constraints import (Constraints, ConstrainedDataset, + unwrap_pairs, wrap_pairs) from ._util import vector_norm -class MMC(BaseMetricLearner): +class _MMC(BaseMetricLearner): """Mahalanobis Metric for Clustering (MMC)""" def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, A0=None, diagonal=False, diagonal_c=1.0, verbose=False): @@ -58,17 +58,17 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, self.diagonal_c = diagonal_c self.verbose = verbose - def fit(self, X, constraints): + def _fit(self, X_constrained, y_constraints): """Learn the MMC model. Parameters ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) - dissimilar pairs + X_constrained : ConstrainedDataset + with constraints being an array of shape [n_constraints, 2] + y_constraints : array-like, shape (n_constraints x 1) + labels of the constraints """ + X, constraints = unwrap_pairs(X_constrained, y_constraints) constraints = self._process_inputs(X, constraints) if self.diagonal: return self._fit_diag(X, constraints) @@ -380,7 +380,7 @@ def transformer(self): return V.T * np.sqrt(np.maximum(0, w[:,None])) -class MMC_Supervised(MMC): +class MMC_Supervised(_MMC, SupervisedMixin): """Mahalanobis Metric for Clustering (MMC)""" def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, num_labeled=np.inf, num_constraints=None, @@ -408,7 +408,7 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, verbose : bool, optional if True, prints information while learning """ - MMC.__init__(self, max_iter=max_iter, max_proj=max_proj, + _MMC.__init__(self, max_iter=max_iter, max_proj=max_proj, convergence_threshold=convergence_threshold, A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, verbose=verbose) @@ -437,4 +437,9 @@ def fit(self, X, y, random_state=np.random): random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) - return MMC.fit(self, X, pos_neg) + X_constrained, y = wrap_pairs(X, pos_neg) + return _MMC._fit(self, X_constrained, y) + +class MMC(_MMC, PairsMixin): + + pass diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 40757d23..0314a21a 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -8,12 +8,12 @@ from six.moves import xrange from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, SupervisedMixin EPS = np.finfo(float).eps -class NCA(BaseMetricLearner): +class NCA(BaseMetricLearner, SupervisedMixin): def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01): self.num_dims = num_dims self.max_iter = max_iter diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 0d9b3620..b353aac8 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -18,7 +18,7 @@ from sklearn import decomposition from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, PairsMixin, SupervisedMixin from .constraints import Constraints @@ -35,7 +35,7 @@ def _chunk_mean_centering(data, chunks): return chunk_mask, chunk_data -class RCA(BaseMetricLearner): +class RCA(BaseMetricLearner, SupervisedMixin): """Relevant Components Analysis (RCA)""" def __init__(self, num_dims=None, pca_comps=None): """Initialize the learner. diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 93280334..d0bf5f89 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -15,11 +15,11 @@ from sklearn.utils.extmath import pinvh from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .base_metric import PairsMixin, SupervisedMixin, BaseMetricLearner +from .constraints import Constraints, wrap_pairs, unwrap_to_graph -class SDML(BaseMetricLearner): +class _SDML(BaseMetricLearner): def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose=False): """ @@ -56,21 +56,22 @@ def _prepare_inputs(self, X, W): def metric(self): return self.M_ - def fit(self, X, W): + def _fit(self, X_constrained, y_constraints): """Learn the SDML model. Parameters ---------- - X : array-like, shape (n, d) - data matrix, where each row corresponds to a single instance - W : array-like, shape (n, n) - connectivity graph, with +1 for positive pairs and -1 for negative + X_constrained : ConstrainedDataset + with constraints being an array of shape [n_constraints, 2] + y_constraints : array-like, shape (n_constraints x 1) + labels of the constraints Returns ------- self : object Returns the instance. """ + X, W = unwrap_to_graph(X_constrained, y_constraints) loss_matrix = self._prepare_inputs(X, W) P = self.M_ + self.balance_param * loss_matrix emp_cov = pinvh(P) @@ -80,7 +81,7 @@ def fit(self, X, W): return self -class SDML_Supervised(SDML): +class SDML_Supervised(_SDML, SupervisedMixin): def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, num_labeled=np.inf, num_constraints=None, verbose=False): """ @@ -99,7 +100,7 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose : bool, optional if True, prints information while learning """ - SDML.__init__(self, balance_param=balance_param, + _SDML.__init__(self, balance_param=balance_param, sparsity_param=sparsity_param, use_cov=use_cov, verbose=verbose) self.num_labeled = num_labeled @@ -131,5 +132,11 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) - adj = c.adjacency_matrix(num_constraints, random_state=random_state) - return SDML.fit(self, X, adj) + pos_neg = c.positive_negative_pairs(num_constraints, + random_state=random_state) + X_constrained, y = wrap_pairs(X, pos_neg) + return _SDML._fit(self, X_constrained, y) + +class SDML(_SDML, PairsMixin): + + pass diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 6d78c657..98988bd8 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -1,15 +1,16 @@ import unittest import numpy as np +from metric_learn.constraints import wrap_pairs from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.datasets import load_iris from numpy.testing import assert_array_almost_equal from metric_learn import ( - LMNN, NCA, LFDA, Covariance, MLKR, MMC, + NCA, LFDA, Covariance, MLKR, MMC, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) # Import this specially for testing. -from metric_learn.lmnn import python_LMNN +from metric_learn.lmnn import python_LMNN, LMNN def class_separation(X, labels): @@ -160,7 +161,7 @@ def test_iris(self): # Full metric mmc = MMC(convergence_threshold=0.01) - mmc.fit(self.iris_points, [a,b,c,d]) + mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) expected = [[+0.00046504, +0.00083371, -0.00111959, -0.00165265], [+0.00083371, +0.00149466, -0.00200719, -0.00296284], [-0.00111959, -0.00200719, +0.00269546, +0.00397881], @@ -169,7 +170,7 @@ def test_iris(self): # Diagonal metric mmc = MMC(diagonal=True) - mmc.fit(self.iris_points, [a,b,c,d]) + mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) expected = [0, 0, 1.21045968, 1.22552608] assert_array_almost_equal(np.diag(expected), mmc.metric(), decimal=6) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 31db4e6f..d73af4f9 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -9,10 +9,10 @@ def test_covariance(self): def test_lmnn(self): self.assertRegexpMatches( - str(metric_learn.LMNN()), - r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " - r"max_iter=1000,\n min_iter=50, regularization=0.5, " - r"use_pca=True, verbose=False\)") + str(metric_learn.LMNN()), + r"^(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " + r"max_iter=1000,(\n )? min_iter=50,(\n )? regularization=0.5, " + r"use_pca=True, verbose=False\)$") def test_nca(self): self.assertEqual(str(metric_learn.NCA()), diff --git a/test/test_constrained_dataset.py b/test/test_constrained_dataset.py new file mode 100644 index 00000000..176bc491 --- /dev/null +++ b/test/test_constrained_dataset.py @@ -0,0 +1,166 @@ +import unittest +import numpy as np +import scipy +from metric_learn.constraints import ConstrainedDataset +from numpy.testing import assert_array_equal +from sklearn.model_selection import StratifiedKFold, KFold +from sklearn.utils import check_random_state +from sklearn.utils.testing import assert_raise_message +from sklearn.utils.mocking import MockDataFrame + + +class _BaseTestConstrainedDataset(object): + + def setUp(self): + self.num_points = 20 + self.num_features = 5 + self.num_constraints = 15 + self.RNG = check_random_state(0) + + self.c = self.RNG.randint(0, self.num_points, + (self.num_constraints, 2)) + self.y = self.RNG.randint(0, 2, self.num_constraints) + self.group = self.RNG.randint(0, 3, self.num_constraints) + + def check_indexing(self, idx): + # checks that an indexing returns the data we expect + np.testing.assert_array_equal(self.X_constrained[idx].c, self.c[idx]) + np.testing.assert_array_equal(self.X_constrained[idx].toarray(), + self.X[self.c[idx]]) + np.testing.assert_array_equal(self.X_constrained[idx].toarray(), + self.X[self.c][idx]) + # checks that slicing does not copy the initial X + self.assertTrue(self.X_constrained[idx].X is self.X_constrained.X) + + def test_allowed_inputs(self): + # test the allowed ways to create a ConstrainedDataset + ConstrainedDataset(self.X, self.c) + + def test_invalid_inputs(self): + # test the invalid ways to create a ConstrainedDataset + two_points = [[1, 2], [3, 5]] + out_of_range_constraints = [[1, 2], [0, 1]] + msg = ("ConstrainedDataset cannot be created: the length of " + "the dataset is 2, so index 2 is out of " + "range.") + assert_raise_message(IndexError, msg, ConstrainedDataset, two_points, + out_of_range_constraints) + + def test_getitem(self): + # test different types of slicing + i = self.RNG.randint(1, self.num_constraints - 1) + begin = self.RNG.randint(1, self.num_constraints - 1) + end = self.RNG.randint(begin + 1, self.num_constraints) + fancy_index = self.RNG.randint(0, self.num_constraints, 20) + binary_index = self.RNG.randint(0, 2, self.num_constraints) + boolean_index = binary_index.astype(bool) + items = [0, self.num_constraints - 1, i, slice(i), slice(0, begin), + slice(begin, end), slice(end, self.num_constraints), + slice(0, self.num_constraints), fancy_index, + fancy_index.tolist(), binary_index, binary_index.tolist(), + boolean_index, boolean_index.tolist()] + for item in items: + self.check_indexing(item) + + def test_repr(self): + self.assertEqual(repr(self.X_constrained), repr(self.X[self.c])) + + def test_str(self): + self.assertEqual(str(self.X_constrained), str(self.X[self.c])) + + def test_shape(self): + self.assertEqual(self.X_constrained.shape, (self.c.shape[0], + self.c.shape[1], + self.X.shape[1])) + self.assertEqual(self.X_constrained[0, 0].shape, + (0, 0, self.X.shape[1])) + + def test_len(self): + self.assertEqual(len(self.X_constrained), self.c.shape[0]) + + def test_toarray(self): + X = self.X_constrained.X + assert_array_equal(self.X_constrained.toarray(), X[self.c]) + + def test_folding(self): + # test that ConstrainedDataset is compatible with scikit-learn folding + shuffle_list = [True, False] + groups_list = [self.group, None] + for alg in [KFold, StratifiedKFold]: + for shuffle_i in shuffle_list: + for group_i in groups_list: + for train_idx, test_idx \ + in alg(shuffle=shuffle_i).split(self.X_constrained, + self.y, + group_i): + self.check_indexing(train_idx) + self.check_indexing(test_idx) + + +class TestDenseConstrainedDataset(_BaseTestConstrainedDataset, + unittest.TestCase): + + def setUp(self): + super(TestDenseConstrainedDataset, self).setUp() + self.X = self.RNG.randn(self.num_points, self.num_features) + self.X_constrained = ConstrainedDataset(self.X, self.c) + + def test_init(self): + """ + Test alternative ways to initialize a ConstrainedDataset + (where the remaining X will stay dense) + """ + X_list = [self.X, self.X.tolist(), list(self.X), MockDataFrame(self.X)] + c_list = [self.c, self.c.tolist(), list(self.c), MockDataFrame(self.c)] + for X in X_list: + for c in c_list: + X_constrained = ConstrainedDataset(X, c) + + +class TestSparseConstrainedDataset(_BaseTestConstrainedDataset, + unittest.TestCase): + + def setUp(self): + super(TestSparseConstrainedDataset, self).setUp() + self.X = scipy.sparse.random(self.num_points, self.num_features, + format='csr', random_state=self.RNG) + # todo: for now we test only csr but we should test all sparse types + # in the future + self.X_constrained = ConstrainedDataset(self.X, self.c) + + def check_indexing(self, idx): + # checks that an indexing returns the data we expect + np.testing.assert_array_equal(self.X_constrained[idx].c, self.c[idx]) + np.testing.assert_array_equal(self.X_constrained[idx].toarray(), + self.X.A[self.c[idx]]) + np.testing.assert_array_equal(self.X_constrained[idx].toarray(), + self.X.A[self.c][idx]) + # checks that slicing does not copy the initial X + self.assertTrue(self.X_constrained[idx].X is self.X_constrained.X) + + def test_repr(self): + self.assertEqual(repr(self.X_constrained), repr(self.X.A[self.c])) + + def test_str(self): + self.assertEqual(str(self.X_constrained), str(self.X.A[self.c])) + + def test_toarray(self): + X = self.X_constrained.X + assert_array_equal(self.X_constrained.toarray(), X.A[self.c]) + + def test_folding(self): + # test that ConstrainedDataset is compatible with scikit-learn folding + shuffle_list = [True, False] + groups_list = [self.group, None] + for alg in [KFold, StratifiedKFold]: + for shuffle_i in shuffle_list: + for group_i in groups_list: + for train_idx, test_idx \ + in alg(shuffle=shuffle_i).split(self.X_constrained, + self.y, + group_i): + self.check_indexing(train_idx) + self.check_indexing(test_idx) + +if __name__=='__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py new file mode 100644 index 00000000..70dbdddd --- /dev/null +++ b/test/test_weakly_supervised.py @@ -0,0 +1,229 @@ +import unittest +from sklearn import clone +from sklearn.cluster import KMeans +from sklearn.datasets import load_iris +from sklearn.model_selection import cross_val_score +from sklearn.decomposition import PCA +from sklearn.model_selection import train_test_split +from sklearn.pipeline import make_pipeline +from sklearn.utils.estimator_checks import is_public_parameter +from sklearn.utils.testing import set_random_state, assert_true, \ + assert_allclose_dense_sparse, assert_dict_equal, assert_false + +from metric_learn import ITML, LSML, MMC, SDML +from metric_learn.constraints import ConstrainedDataset, Constraints, \ + wrap_pairs +from sklearn.utils import check_random_state, shuffle +import numpy as np + +class _TestWeaklySupervisedBase(object): + + def setUp(self): + self.RNG = check_random_state(0) + set_random_state(self.estimator) + dataset = load_iris() + self.X, y = shuffle(dataset.data, dataset.target, random_state=self.RNG) + self.X, y = self.X[:20], y[:20] + num_constraints = 20 + constraints = Constraints.random_subset(y, random_state=self.RNG) + self.pairs = constraints.positive_negative_pairs(num_constraints, + same_length=True, + random_state=self.RNG) + + def test_cross_validation(self): + # test that you can do cross validation on a ConstrainedDataset with + # a WeaklySupervisedMetricLearner + estimator = clone(self.estimator) + self.assertTrue(np.isfinite(cross_val_score(estimator, + self.X_constrained, self.y)).all()) + + def check_score(self, estimator, X_constrained, y): + score = estimator.score(X_constrained, y) + self.assertTrue(np.isfinite(score)) + + def check_predict(self, estimator, X_constrained): + y_predicted = estimator.predict(X_constrained) + self.assertEqual(len(y_predicted), len(X_constrained)) + + def check_transform(self, estimator, X_constrained): + X_transformed = estimator.transform(X_constrained) + self.assertEqual(len(X_transformed), len(X_constrained.X)) + + def test_simple_estimator(self): + estimator = clone(self.estimator) + estimator.fit(self.X_constrained_train, self.y_train) + self.check_score(estimator, self.X_constrained_test, self.y_test) + self.check_predict(estimator, self.X_constrained_test) + self.check_transform(estimator, self.X_constrained_test) + + def test_pipelining_with_transformer(self): + """ + Test that weakly supervised estimators fit well into pipelines + """ + # test in a pipeline with KMeans + estimator = clone(self.estimator) + pipe = make_pipeline(estimator, KMeans()) + pipe.fit(self.X_constrained_train, self.y_train) + self.check_score(pipe, self.X_constrained_test, self.y_test) + self.check_transform(pipe, self.X_constrained_test) + # we cannot use check_predict because in this case the shape of the + # output is the shape of X_constrained.X, not X_constrained + y_predicted = pipe.predict(self.X_constrained) + self.assertEqual(len(y_predicted), len(self.X_constrained.X)) + + # test in a pipeline with PCA + estimator = clone(self.estimator) + pipe = make_pipeline(estimator, PCA()) + pipe.fit(self.X_constrained_train, self.y_train) + self.check_transform(pipe, self.X_constrained_test) + + def test_no_fit_attributes_set_in_init(self): + """Check that Estimator.__init__ doesn't set trailing-_ attributes.""" + # From scikit-learn + estimator = clone(self.estimator) + for attr in dir(estimator): + if attr.endswith("_") and not attr.startswith("__"): + # This check is for properties, they can be listed in dir + # while at the same time have hasattr return False as long + # as the property getter raises an AttributeError + assert_false( + hasattr(estimator, attr), + "By convention, attributes ending with '_' are " + 'estimated from data in scikit-learn. Consequently they ' + 'should not be initialized in the constructor of an ' + 'estimator but in the fit method. Attribute {!r} ' + 'was found in estimator {}'.format( + attr, type(estimator).__name__)) + + def test_estimators_fit_returns_self(self): + """Check if self is returned when calling fit""" + # From scikit-learn + estimator = clone(self.estimator) + assert_true(estimator.fit(self.X_constrained, self.y) is estimator) + + def test_pipeline_consistency(self): + # From scikit learn + # check that make_pipeline(est) gives same score as est + estimator = clone(self.estimator) + pipeline = make_pipeline(estimator) + estimator.fit(self.X_constrained, self.y) + pipeline.fit(self.X_constrained, self.y) + + funcs = ["score", "fit_transform"] + + for func_name in funcs: + func = getattr(estimator, func_name, None) + if func is not None: + func_pipeline = getattr(pipeline, func_name) + result = func(self.X_constrained, self.y) + result_pipe = func_pipeline(self.X_constrained, self.y) + assert_allclose_dense_sparse(result, result_pipe) + + def test_dict_unchanged(self): + # From scikit-learn + estimator = clone(self.estimator) + if hasattr(estimator, "n_components"): + estimator.n_components = 1 + estimator.fit(self.X_constrained, self.y) + for method in ["predict", "transform", "decision_function", + "predict_proba"]: + if hasattr(estimator, method): + dict_before = estimator.__dict__.copy() + getattr(estimator, method)(self.X_constrained) + assert_dict_equal(estimator.__dict__, dict_before, + 'Estimator changes __dict__ during %s' + % method) + + def test_dont_overwrite_parameters(self): + # From scikit-learn + # check that fit method only changes or sets private attributes + estimator = clone(self.estimator) + if hasattr(estimator, "n_components"): + estimator.n_components = 1 + dict_before_fit = estimator.__dict__.copy() + + estimator.fit(self.X_constrained, self.y) + dict_after_fit = estimator.__dict__ + + public_keys_after_fit = [key for key in dict_after_fit.keys() + if is_public_parameter(key)] + + attrs_added_by_fit = [key for key in public_keys_after_fit + if key not in dict_before_fit.keys()] + + # check that fit doesn't add any public attribute + assert_true(not attrs_added_by_fit, + ('Estimator adds public attribute(s) during' + ' the fit method.' + ' Estimators are only allowed to add private ' + 'attributes' + ' either started with _ or ended' + ' with _ but %s added' % ', '.join( + attrs_added_by_fit))) + + # check that fit doesn't change any public attribute + attrs_changed_by_fit = [key for key in public_keys_after_fit + if (dict_before_fit[key] + is not dict_after_fit[key])] + + assert_true(not attrs_changed_by_fit, + ('Estimator changes public attribute(s) during' + ' the fit method. Estimators are only allowed' + ' to change attributes started' + ' or ended with _, but' + ' %s changed' % ', '.join(attrs_changed_by_fit))) + + +class _TestPairsBase(_TestWeaklySupervisedBase): + + def setUp(self): + super(_TestPairsBase, self).setUp() + self.X_constrained, self.y = wrap_pairs(self.X, self.pairs) + self.X_constrained, self.y = shuffle(self.X_constrained, self.y, + random_state=self.RNG) + self.X_constrained_train, self.X_constrained_test, self.y_train, \ + self.y_test = train_test_split(self.X_constrained, self.y) + + +class _TestQuadrupletsBase(_TestWeaklySupervisedBase): + + def setUp(self): + super(_TestQuadrupletsBase, self).setUp() + c = np.column_stack(self.pairs) + self.X_constrained = ConstrainedDataset(self.X, c) + self.X_constrained = shuffle(self.X_constrained) + self.y, self.y_train, self.y_test = None, None, None + self.X_constrained_train, self.X_constrained_test = train_test_split( + self.X_constrained) + + +class TestITML(_TestPairsBase, unittest.TestCase): + + def setUp(self): + self.estimator = ITML() + super(TestITML, self).setUp() + + +class TestLSML(_TestQuadrupletsBase, unittest.TestCase): + + def setUp(self): + self.estimator = LSML() + super(TestLSML, self).setUp() + + +class TestMMC(_TestPairsBase, unittest.TestCase): + + def setUp(self): + self.estimator = MMC() + super(TestMMC, self).setUp() + + +class TestSDML(_TestPairsBase, unittest.TestCase): + + def setUp(self): + self.estimator = SDML() + super(TestSDML, self).setUp() + + +if __name__ == '__main__': + unittest.main()