diff --git a/metric_learn/_util.py b/metric_learn/_util.py index e7f24e7d..27707be9 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,5 +1,8 @@ import numpy as np - +import six +from sklearn.utils import check_array +from sklearn.utils.validation import check_X_y +from metric_learn.exceptions import PreprocessorError # hack around lack of axis kwarg in older numpy versions try: @@ -12,39 +15,310 @@ def vector_norm(X): return np.linalg.norm(X, axis=1) -def check_tuples(tuples): - """Check that the input is a valid 3D array representing a dataset of tuples. - - Equivalent of `check_array` in scikit-learn. +def check_input(input_data, y=None, preprocessor=None, + type_of_inputs='classic', tuple_size=None, accept_sparse=False, + dtype='numeric', order=None, + copy=False, force_all_finite=True, + multi_output=False, ensure_min_samples=1, + ensure_min_features=1, y_numeric=False, + warn_on_dtype=False, estimator=None): + """Checks that the input format is valid, and converts it if specified + (this is the equivalent of scikit-learn's `check_array` or `check_X_y`). + All arguments following tuple_size are scikit-learn's `check_X_y` + arguments that will be enforced on the data and labels array. If + indicators are given as an input data array, the returned data array + will be the formed points/tuples, using the given preprocessor. Parameters ---------- - tuples : object - The tuples to check. + input : array-like + The input data array to check. + + y : array-like + The input labels array to check. + + preprocessor : callable (default=`None`) + The preprocessor to use. If None, no preprocessor is used. + + type_of_inputs : `str` {'classic', 'tuples'} + The type of inputs to check. If 'classic', the input should be + a 2D array-like of points or a 1D array like of indicators of points. If + 'tuples', the input should be a 3D array-like of tuples or a 2D + array-like of indicators of tuples. + + accept_sparse : `bool` + Set to true to allow sparse inputs (only works for sparse inputs with + dim < 3). + + tuple_size : int + The number of elements in a tuple (e.g. 2 for pairs). + + dtype : string, type, list of types or None (default='numeric') + Data type of result. If None, the dtype of the input is preserved. + If 'numeric', dtype is preserved unless array.dtype is object. + If dtype is a list of types, conversion on the first type is only + performed if the dtype of the input is not in the list. + + order : 'F', 'C' or None (default=`None`) + Whether an array will be forced to be fortran or c-style. + + copy : boolean (default=False) + Whether a forced copy will be triggered. If copy=False, a copy might + be triggered by a conversion. + + force_all_finite : boolean or 'allow-nan', (default=True) + Whether to raise an error on np.inf and np.nan in X. This parameter + does not influence whether y can have np.inf or np.nan values. + The possibilities are: + - True: Force all values of X to be finite. + - False: accept both np.inf and np.nan in X. + - 'allow-nan': accept only np.nan values in X. Values cannot be + infinite. + + ensure_min_samples : int (default=1) + Make sure that X has a minimum number of samples in its first + axis (rows for a 2D array). + + ensure_min_features : int (default=1) + Make sure that the 2D array has some minimum number of features + (columns). The default value of 1 rejects empty datasets. + This check is only enforced when X has effectively 2 dimensions or + is originally 1D and ``ensure_2d`` is True. Setting to 0 disables + this check. + + warn_on_dtype : boolean (default=False) + Raise DataConversionWarning if the dtype of the input data structure + does not match the requested dtype, causing a memory copy. + + estimator : str or estimator instance (default=`None`) + If passed, include the name of the estimator in warning messages. Returns ------- - tuples_valid : object - The validated input. + X : `numpy.ndarray` + The checked input data array. + + y: `numpy.ndarray` (optional) + The checked input labels array. """ - # If input is scalar raise error - if np.isscalar(tuples): - raise ValueError( - "Expected 3D array, got scalar instead. Cannot apply this function on " - "scalars.") - # If input is 1D raise error - if len(tuples.shape) == 1: - raise ValueError( - "Expected 3D array, got 1D array instead:\ntuples={}.\n" - "Reshape your data using tuples.reshape(1, -1, 1) if it contains a " - "single tuple and the points in the tuple have a single " - "feature.".format(tuples)) - # If input is 2D raise error - if len(tuples.shape) == 2: - raise ValueError( - "Expected 3D array, got 2D array instead:\ntuples={}.\n" - "Reshape your data either using tuples.reshape(-1, {}, 1) if " - "your data has a single feature or tuples.reshape(1, {}, -1) " - "if it contains a single tuple.".format(tuples, tuples.shape[1], - tuples.shape[0])) + + context = make_context(estimator) + + args_for_sk_checks = dict(accept_sparse=accept_sparse, + dtype=dtype, order=order, + copy=copy, force_all_finite=force_all_finite, + ensure_min_samples=ensure_min_samples, + ensure_min_features=ensure_min_features, + warn_on_dtype=warn_on_dtype, estimator=estimator) + + # We need to convert input_data into a numpy.ndarray if possible, before + # any further checks or conversions, and deal with y if needed. Therefore + # we use check_array/check_X_y with fixed permissive arguments. + if y is None: + input_data = check_array(input_data, ensure_2d=False, allow_nd=True, + copy=False, force_all_finite=False, + accept_sparse=True, dtype=None, + ensure_min_features=0, ensure_min_samples=0) + else: + input_data, y = check_X_y(input_data, y, ensure_2d=False, allow_nd=True, + copy=False, force_all_finite=False, + accept_sparse=True, dtype=None, + ensure_min_features=0, ensure_min_samples=0, + multi_output=multi_output, + y_numeric=y_numeric) + + if type_of_inputs == 'classic': + input_data = check_input_classic(input_data, context, preprocessor, + args_for_sk_checks) + + elif type_of_inputs == 'tuples': + input_data = check_input_tuples(input_data, context, preprocessor, + args_for_sk_checks, tuple_size) + + else: + raise ValueError("Unknown value {} for type_of_inputs. Valid values are " + "'classic' or 'tuples'.".format(type_of_inputs)) + + return input_data if y is None else (input_data, y) + + +def check_input_tuples(input_data, context, preprocessor, args_for_sk_checks, + tuple_size): + preprocessor_has_been_applied = False + if input_data.ndim == 2: + if preprocessor is not None: + input_data = preprocess_tuples(input_data, preprocessor) + preprocessor_has_been_applied = True + else: + make_error_input(201, input_data, context) + elif input_data.ndim == 3: + pass + else: + if preprocessor is not None: + make_error_input(420, input_data, context) + else: + make_error_input(200, input_data, context) + input_data = check_array(input_data, allow_nd=True, ensure_2d=False, + **args_for_sk_checks) + # we need to check num_features because check_array does not check it + # for 3D inputs: + if args_for_sk_checks['ensure_min_features'] > 0: + n_features = input_data.shape[2] + if n_features < args_for_sk_checks['ensure_min_features']: + raise ValueError("Found array with {} feature(s) (shape={}) while" + " a minimum of {} is required{}." + .format(n_features, input_data.shape, + args_for_sk_checks['ensure_min_features'], + context)) + # normally we don't need to check_tuple_size too because tuple_size + # shouldn't be able to be modified by any preprocessor + if input_data.ndim != 3: + # we have to ensure this because check_array above does not + if preprocessor_has_been_applied: + make_error_input(211, input_data, context) + else: + make_error_input(201, input_data, context) + check_tuple_size(input_data, tuple_size, context) + return input_data + + +def check_input_classic(input_data, context, preprocessor, args_for_sk_checks): + preprocessor_has_been_applied = False + if input_data.ndim == 1: + if preprocessor is not None: + input_data = preprocess_points(input_data, preprocessor) + preprocessor_has_been_applied = True + else: + make_error_input(101, input_data, context) + elif input_data.ndim == 2: + pass # OK + else: + if preprocessor is not None: + make_error_input(320, input_data, context) + else: + make_error_input(100, input_data, context) + + input_data = check_array(input_data, allow_nd=True, ensure_2d=False, + **args_for_sk_checks) + if input_data.ndim != 2: + # we have to ensure this because check_array above does not + if preprocessor_has_been_applied: + make_error_input(111, input_data, context) + else: + make_error_input(101, input_data, context) + return input_data + + +def make_error_input(code, input_data, context): + code_str = {'expected_input': {'1': '2D array of formed points', + '2': '3D array of formed tuples', + '3': ('1D array of indicators or 2D array of ' + 'formed points'), + '4': ('2D array of indicators or 3D array ' + 'of formed tuples')}, + 'additional_context': {'0': '', + '2': ' when using a preprocessor', + '1': (' after the preprocessor has been ' + 'applied')}, + 'possible_preprocessor': {'0': '', + '1': ' and/or use a preprocessor' + }} + code_list = str(code) + err_args = dict(expected_input=code_str['expected_input'][code_list[0]], + additional_context=code_str['additional_context'] + [code_list[1]], + possible_preprocessor=code_str['possible_preprocessor'] + [code_list[2]], + input_data=input_data, context=context, + found_size=input_data.ndim) + err_msg = ('{expected_input} expected' + '{context}{additional_context}. Found {found_size}D array ' + 'instead:\ninput={input_data}. Reshape your data' + '{possible_preprocessor}.\n') + raise ValueError(err_msg.format(**err_args)) + + +def preprocess_tuples(tuples, preprocessor): + try: + tuples = np.column_stack([preprocessor(tuples[:, i])[:, np.newaxis] for + i in range(tuples.shape[1])]) + except Exception as e: + raise PreprocessorError(e) return tuples + + +def preprocess_points(points, preprocessor): + """form points if there is a preprocessor else keep them as such (assumes + that check_points has already been called)""" + try: + points = preprocessor(points) + except Exception as e: + raise PreprocessorError(e) + return points + + +def make_context(estimator): + """Helper function to create a string with the estimator name. + Taken from check_array function in scikit-learn. + Will return the following for instance: + NCA: ' by NCA' + 'NCA': ' by NCA' + None: '' + """ + estimator_name = make_name(estimator) + context = (' by ' + estimator_name) if estimator_name is not None else '' + return context + + +def make_name(estimator): + """Helper function that returns the name of estimator or the given string + if a string is given + """ + if estimator is not None: + if isinstance(estimator, six.string_types): + estimator_name = estimator + else: + estimator_name = estimator.__class__.__name__ + else: + estimator_name = None + return estimator_name + + +def check_tuple_size(tuples, tuple_size, context): + """Helper function to check that the number of points in each tuple is + equal to tuple_size (e.g. 2 for pairs), and raise a `ValueError` otherwise""" + if tuple_size is not None and tuples.shape[1] != tuple_size: + msg_t = (("Tuples of {} element(s) expected{}. Got tuples of {} " + "element(s) instead (shape={}):\ninput={}.\n") + .format(tuple_size, context, tuples.shape[1], tuples.shape, + tuples)) + raise ValueError(msg_t) + + +class ArrayIndexer: + + def __init__(self, X): + # we check the array-like preprocessor here, and we as much permissive + # as possible (because the user will check for the desired + # format with arguments in check_input, and only this latter function + # should return the appropriate errors). We do this only to have a numpy + # array object which can be indexed by another numpy array object. + X = check_array(X, + accept_sparse=True, dtype=None, + force_all_finite=False, + ensure_2d=False, allow_nd=True, + ensure_min_samples=0, + ensure_min_features=0, + warn_on_dtype=False, estimator=None) + self.X = X + + def __call__(self, indices): + return self.X[indices] + + +def check_collapsed_pairs(pairs): + num_ident = (vector_norm(pairs[:, 0] - pairs[:, 1]) < 1e-9).sum() + if num_ident: + raise ValueError("{} collapsed pairs found (where the left element is " + "the same as the right element), out of {} pairs " + "in total.".format(num_ident, pairs.shape[0])) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 4044f223..9af79ecc 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,15 +1,26 @@ from numpy.linalg import cholesky from sklearn.base import BaseEstimator -from sklearn.utils.validation import check_array +from sklearn.utils.validation import _is_arraylike from sklearn.metrics import roc_auc_score import numpy as np from abc import ABCMeta, abstractmethod import six -from ._util import check_tuples +from ._util import ArrayIndexer, check_input class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)): + def __init__(self, preprocessor=None): + """ + + Parameters + ---------- + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. + """ + self.preprocessor = preprocessor + @abstractmethod def score_pairs(self, pairs): """Returns the score between pairs @@ -26,6 +37,55 @@ def score_pairs(self, pairs): The score of every pair. """ + def check_preprocessor(self): + """Initializes the preprocessor""" + if _is_arraylike(self.preprocessor): + self.preprocessor_ = ArrayIndexer(self.preprocessor) + elif callable(self.preprocessor) or self.preprocessor is None: + self.preprocessor_ = self.preprocessor + else: + raise ValueError("Invalid type for the preprocessor: {}. You should " + "provide either None, an array-like object, " + "or a callable.".format(type(self.preprocessor))) + + def _prepare_inputs(self, X, y=None, type_of_inputs='classic', + **kwargs): + """Initializes the preprocessor and processes inputs. See `check_input` + for more details. + + Parameters + ---------- + input: array-like + The input data array to check. + + y : array-like + The input labels array to check. + + type_of_inputs: `str` {'classic', 'tuples'} + The type of inputs to check. If 'classic', the input should be + a 2D array-like of points or a 1D array like of indicators of points. If + 'tuples', the input should be a 3D array-like of tuples or a 2D + array-like of indicators of tuples. + + **kwargs: dict + Arguments to pass to check_input. + + Returns + ------- + X : `numpy.ndarray` + The checked input data array. + + y: `numpy.ndarray` (optional) + The checked input labels array. + """ + self.check_preprocessor() + return check_input(X, y, + type_of_inputs=type_of_inputs, + preprocessor=self.preprocessor_, + estimator=self, + tuple_size=getattr(self, '_tuple_size', None), + **kwargs) + class MetricTransformer(six.with_metaclass(ABCMeta)): @@ -78,15 +138,19 @@ def score_pairs(self, pairs): Parameters ---------- - pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features) - 3D array of pairs, or 2D array of one pair. + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the metric learner uses a + preprocessor. Returns ------- scores: `numpy.ndarray` of shape=(n_pairs,) The learned Mahalanobis distance for every pair. """ - pairs = check_tuples(pairs) + pairs = check_input(pairs, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=2) pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :]) # (for MahalanobisMixin, the embedding is linear so we can just embed the # difference) @@ -109,7 +173,9 @@ def transform(self, X): X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims) The embedded data points. """ - X_checked = check_array(X, accept_sparse=True) + X_checked = check_input(X, type_of_inputs='classic', estimator=self, + preprocessor=self.preprocessor_, + accept_sparse=True) return X_checked.dot(self.transformer_.T) def metric(self): @@ -144,28 +210,51 @@ def transformer_from_metric(self, metric): class _PairsClassifierMixin(BaseMetricLearner): + _tuple_size = 2 # number of points in a tuple, 2 for pairs + def predict(self, pairs): - """Predicts the learned metric between input pairs. + """Predicts the learned metric between input pairs. (For now it just + calls decision function). Returns the learned metric value between samples in every pair. It should ideally be low for similar samples and high for dissimilar samples. Parameters ---------- - pairs : array-like, shape=(n_constraints, 2, n_features) - Input pairs. + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to predict, with each row corresponding to two + points, or 2D array of indices of pairs if the metric learner uses a + preprocessor. Returns ------- y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted learned metric value between samples in every pair. """ - pairs = check_tuples(pairs) - return self.score_pairs(pairs) + return self.decision_function(pairs) def decision_function(self, pairs): - pairs = check_tuples(pairs) - return self.predict(pairs) + """Returns the learned metric between input pairs. + + Returns the learned metric value between samples in every pair. It should + ideally be low for similar samples and high for dissimilar samples. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to predict, with each row corresponding to two + points, or 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) + The predicted learned metric value between samples in every pair. + """ + pairs = check_input(pairs, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=self._tuple_size) + return self.score_pairs(pairs) def score(self, pairs, y): """Computes score of pairs similarity prediction. @@ -179,8 +268,10 @@ def score(self, pairs, y): Parameters ---------- - pairs : array-like, shape=(n_constraints, 2, n_features) - Input Pairs. + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs, with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. y : array-like, shape=(n_constraints,) The corresponding labels. @@ -190,12 +281,13 @@ def score(self, pairs, y): score : float The ``roc_auc`` score. """ - pairs = check_tuples(pairs) return roc_auc_score(y, self.decision_function(pairs)) class _QuadrupletsClassifierMixin(BaseMetricLearner): + _tuple_size = 4 # number of points in a tuple, 4 for quadruplets + def predict(self, quadruplets): """Predicts the ordering between sample distances in input quadruplets. @@ -204,15 +296,20 @@ def predict(self, quadruplets): Parameters ---------- - quadruplets : array-like, shape=(n_constraints, 4, n_features) - Input quadruplets. + quadruplets : array-like, shape=(n_quadruplets, 4, n_features) or + (n_quadruplets, 4) + 3D Array of quadruplets to predict, with each row corresponding to four + points, or 2D array of indices of quadruplets if the metric learner + uses a preprocessor. Returns ------- prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each quadruplet. """ - quadruplets = check_tuples(quadruplets) + quadruplets = check_input(quadruplets, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=self._tuple_size) return np.sign(self.decision_function(quadruplets)) def decision_function(self, quadruplets): @@ -223,17 +320,19 @@ def decision_function(self, quadruplets): Parameters ---------- - quadruplets : array-like, shape=(n_constraints, 4, n_features) - Input quadruplets. + quadruplets : array-like, shape=(n_quadruplets, 4, n_features) or + (n_quadruplets, 4) + 3D Array of quadruplets to predict, with each row corresponding to four + points, or 2D array of indices of quadruplets if the metric learner + uses a preprocessor. Returns ------- decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ - quadruplets = check_tuples(quadruplets) - return (self.score_pairs(quadruplets[:, :2, :]) - - self.score_pairs(quadruplets[:, 2:, :])) + return (self.score_pairs(quadruplets[:, :2]) - + self.score_pairs(quadruplets[:, 2:])) def score(self, quadruplets, y=None): """Computes score on input quadruplets @@ -244,8 +343,11 @@ def score(self, quadruplets, y=None): Parameters ---------- - quadruplets : array-like, shape=(n_constraints, 4, n_features) - Input quadruplets. + quadruplets : array-like, shape=(n_quadruplets, 4, n_features) or + (n_quadruplets, 4) + 3D Array of quadruplets to score, with each row corresponding to four + points, or 2D array of indices of quadruplets if the metric learner + uses a preprocessor. y : Ignored, for scikit-learn compatibility. @@ -254,5 +356,4 @@ def score(self, quadruplets, y=None): score : float The quadruplets score. """ - quadruplets = check_tuples(quadruplets) return -np.mean(self.predict(quadruplets)) diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 4e8c1a0f..a828feb6 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -10,7 +10,6 @@ from __future__ import absolute_import import numpy as np -from sklearn.utils.validation import check_array from sklearn.base import TransformerMixin from .base_metric import MahalanobisMixin @@ -26,20 +25,20 @@ class Covariance(MahalanobisMixin, TransformerMixin): metric (See :meth:`transformer_from_metric`.) """ - def __init__(self): - pass + def __init__(self, preprocessor=None): + super(Covariance, self).__init__(preprocessor) def fit(self, X, y=None): """ X : data matrix, (n x d) y : unused """ - self.X_ = check_array(X, ensure_min_samples=2) + self.X_ = self._prepare_inputs(X, ensure_min_samples=2) self.M_ = np.cov(self.X_, rowvar = False) if self.M_.ndim == 0: self.M_ = 1./self.M_ else: self.M_ = np.linalg.inv(self.M_) - self.transformer_ = self.transformer_from_metric(check_array(self.M_)) + self.transformer_ = self.transformer_from_metric(np.atleast_2d(self.M_)) return self diff --git a/metric_learn/exceptions.py b/metric_learn/exceptions.py new file mode 100644 index 00000000..424d2c4f --- /dev/null +++ b/metric_learn/exceptions.py @@ -0,0 +1,12 @@ +""" +The :mod:`metric_learn.exceptions` module includes all custom warnings and +error classes used across metric-learn. +""" + + +class PreprocessorError(Exception): + + def __init__(self, original_error): + err_msg = ("An error occurred when trying to use the " + "preprocessor: {}").format(repr(original_error)) + super(PreprocessorError, self).__init__(err_msg) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index d8bd24c2..4ce550fb 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -17,17 +17,20 @@ import numpy as np from six.moves import xrange from sklearn.metrics import pairwise_distances -from sklearn.utils.validation import check_array, check_X_y +from sklearn.utils.validation import check_array from sklearn.base import TransformerMixin from .base_metric import _PairsClassifierMixin, MahalanobisMixin from .constraints import Constraints, wrap_pairs -from ._util import vector_norm, check_tuples +from ._util import vector_norm class _BaseITML(MahalanobisMixin): """Information Theoretic Metric Learning (ITML)""" + + _tuple_size = 2 # constraints are pairs + def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, - A0=None, verbose=False): + A0=None, verbose=False, preprocessor=None): """Initialize ITML. Parameters @@ -44,26 +47,21 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, verbose : bool, optional if True, prints information while learning + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.gamma = gamma self.max_iter = max_iter self.convergence_threshold = convergence_threshold self.A0 = A0 self.verbose = verbose + super(_BaseITML, self).__init__(preprocessor) - def _process_pairs(self, pairs, y, bounds): - # for now we check_X_y and check_tuples but we should only - # check_tuples_y in the future - pairs, y = check_X_y(pairs, y, accept_sparse=False, - ensure_2d=False, allow_nd=True) - pairs = check_tuples(pairs) - - # check to make sure that no two constrained vectors are identical - pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] - pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 - pos_pairs = pos_pairs[pos_no_ident] - neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 - neg_pairs = neg_pairs[neg_no_ident] + def _fit(self, pairs, y, bounds=None): + pairs, y = self._prepare_inputs(pairs, y, + type_of_inputs='tuples') # init bounds if bounds is None: X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) @@ -77,12 +75,6 @@ def _process_pairs(self, pairs, y, bounds): self.A_ = np.identity(pairs.shape[2]) else: self.A_ = check_array(self.A0) - pairs = np.vstack([pos_pairs, neg_pairs]) - y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))]) - return pairs, y - - def _fit(self, pairs, y, bounds=None): - pairs, y = self._process_pairs(pairs, y, bounds) gamma = self.gamma pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] num_pos = len(pos_pairs) @@ -151,8 +143,11 @@ def fit(self, pairs, y, bounds=None): Parameters ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) - Array of pairs. Each row corresponds to two points. + pairs: array-like, shape=(n_constraints, 2, n_features) or + (n_constraints, 2) + 3D Array of pairs with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. y: array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. bounds : list (pos,neg) pairs, optional @@ -178,7 +173,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin): def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, num_labeled=np.inf, num_constraints=None, bounds=None, A0=None, - verbose=False): + verbose=False, preprocessor=None): """Initialize the learner. Parameters @@ -197,10 +192,13 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, initial regularization matrix, defaults to identity verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ _BaseITML.__init__(self, gamma=gamma, max_iter=max_iter, convergence_threshold=convergence_threshold, - A0=A0, verbose=verbose) + A0=A0, verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints self.bounds = bounds @@ -220,7 +218,7 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. """ - X, y = check_X_y(X, y) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index c06fca91..2feff211 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -16,7 +16,6 @@ import warnings from six.moves import xrange from sklearn.metrics import pairwise_distances -from sklearn.utils.validation import check_X_y from sklearn.base import TransformerMixin from .base_metric import MahalanobisMixin @@ -32,7 +31,8 @@ class LFDA(MahalanobisMixin, TransformerMixin): The learned linear transformation ``L``. ''' - def __init__(self, num_dims=None, k=None, embedding_type='weighted'): + def __init__(self, num_dims=None, k=None, embedding_type='weighted', + preprocessor=None): ''' Initialize LFDA. @@ -50,17 +50,32 @@ def __init__(self, num_dims=None, k=None, embedding_type='weighted'): 'weighted' - weighted eigenvectors 'orthonormalized' - orthonormalized 'plain' - raw eigenvectors + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. ''' if embedding_type not in ('weighted', 'orthonormalized', 'plain'): raise ValueError('Invalid embedding_type: %r' % embedding_type) self.num_dims = num_dims self.embedding_type = embedding_type self.k = k + super(LFDA, self).__init__(preprocessor) + + def fit(self, X, y): + '''Fit the LFDA model. - def _process_inputs(self, X, y): + Parameters + ---------- + X : (n, d) array-like + Input data. + + y : (n,) array-like + Class labels, one per point of data. + ''' + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) unique_classes, y = np.unique(y, return_inverse=True) - self.X_, y = check_X_y(X, y) - n, d = self.X_.shape + n, d = X.shape num_classes = len(unique_classes) if self.num_dims is None: @@ -77,21 +92,6 @@ def _process_inputs(self, X, y): k = d - 1 else: k = int(self.k) - - return self.X_, y, num_classes, n, d, dim, k - - def fit(self, X, y): - '''Fit the LFDA model. - - Parameters - ---------- - X : (n, d) array-like - Input data. - - y : (n,) array-like - Class labels, one per point of data. - ''' - X, y, num_classes, n, d, dim, k_ = self._process_inputs(X, y) tSb = np.zeros((d,d)) tSw = np.zeros((d,d)) @@ -102,8 +102,8 @@ def fit(self, X, y): # classwise affinity matrix dist = pairwise_distances(Xc, metric='l2', squared=True) # distances to k-th nearest neighbor - k = min(k_, nc-1) - sigma = np.sqrt(np.partition(dist, k, axis=0)[:,k]) + k = min(k, nc - 1) + sigma = np.sqrt(np.partition(dist, k, axis=0)[:, k]) local_scale = np.outer(sigma, sigma) with np.errstate(divide='ignore', invalid='ignore'): diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 7ce4d051..d78cf6b6 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -14,7 +14,6 @@ import warnings from collections import Counter from six.moves import xrange -from sklearn.utils.validation import check_X_y, check_array from sklearn.metrics import euclidean_distances from sklearn.base import TransformerMixin from .base_metric import MahalanobisMixin @@ -24,7 +23,7 @@ class _base_LMNN(MahalanobisMixin, TransformerMixin): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, use_pca=True, - verbose=False): + verbose=False, preprocessor=None): """Initialize the LMNN object. Parameters @@ -34,6 +33,10 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization: float, optional Weighting of pull and push terms, with 0.5 meaning equal weight. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.k = k self.min_iter = min_iter @@ -43,15 +46,21 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, self.convergence_tol = convergence_tol self.use_pca = use_pca self.verbose = verbose + super(_base_LMNN, self).__init__(preprocessor) # slower Python version class python_LMNN(_base_LMNN): - def _process_inputs(self, X, labels): - self.X_ = check_array(X, dtype=float) - num_pts, num_dims = self.X_.shape - unique_labels, self.label_inds_ = np.unique(labels, return_inverse=True) + def fit(self, X, y): + k = self.k + reg = self.regularization + learn_rate = self.learn_rate + + X, y = self._prepare_inputs(X, y, dtype=float, + ensure_min_samples=2) + num_pts, num_dims = X.shape + unique_labels, self.label_inds_ = np.unique(y, return_inverse=True) if len(self.label_inds_) != num_pts: raise ValueError('Must have one label per point.') self.labels_ = np.arange(len(unique_labels)) @@ -63,21 +72,15 @@ def _process_inputs(self, X, labels): raise ValueError('not enough class labels for specified k' ' (smallest class has %d)' % required_k) - def fit(self, X, y): - k = self.k - reg = self.regularization - learn_rate = self.learn_rate - self._process_inputs(X, y) - - target_neighbors = self._select_targets() - impostors = self._find_impostors(target_neighbors[:,-1]) + target_neighbors = self._select_targets(X) + impostors = self._find_impostors(target_neighbors[:, -1], X) if len(impostors) == 0: # L has already been initialized to an identity matrix return # sum outer products - dfG = _sum_outer_products(self.X_, target_neighbors.flatten(), - np.repeat(np.arange(self.X_.shape[0]), k)) + dfG = _sum_outer_products(X, target_neighbors.flatten(), + np.repeat(np.arange(X.shape[0]), k)) df = np.zeros_like(dfG) # storage @@ -99,7 +102,7 @@ def fit(self, X, y): a2_old = [a.copy() for a in a2] objective_old = objective # Compute pairwise distances under current metric - Lx = L.dot(self.X_.T).T + Lx = L.dot(X.T).T g0 = _inplace_paired_L2(*Lx[impostors]) Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:,None,:]) g1,g2 = Ni[impostors] @@ -124,16 +127,16 @@ def fit(self, X, y): targets = target_neighbors[:,nn_idx] PLUS, pweight = _count_edges(plus1, plus2, impostors, targets) - df += _sum_outer_products(self.X_, PLUS[:,0], PLUS[:,1], pweight) + df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight) MINUS, mweight = _count_edges(minus1, minus2, impostors, targets) - df -= _sum_outer_products(self.X_, MINUS[:,0], MINUS[:,1], mweight) + df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight) in_imp, out_imp = impostors - df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1]) - df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2]) + df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1]) + df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2]) - df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1]) - df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2]) + df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1]) + df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2]) a1[nn_idx] = act1 a2[nn_idx] = act2 @@ -178,18 +181,18 @@ def fit(self, X, y): self.n_iter_ = it return self - def _select_targets(self): - target_neighbors = np.empty((self.X_.shape[0], self.k), dtype=int) + def _select_targets(self, X): + target_neighbors = np.empty((X.shape[0], self.k), dtype=int) for label in self.labels_: inds, = np.nonzero(self.label_inds_ == label) - dd = euclidean_distances(self.X_[inds], squared=True) + dd = euclidean_distances(X[inds], squared=True) np.fill_diagonal(dd, np.inf) nn = np.argsort(dd)[..., :self.k] target_neighbors[inds] = inds[nn] return target_neighbors - def _find_impostors(self, furthest_neighbors): - Lx = self.transform(self.X_) + def _find_impostors(self, furthest_neighbors, X): + Lx = self.transform(X) margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) impostors = [] for label in self.labels_[:-1]: @@ -252,9 +255,10 @@ class LMNN(_base_LMNN): """ def fit(self, X, y): - self.X_, y = check_X_y(X, y, dtype=float) + X, y = self._prepare_inputs(X, y, dtype=float, + ensure_min_samples=2) labels = MulticlassLabels(y) - self._lmnn = shogun_LMNN(RealFeatures(self.X_.T), labels, self.k) + self._lmnn = shogun_LMNN(RealFeatures(X.T), labels, self.k) self._lmnn.set_maxiter(self.max_iter) self._lmnn.set_obj_threshold(self.convergence_tol) self._lmnn.set_regularization(self.regularization) diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 0e8b3513..cb2c1f18 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -12,15 +12,17 @@ import scipy.linalg from six.moves import xrange from sklearn.base import TransformerMixin -from sklearn.utils.validation import check_array, check_X_y -from ._util import check_tuples from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin from .constraints import Constraints class _BaseLSML(MahalanobisMixin): - def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): + + _tuple_size = 4 # constraints are quadruplets + + def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False, + preprocessor=None): """Initialize LSML. Parameters @@ -31,18 +33,19 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): guess at a metric [default: inv(covariance(X))] verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.prior = prior self.tol = tol self.max_iter = max_iter self.verbose = verbose + super(_BaseLSML, self).__init__(preprocessor) - def _prepare_quadruplets(self, quadruplets, weights): - # for now we check_array and check_tuples but we should only - # check_tuples in the future (with enhanced check_tuples) - quadruplets = check_array(quadruplets, accept_sparse=False, - ensure_2d=False, allow_nd=True) - quadruplets = check_tuples(quadruplets) + def _fit(self, quadruplets, y=None, weights=None): + quadruplets = self._prepare_inputs(quadruplets, + type_of_inputs='tuples') # check to make sure that no two constrained vectors are identical self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :] @@ -63,8 +66,6 @@ def _prepare_quadruplets(self, quadruplets, weights): self.M_ = self.prior self.prior_inv_ = np.linalg.inv(self.prior) - def _fit(self, quadruplets, weights=None): - self._prepare_quadruplets(quadruplets, weights) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. l_best = 0 @@ -143,11 +144,13 @@ def fit(self, quadruplets, weights=None): Parameters ---------- - quadruplets : array-like, shape=(n_constraints, 4, n_features) - Each row corresponds to 4 points. In order to supervise the - algorithm in the right way, we should have the four samples ordered - in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) - for all 0 <= i < n_constraints. + quadruplets : array-like, shape=(n_constraints, 4, n_features) or + (n_constraints, 4) + 3D array-like of quadruplets of points or 2D array of quadruplets of + indicators. In order to supervise the algorithm in the right way, we + should have the four samples ordered in a way such that: + d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) for all 0 <= i < + n_constraints. weights : (n_constraints,) array of floats, optional scale factor for each constraint @@ -170,7 +173,8 @@ class LSML_Supervised(_BaseLSML, TransformerMixin): """ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, - num_constraints=None, weights=None, verbose=False): + num_constraints=None, weights=None, verbose=False, + preprocessor=None): """Initialize the learner. Parameters @@ -187,9 +191,12 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, scale factor for each constraint verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ _BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, - verbose=verbose) + verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints self.weights = weights @@ -208,7 +215,7 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. """ - X, y = check_X_y(X, y) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 9f774322..8e8af9f2 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -13,7 +13,6 @@ from sklearn.base import TransformerMixin from sklearn.decomposition import PCA -from sklearn.utils.validation import check_X_y from .base_metric import MahalanobisMixin @@ -30,7 +29,7 @@ class MLKR(MahalanobisMixin, TransformerMixin): """ def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001, - max_iter=1000): + max_iter=1000, preprocessor=None): """ Initialize MLKR. @@ -50,16 +49,30 @@ def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001, max_iter: int, optional Cap on number of congugate gradient iterations. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.num_dims = num_dims self.A0 = A0 self.epsilon = epsilon self.alpha = alpha self.max_iter = max_iter + super(MLKR, self).__init__(preprocessor) + + def fit(self, X, y): + """ + Fit MLKR model - def _process_inputs(self, X, y): - self.X_, y = check_X_y(X, y) - n, d = self.X_.shape + Parameters + ---------- + X : (n x d) array of samples + y : (n) data labels + """ + X, y = self._prepare_inputs(X, y, y_numeric=True, + ensure_min_samples=2) + n, d = X.shape if y.shape[0] != n: raise ValueError('Data and label lengths mismatch: %d != %d' % (n, y.shape[0])) @@ -75,18 +88,6 @@ def _process_inputs(self, X, y): elif A.shape != (m, d): raise ValueError('A0 needs shape (%d,%d) but got %s' % ( m, d, A.shape)) - return self.X_, y, A - - def fit(self, X, y): - """ - Fit MLKR model - - Parameters - ---------- - X : (n x d) array of samples - y : (n) data labels - """ - X, y, A = self._process_inputs(X, y) # note: this line takes (n*n*d) memory! # for larger datasets, we'll need to compute dX as we go diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 2f2ee400..89c18b58 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -20,17 +20,21 @@ import numpy as np from six.moves import xrange from sklearn.base import TransformerMixin -from sklearn.utils.validation import check_array, check_X_y +from sklearn.utils.validation import check_array from .base_metric import _PairsClassifierMixin, MahalanobisMixin from .constraints import Constraints, wrap_pairs -from ._util import vector_norm, check_tuples +from ._util import vector_norm class _BaseMMC(MahalanobisMixin): """Mahalanobis Metric for Clustering (MMC)""" + + _tuple_size = 2 # constraints are pairs + def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, - A0=None, diagonal=False, diagonal_c=1.0, verbose=False): + A0=None, diagonal=False, diagonal_c=1.0, verbose=False, + preprocessor=None): """Initialize MMC. Parameters ---------- @@ -48,6 +52,9 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, metric learning verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. """ self.max_iter = max_iter self.max_proj = max_proj @@ -56,31 +63,11 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, self.diagonal = diagonal self.diagonal_c = diagonal_c self.verbose = verbose + super(_BaseMMC, self).__init__(preprocessor) def _fit(self, pairs, y): - pairs, y = self._process_pairs(pairs, y) - if self.diagonal: - return self._fit_diag(pairs, y) - else: - return self._fit_full(pairs, y) - - def _process_pairs(self, pairs, y): - # for now we check_X_y and check_tuples but we should only - # check_tuples_y in the future - pairs, y = check_X_y(pairs, y, accept_sparse=False, - ensure_2d=False, allow_nd=True) - pairs = check_tuples(pairs) - - # check to make sure that no two constrained vectors are identical - pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] - pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 - pos_pairs = pos_pairs[pos_no_ident] - neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 - neg_pairs = neg_pairs[neg_no_ident] - if len(pos_pairs) == 0: - raise ValueError('No non-trivial similarity constraints given for MMC.') - if len(neg_pairs) == 0: - raise ValueError('No non-trivial dissimilarity constraints given for MMC.') + pairs, y = self._prepare_inputs(pairs, y, + type_of_inputs='tuples') # init metric if self.A0 is None: @@ -92,9 +79,10 @@ def _process_pairs(self, pairs, y): else: self.A_ = check_array(self.A0) - pairs = np.vstack([pos_pairs, neg_pairs]) - y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))]) - return pairs, y + if self.diagonal: + return self._fit_diag(pairs, y) + else: + return self._fit_full(pairs, y) def _fit_full(self, pairs, y): """Learn full metric using MMC. @@ -373,8 +361,11 @@ def fit(self, pairs, y): Parameters ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) - Array of pairs. Each row corresponds to two points. + pairs: array-like, shape=(n_constraints, 2, n_features) or + (n_constraints, 2) + 3D Array of pairs with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. y: array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. @@ -398,7 +389,8 @@ class MMC_Supervised(_BaseMMC, TransformerMixin): def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, num_labeled=np.inf, num_constraints=None, - A0=None, diagonal=False, diagonal_c=1.0, verbose=False): + A0=None, diagonal=False, diagonal_c=1.0, verbose=False, + preprocessor=None): """Initialize the learner. Parameters @@ -421,11 +413,14 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, metric learning verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ _BaseMMC.__init__(self, max_iter=max_iter, max_proj=max_proj, convergence_threshold=convergence_threshold, A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, - verbose=verbose) + verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -441,7 +436,7 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. """ - X, y = check_X_y(X, y) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 19e016ec..791617c5 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -7,7 +7,6 @@ import numpy as np from six.moves import xrange from sklearn.base import TransformerMixin -from sklearn.utils.validation import check_X_y from .base_metric import MahalanobisMixin @@ -23,17 +22,19 @@ class NCA(MahalanobisMixin, TransformerMixin): The learned linear transformation ``L``. """ - def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01): + def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01, + preprocessor=None): self.num_dims = num_dims self.max_iter = max_iter self.learning_rate = learning_rate + super(NCA, self).__init__(preprocessor) def fit(self, X, y): """ X: data matrix, (n x d) y: scalar labels, (n) """ - X, labels = check_X_y(X, y) + X, labels = self._prepare_inputs(X, y, ensure_min_samples=2) n, d = X.shape num_dims = self.num_dims if num_dims is None: diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 170e21f8..290ea941 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -17,7 +17,6 @@ from six.moves import xrange from sklearn import decomposition from sklearn.base import TransformerMixin -from sklearn.utils.validation import check_array from .base_metric import MahalanobisMixin from .constraints import Constraints @@ -45,7 +44,7 @@ class RCA(MahalanobisMixin, TransformerMixin): The learned linear transformation ``L``. """ - def __init__(self, num_dims=None, pca_comps=None): + def __init__(self, num_dims=None, pca_comps=None, preprocessor=None): """Initialize the learner. Parameters @@ -59,26 +58,17 @@ def __init__(self, num_dims=None, pca_comps=None): If ``0 < pca_comps < 1``, it is used as the minimum explained variance ratio. See sklearn.decomposition.PCA for more details. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.num_dims = num_dims self.pca_comps = pca_comps + super(RCA, self).__init__(preprocessor) - def _process_data(self, X): - self.X_ = X = check_array(X) - - # PCA projection to remove noise and redundant information. - if self.pca_comps is not None: - pca = decomposition.PCA(n_components=self.pca_comps) - X = pca.fit_transform(X) - M_pca = pca.components_ - else: - X -= X.mean(axis=0) - M_pca = None - - return X, M_pca - - def _check_dimension(self, rank): - d = self.X_.shape[1] + def _check_dimension(self, rank, X): + d = X.shape[1] if rank < d: warnings.warn('The inner covariance matrix is not invertible, ' 'so the transformation matrix may contain Nan values. ' @@ -97,7 +87,7 @@ def _check_dimension(self, rank): dim = self.num_dims return dim - def fit(self, data, chunks): + def fit(self, X, chunks): """Learn the RCA model. Parameters @@ -108,17 +98,26 @@ def fit(self, data, chunks): When ``chunks[i] == -1``, point i doesn't belong to any chunklet. When ``chunks[i] == j``, point i belongs to chunklet j. """ - data, M_pca = self._process_data(data) + X = self._prepare_inputs(X, ensure_min_samples=2) + + # PCA projection to remove noise and redundant information. + if self.pca_comps is not None: + pca = decomposition.PCA(n_components=self.pca_comps) + X_t = pca.fit_transform(X) + M_pca = pca.components_ + else: + X_t = X - X.mean(axis=0) + M_pca = None chunks = np.asanyarray(chunks, dtype=int) - chunk_mask, chunked_data = _chunk_mean_centering(data, chunks) + chunk_mask, chunked_data = _chunk_mean_centering(X_t, chunks) inner_cov = np.cov(chunked_data, rowvar=0, bias=1) - dim = self._check_dimension(np.linalg.matrix_rank(inner_cov)) + dim = self._check_dimension(np.linalg.matrix_rank(inner_cov), X_t) # Fisher Linear Discriminant projection - if dim < data.shape[1]: - total_cov = np.cov(data[chunk_mask], rowvar=0) + if dim < X_t.shape[1]: + total_cov = np.cov(X_t[chunk_mask], rowvar=0) tmp = np.linalg.lstsq(total_cov, inner_cov)[0] vals, vecs = np.linalg.eig(tmp) inds = np.argsort(vals)[:dim] @@ -150,7 +149,7 @@ class RCA_Supervised(RCA): """ def __init__(self, num_dims=None, pca_comps=None, num_chunks=100, - chunk_size=2): + chunk_size=2, preprocessor=None): """Initialize the learner. Parameters @@ -159,8 +158,12 @@ def __init__(self, num_dims=None, pca_comps=None, num_chunks=100, embedding dimension (default: original dimension of data) num_chunks: int, optional chunk_size: int, optional + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ - RCA.__init__(self, num_dims=num_dims, pca_comps=pca_comps) + RCA.__init__(self, num_dims=num_dims, pca_comps=pca_comps, + preprocessor=preprocessor) self.num_chunks = num_chunks self.chunk_size = chunk_size @@ -175,6 +178,7 @@ def fit(self, X, y, random_state=np.random): y : (n) data labels random_state : a random.seed object to fix the random_state if needed. """ + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) chunks = Constraints(y).chunks(num_chunks=self.num_chunks, chunk_size=self.chunk_size, random_state=random_state) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 0d3c8b92..b7f36bf4 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -13,16 +13,17 @@ from sklearn.base import TransformerMixin from sklearn.covariance import graph_lasso from sklearn.utils.extmath import pinvh -from sklearn.utils.validation import check_array, check_X_y from .base_metric import MahalanobisMixin, _PairsClassifierMixin from .constraints import Constraints, wrap_pairs -from ._util import check_tuples class _BaseSDML(MahalanobisMixin): + + _tuple_size = 2 # constraints are pairs + def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, - verbose=False): + verbose=False, preprocessor=None): """ Parameters ---------- @@ -37,18 +38,20 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose : bool, optional if True, prints information while learning + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. """ self.balance_param = balance_param self.sparsity_param = sparsity_param self.use_cov = use_cov self.verbose = verbose + super(_BaseSDML, self).__init__(preprocessor) - def _prepare_pairs(self, pairs, y): - # for now we check_X_y and check_tuples but we should only - # check_tuples_y in the future - pairs, y = check_X_y(pairs, y, accept_sparse=False, - ensure_2d=False, allow_nd=True) - pairs = check_tuples(pairs) + def _fit(self, pairs, y): + pairs, y = self._prepare_inputs(pairs, y, + type_of_inputs='tuples') # set up prior M if self.use_cov: @@ -57,10 +60,7 @@ def _prepare_pairs(self, pairs, y): else: self.M_ = np.identity(pairs.shape[2]) diff = pairs[:, 0] - pairs[:, 1] - return (diff.T * y).dot(diff) - - def _fit(self, pairs, y): - loss_matrix = self._prepare_pairs(pairs, y) + loss_matrix = (diff.T * y).dot(diff) P = self.M_ + self.balance_param * loss_matrix emp_cov = pinvh(P) # hack: ensure positive semidefinite @@ -86,8 +86,11 @@ def fit(self, pairs, y): Parameters ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) - Array of pairs. Each row corresponds to two points. + pairs: array-like, shape=(n_constraints, 2, n_features) or + (n_constraints, 2) + 3D Array of pairs with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. y: array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. @@ -110,7 +113,8 @@ class SDML_Supervised(_BaseSDML, TransformerMixin): """ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, - num_labeled=np.inf, num_constraints=None, verbose=False): + num_labeled=np.inf, num_constraints=None, verbose=False, + preprocessor=None): """ Parameters ---------- @@ -126,10 +130,13 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, number of constraints to generate verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ _BaseSDML.__init__(self, balance_param=balance_param, sparsity_param=sparsity_param, use_cov=use_cov, - verbose=verbose) + verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -151,7 +158,7 @@ def fit(self, X, y, random_state=np.random): self : object Returns the instance. """ - y = check_array(y, ensure_2d=False) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 1671c8ef..d7a1d935 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -40,7 +40,7 @@ def test_iris(self): csep = class_separation(cov.transform(self.iris_points), self.iris_labels) # deterministic result - self.assertAlmostEqual(csep, 0.73068122) + self.assertAlmostEqual(csep, 0.72981476) class TestLSML(MetricTestCase): @@ -90,21 +90,16 @@ def test_iris(self): n = self.iris_points.shape[0] # Without dimension reduction - nca = NCA(max_iter=(100000//n), learning_rate=0.01) + nca = NCA(max_iter=(100000//n)) nca.fit(self.iris_points, self.iris_labels) - # Result copied from Iris example at - # https://github.com/vomjom/nca/blob/master/README.mkd - expected = [[-0.09935, -0.2215, 0.3383, 0.443], - [+0.2532, 0.5835, -0.8461, -0.8915], - [-0.729, -0.6386, 1.767, 1.832], - [-0.9405, -0.8461, 2.281, 2.794]] - assert_array_almost_equal(expected, nca.transformer_, decimal=3) + csep = class_separation(nca.transform(self.iris_points), self.iris_labels) + self.assertLess(csep, 0.15) # With dimension reduction nca = NCA(max_iter=(100000//n), learning_rate=0.01, num_dims=2) nca.fit(self.iris_points, self.iris_labels) csep = class_separation(nca.transform(self.iris_points), self.iris_labels) - self.assertLess(csep, 0.15) + self.assertLess(csep, 0.20) class TestLFDA(MetricTestCase): @@ -163,16 +158,16 @@ def test_iris(self): # Full metric mmc = MMC(convergence_threshold=0.01) mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) - expected = [[+0.00046504, +0.00083371, -0.00111959, -0.00165265], - [+0.00083371, +0.00149466, -0.00200719, -0.00296284], - [-0.00111959, -0.00200719, +0.00269546, +0.00397881], - [-0.00165265, -0.00296284, +0.00397881, +0.00587320]] + expected = [[+0.000514, +0.000868, -0.001195, -0.001703], + [+0.000868, +0.001468, -0.002021, -0.002879], + [-0.001195, -0.002021, +0.002782, +0.003964], + [-0.001703, -0.002879, +0.003964, +0.005648]] assert_array_almost_equal(expected, mmc.metric(), decimal=6) # Diagonal metric mmc = MMC(diagonal=True) mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) - expected = [0, 0, 1.21045968, 1.22552608] + expected = [0, 0, 1.210220, 1.228596] assert_array_almost_equal(np.diag(expected), mmc.metric(), decimal=6) # Supervised Full diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 31db4e6f..d71bf760 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -5,73 +5,80 @@ class TestStringRepr(unittest.TestCase): def test_covariance(self): - self.assertEqual(str(metric_learn.Covariance()), "Covariance()") + self.assertEqual(str(metric_learn.Covariance()), + "Covariance(preprocessor=None)") def test_lmnn(self): self.assertRegexpMatches( str(metric_learn.LMNN()), r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " - r"max_iter=1000,\n min_iter=50, regularization=0.5, " - r"use_pca=True, verbose=False\)") + r"max_iter=1000,\n min_iter=50, preprocessor=None, " + r"regularization=0.5, use_pca=True,\n verbose=False\)") def test_nca(self): self.assertEqual(str(metric_learn.NCA()), - "NCA(learning_rate=0.01, max_iter=100, num_dims=None)") + "NCA(learning_rate=0.01, max_iter=100, num_dims=None, " + "preprocessor=None)") def test_lfda(self): self.assertEqual(str(metric_learn.LFDA()), - "LFDA(embedding_type='weighted', k=None, num_dims=None)") + "LFDA(embedding_type='weighted', k=None, num_dims=None, " + "preprocessor=None)") def test_itml(self): self.assertEqual(str(metric_learn.ITML()), """ ITML(A0=None, convergence_threshold=0.001, gamma=1.0, max_iter=1000, - verbose=False) + preprocessor=None, verbose=False) """.strip('\n')) self.assertEqual(str(metric_learn.ITML_Supervised()), """ ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0, max_iter=1000, num_constraints=None, num_labeled=inf, - verbose=False) + preprocessor=None, verbose=False) """.strip('\n')) def test_lsml(self): self.assertEqual( str(metric_learn.LSML()), - "LSML(max_iter=1000, prior=None, tol=0.001, verbose=False)") + "LSML(max_iter=1000, preprocessor=None, prior=None, tol=0.001, " + "verbose=False)") self.assertEqual(str(metric_learn.LSML_Supervised()), """ LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled=inf, - prior=None, tol=0.001, verbose=False, weights=None) + preprocessor=None, prior=None, tol=0.001, verbose=False, + weights=None) """.strip('\n')) def test_sdml(self): self.assertEqual(str(metric_learn.SDML()), - "SDML(balance_param=0.5, sparsity_param=0.01, " - "use_cov=True, verbose=False)") + "SDML(balance_param=0.5, preprocessor=None, " + "sparsity_param=0.01, use_cov=True,\n verbose=False)") self.assertEqual(str(metric_learn.SDML_Supervised()), """ SDML_Supervised(balance_param=0.5, num_constraints=None, num_labeled=inf, - sparsity_param=0.01, use_cov=True, verbose=False) + preprocessor=None, sparsity_param=0.01, use_cov=True, + verbose=False) """.strip('\n')) def test_rca(self): self.assertEqual(str(metric_learn.RCA()), - "RCA(num_dims=None, pca_comps=None)") + "RCA(num_dims=None, pca_comps=None, preprocessor=None)") self.assertEqual(str(metric_learn.RCA_Supervised()), "RCA_Supervised(chunk_size=2, num_chunks=100, " - "num_dims=None, pca_comps=None)") + "num_dims=None, pca_comps=None,\n " + "preprocessor=None)") def test_mlkr(self): self.assertEqual(str(metric_learn.MLKR()), "MLKR(A0=None, alpha=0.0001, epsilon=0.01, " - "max_iter=1000, num_dims=None)") + "max_iter=1000, num_dims=None,\n preprocessor=None)") def test_mmc(self): self.assertEqual(str(metric_learn.MMC()), """ MMC(A0=None, convergence_threshold=0.001, diagonal=False, diagonal_c=1.0, - max_iter=100, max_proj=10000, verbose=False) + max_iter=100, max_proj=10000, preprocessor=None, verbose=False) """.strip('\n')) self.assertEqual(str(metric_learn.MMC_Supervised()), """ MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False, diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None, - num_labeled=inf, verbose=False) + num_labeled=inf, preprocessor=None, verbose=False) """.strip('\n')) if __name__ == '__main__': diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 09a98ece..0d834f10 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -5,83 +5,26 @@ from numpy.testing import assert_array_almost_equal from scipy.spatial.distance import pdist, squareform from sklearn import clone -from sklearn.datasets import load_iris -from sklearn.utils import check_random_state, shuffle +from sklearn.utils import check_random_state from sklearn.utils.testing import set_random_state -from metric_learn import (Constraints, ITML, LSML, MMC, SDML, Covariance, LFDA, - LMNN, MLKR, NCA, RCA) -from metric_learn.constraints import wrap_pairs -from functools import partial +from metric_learn._util import make_context + +from test.test_utils import ids_metric_learners, metric_learners RNG = check_random_state(0) -def build_data(): - dataset = load_iris() - X, y = shuffle(dataset.data, dataset.target, random_state=RNG) - num_constraints = 20 - constraints = Constraints.random_subset(y, random_state=RNG) - pairs = constraints.positive_negative_pairs(num_constraints, - same_length=True, - random_state=RNG) - return X, pairs - - -def build_pairs(): - # test that you can do cross validation on tuples of points with - # a WeaklySupervisedMetricLearner - X, pairs = build_data() - pairs, y = wrap_pairs(X, pairs) - pairs, y = shuffle(pairs, y, random_state=RNG) - return pairs, y - - -def build_quadruplets(): - # test that you can do cross validation on a tuples of points with - # a WeaklySupervisedMetricLearner - X, pairs = build_data() - c = np.column_stack(pairs) - quadruplets = X[c] - quadruplets = shuffle(quadruplets, random_state=RNG) - return quadruplets, None - - -list_estimators = [(Covariance(), build_data), - (ITML(), build_pairs), - (LFDA(), partial(load_iris, return_X_y=True)), - (LMNN(), partial(load_iris, return_X_y=True)), - (LSML(), build_quadruplets), - (MLKR(), partial(load_iris, return_X_y=True)), - (MMC(), build_pairs), - (NCA(), partial(load_iris, return_X_y=True)), - (RCA(), partial(load_iris, return_X_y=True)), - (SDML(), build_pairs) - ] - -ids_estimators = ['covariance', - 'itml', - 'lfda', - 'lmnn', - 'lsml', - 'mlkr', - 'mmc', - 'nca', - 'rca', - 'sdml', - ] - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_score_pairs_pairwise(estimator, build_dataset): # Computing pairwise scores should return a euclidean distance matrix. - inputs, labels = build_dataset() - X, _ = load_iris(return_X_y=True) + input_data, labels, _, X = build_dataset() n_samples = 20 X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) + model.fit(input_data, labels) pairwise = model.score_pairs(np.array(list(product(X, X))))\ .reshape(n_samples, n_samples) @@ -96,17 +39,16 @@ def test_score_pairs_pairwise(estimator, build_dataset): assert_array_almost_equal(squareform(pairwise), pdist(model.transform(X))) -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_score_pairs_toy_example(estimator, build_dataset): # Checks that score_pairs works on a toy example - inputs, labels = build_dataset() - X, _ = load_iris(return_X_y=True) + input_data, labels, _, X = build_dataset() n_samples = 20 X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) + model.fit(input_data, labels) pairs = np.stack([X[:10], X[10:20]], axis=1) embedded_pairs = pairs.dot(model.transformer_.T) distances = np.sqrt(np.sum((embedded_pairs[:, 1] - @@ -115,39 +57,37 @@ def test_score_pairs_toy_example(estimator, build_dataset): assert_array_almost_equal(model.score_pairs(pairs), distances) -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_score_pairs_finite(estimator, build_dataset): # tests that the score is finite - inputs, labels = build_dataset() + input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) - X, _ = load_iris(return_X_y=True) + model.fit(input_data, labels) pairs = np.array(list(product(X, X))) assert np.isfinite(model.score_pairs(pairs)).all() -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_score_pairs_dim(estimator, build_dataset): # scoring of 3D arrays should return 1D array (several tuples), # and scoring of 2D arrays (one tuple) should return an error (like # scikit-learn's error when scoring 1D arrays) - inputs, labels = build_dataset() + input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) - X, _ = load_iris(return_X_y=True) + model.fit(input_data, labels) tuples = np.array(list(product(X, X))) assert model.score_pairs(tuples).shape == (tuples.shape[0],) - msg = ("Expected 3D array, got 2D array instead:\ntuples={}.\n" - "Reshape your data either using tuples.reshape(-1, {}, 1) if " - "your data has a single feature or tuples.reshape(1, {}, -1) " - "if it contains a single tuple.".format(tuples, tuples.shape[1], - tuples.shape[0])) - with pytest.raises(ValueError, message=msg): + context = make_context(estimator) + msg = ("3D array of formed tuples expected{}. Found 2D array " + "instead:\ninput={}. Reshape your data and/or use a preprocessor.\n" + .format(context, tuples[1])) + with pytest.raises(ValueError) as raised_error: model.score_pairs(tuples[1]) + assert str(raised_error.value) == msg def check_is_distance_matrix(pairwise): @@ -156,74 +96,72 @@ def check_is_distance_matrix(pairwise): assert (pairwise.diagonal() == 0).all() # identity # triangular inequality tol = 1e-15 - assert (pairwise <= pairwise[:, :, np.newaxis] - + pairwise[:, np.newaxis, :] + tol).all() + assert (pairwise <= pairwise[:, :, np.newaxis] + + pairwise[:, np.newaxis, :] + tol).all() -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_embed_toy_example(estimator, build_dataset): # Checks that embed works on a toy example - inputs, labels = build_dataset() - X, _ = load_iris(return_X_y=True) + input_data, labels, _, X = build_dataset() n_samples = 20 X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) + model.fit(input_data, labels) embedded_points = X.dot(model.transformer_.T) assert_array_almost_equal(model.transform(X), embedded_points) -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_embed_dim(estimator, build_dataset): # Checks that the the dimension of the output space is as expected - inputs, labels = build_dataset() + input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) - X, _ = load_iris(return_X_y=True) + model.fit(input_data, labels) assert model.transform(X).shape == X.shape # assert that ValueError is thrown if input shape is 1D - err_msg = ("Expected 2D array, got 1D array instead:\narray={}.\n" - "Reshape your data either using array.reshape(-1, 1) if " - "your data has a single feature or array.reshape(1, -1) " - "if it contains a single sample.".format(X)) - with pytest.raises(ValueError, message=err_msg): + context = make_context(estimator) + err_msg = ("2D array of formed points expected{}. Found 1D array " + "instead:\ninput={}. Reshape your data and/or use a " + "preprocessor.\n".format(context, X[0])) + with pytest.raises(ValueError) as raised_error: model.score_pairs(model.transform(X[0, :])) + assert str(raised_error.value) == err_msg # we test that the shape is also OK when doing dimensionality reduction if type(model).__name__ in {'LFDA', 'MLKR', 'NCA', 'RCA'}: model.set_params(num_dims=2) - model.fit(inputs, labels) + model.fit(input_data, labels) assert model.transform(X).shape == (X.shape[0], 2) # assert that ValueError is thrown if input shape is 1D - with pytest.raises(ValueError, message=err_msg): + with pytest.raises(ValueError) as raised_error: model.transform(model.transform(X[0, :])) + assert str(raised_error.value) == err_msg -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_embed_finite(estimator, build_dataset): # Checks that embed returns vectors with finite values - inputs, labels = build_dataset() + input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) - X, _ = load_iris(return_X_y=True) + model.fit(input_data, labels) assert np.isfinite(model.transform(X)).all() -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) def test_embed_is_linear(estimator, build_dataset): # Checks that the embedding is linear - inputs, labels = build_dataset() + input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(inputs, labels) - X, _ = load_iris(return_X_y=True) + model.fit(input_data, labels) assert_array_almost_equal(model.transform(X[:10] + X[10:20]), model.transform(X[:10]) + model.transform(X[10:20])) diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index f1e1a09d..d9dce685 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -1,10 +1,23 @@ -import numpy as np +import pytest import unittest from sklearn.utils.estimator_checks import check_estimator +from sklearn.base import TransformerMixin +from sklearn.pipeline import make_pipeline +from sklearn.utils import check_random_state +from sklearn.utils.estimator_checks import is_public_parameter +from sklearn.utils.testing import (assert_allclose_dense_sparse, + set_random_state) -from metric_learn import ( - LMNN, NCA, LFDA, Covariance, MLKR, - LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) +from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA, + ITML_Supervised, LSML_Supervised, + MMC_Supervised, RCA_Supervised, SDML_Supervised) +from sklearn import clone +import numpy as np +from sklearn.model_selection import (cross_val_score, cross_val_predict, + train_test_split, KFold) +from sklearn.utils.testing import _get_args +from test.test_utils import (metric_learners, ids_metric_learners, + mock_preprocessor) # Wrap the _Supervised methods with a deterministic wrapper for testing. @@ -68,5 +81,263 @@ def test_mmc(self): # check_estimator(RCA_Supervised) +RNG = check_random_state(0) + + +# ---------------------- Test scikit-learn compatibility ---------------------- + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_cross_validation_is_finite(estimator, build_dataset, + with_preprocessor): + """Tests that validation on metric-learn estimators returns something finite + """ + if any(hasattr(estimator, method) for method in ["predict", "score"]): + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + if hasattr(estimator, "score"): + assert np.isfinite(cross_val_score(estimator, input_data, labels)).all() + if hasattr(estimator, "predict"): + assert np.isfinite(cross_val_predict(estimator, + input_data, labels)).all() + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_cross_validation_manual_vs_scikit(estimator, build_dataset, + with_preprocessor): + """Tests that if we make a manual cross-validation, the result will be the + same as scikit-learn's cross-validation (some code for generating the + folds is taken from scikit-learn). + """ + if any(hasattr(estimator, method) for method in ["predict", "score"]): + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + n_splits = 3 + kfold = KFold(shuffle=False, n_splits=n_splits) + n_samples = input_data.shape[0] + fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int) + fold_sizes[:n_samples % n_splits] += 1 + current = 0 + scores, predictions = [], np.zeros(input_data.shape[0]) + for fold_size in fold_sizes: + start, stop = current, current + fold_size + current = stop + test_slice = slice(start, stop) + train_mask = np.ones(input_data.shape[0], bool) + train_mask[test_slice] = False + y_train, y_test = labels[train_mask], labels[test_slice] + estimator.fit(input_data[train_mask], y_train) + if hasattr(estimator, "score"): + scores.append(estimator.score(input_data[test_slice], y_test)) + if hasattr(estimator, "predict"): + predictions[test_slice] = estimator.predict(input_data[test_slice]) + if hasattr(estimator, "score"): + assert all(scores == cross_val_score(estimator, input_data, labels, + cv=kfold)) + if hasattr(estimator, "predict"): + assert all(predictions == cross_val_predict(estimator, input_data, + labels, + cv=kfold)) + + +def check_score(estimator, tuples, y): + if hasattr(estimator, "score"): + score = estimator.score(tuples, y) + assert np.isfinite(score) + + +def check_predict(estimator, tuples): + if hasattr(estimator, "predict"): + y_predicted = estimator.predict(tuples) + assert len(y_predicted), len(tuples) + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_simple_estimator(estimator, build_dataset, with_preprocessor): + """Tests that fit, predict and scoring works. + """ + if any(hasattr(estimator, method) for method in ["predict", "score"]): + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + (tuples_train, tuples_test, y_train, + y_test) = train_test_split(input_data, labels, random_state=RNG) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + + estimator.fit(tuples_train, y_train) + check_score(estimator, tuples_test, y_test) + check_predict(estimator, tuples_test) + + +@pytest.mark.parametrize('estimator', [est[0] for est in metric_learners], + ids=ids_metric_learners) +@pytest.mark.parametrize('preprocessor', [None, mock_preprocessor]) +def test_no_attributes_set_in_init(estimator, preprocessor): + """Check setting during init. Adapted from scikit-learn.""" + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + if hasattr(type(estimator).__init__, "deprecated_original"): + return + + init_params = _get_args(type(estimator).__init__) + parents_init_params = [param for params_parent in + (_get_args(parent) for parent in + type(estimator).__mro__) + for param in params_parent] + + # Test for no setting apart from parameters during init + invalid_attr = (set(vars(estimator)) - set(init_params) - + set(parents_init_params)) + assert not invalid_attr, \ + ("Estimator %s should not set any attribute apart" + " from parameters during init. Found attributes %s." + % (type(estimator).__name__, sorted(invalid_attr))) + # Ensure that each parameter is set in init + invalid_attr = (set(init_params) - set(vars(estimator)) - + set(["self"])) + assert not invalid_attr, \ + ("Estimator %s should store all parameters" + " as an attribute during init. Did not find " + "attributes %s." % (type(estimator).__name__, sorted(invalid_attr))) + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_estimators_fit_returns_self(estimator, build_dataset, + with_preprocessor): + """Check if self is returned when calling fit""" + # Adapted from scikit-learn + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + assert estimator.fit(input_data, labels) is estimator + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_pipeline_consistency(estimator, build_dataset, + with_preprocessor): + # Adapted from scikit learn + # check that make_pipeline(est) gives same score as est + input_data, y, preprocessor, _ = build_dataset(with_preprocessor) + + def make_random_state(estimator, in_pipeline): + rs = {} + name_estimator = estimator.__class__.__name__ + if name_estimator[-11:] == '_Supervised': + name_param = 'random_state' + if in_pipeline: + name_param = name_estimator.lower() + '__' + name_param + rs[name_param] = check_random_state(0) + return rs + + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + pipeline = make_pipeline(estimator) + estimator.fit(input_data, y, **make_random_state(estimator, False)) + pipeline.fit(input_data, y, **make_random_state(estimator, True)) + + if hasattr(estimator, 'score'): + result = estimator.score(input_data, y) + result_pipe = pipeline.score(input_data, y) + assert_allclose_dense_sparse(result, result_pipe) + + if hasattr(estimator, 'predict'): + result = estimator.predict(input_data) + result_pipe = pipeline.predict(input_data) + assert_allclose_dense_sparse(result, result_pipe) + + if issubclass(estimator.__class__, TransformerMixin): + if hasattr(estimator, 'transform'): + result = estimator.transform(input_data) + result_pipe = pipeline.transform(input_data) + assert_allclose_dense_sparse(result, result_pipe) + + +@pytest.mark.parametrize('with_preprocessor',[True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_dict_unchanged(estimator, build_dataset, with_preprocessor): + # Adapted from scikit-learn + (input_data, labels, preprocessor, + to_transform) = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + if hasattr(estimator, "num_dims"): + estimator.num_dims = 1 + estimator.fit(input_data, labels) + + def check_dict(): + assert estimator.__dict__ == dict_before, ( + "Estimator changes __dict__ during %s" % method) + for method in ["predict", "decision_function", "predict_proba"]: + if hasattr(estimator, method): + dict_before = estimator.__dict__.copy() + getattr(estimator, method)(input_data) + check_dict() + if hasattr(estimator, "transform"): + dict_before = estimator.__dict__.copy() + # we transform only dataset of points + estimator.transform(to_transform) + check_dict() + + +@pytest.mark.parametrize('with_preprocessor',[True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_dont_overwrite_parameters(estimator, build_dataset, + with_preprocessor): + # Adapted from scikit-learn + # check that fit method only changes or sets private attributes + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + if hasattr(estimator, "num_dims"): + estimator.num_dims = 1 + dict_before_fit = estimator.__dict__.copy() + + estimator.fit(input_data, labels) + dict_after_fit = estimator.__dict__ + + public_keys_after_fit = [key for key in dict_after_fit.keys() + if is_public_parameter(key)] + + attrs_added_by_fit = [key for key in public_keys_after_fit + if key not in dict_before_fit.keys()] + + # check that fit doesn't add any public attribute + assert not attrs_added_by_fit, ( + "Estimator adds public attribute(s) during" + " the fit method." + " Estimators are only allowed to add private " + "attributes" + " either started with _ or ended" + " with _ but %s added" % ', '.join(attrs_added_by_fit)) + + # check that fit doesn't change any public attribute + attrs_changed_by_fit = [key for key in public_keys_after_fit + if (dict_before_fit[key] + is not dict_after_fit[key])] + + assert not attrs_changed_by_fit, ( + "Estimator changes public attribute(s) during" + " the fit method. Estimators are only allowed" + " to change attributes started" + " or ended with _, but" + " %s changed" % ', '.join(attrs_changed_by_fit)) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_utils.py b/test/test_utils.py index 8ca3aac3..de59e9ff 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,29 +1,1013 @@ -import numpy as np import pytest -from metric_learn._util import check_tuples - - -def test_check_tuples(): - X = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) - check_tuples(X) - - X = 5 - msg = ("Expected 3D array, got scalar instead. Cannot apply this function " - "on scalars.") - with pytest.raises(ValueError, message=msg): - check_tuples(X) - - X = np.array([1, 2, 3]) - msg = ("Expected 3D array, got 1D array instead:\ntuples=[1, 2, 3].\n" - "Reshape your data using tuples.reshape(1, -1, 1) if it contains a " - "single tuple and the points in the tuple have a single feature.") - with pytest.raises(ValueError, message=msg): - check_tuples(X) - - X = np.array([[1, 2, 3], [2, 3, 5]]) - msg = ("Expected 3D array, got 2D array instead:\ntuples=[[1, 2, 3], " - "[2, 3, 5]].\nReshape your data either using " - "tuples.reshape(-1, 3, 1) if your data has a single feature or " - "tuples.reshape(1, 2, -1) if it contains a single tuple.") - with pytest.raises(ValueError, message=msg): - check_tuples(X) +from collections import namedtuple +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.exceptions import DataConversionWarning +from sklearn.utils import check_random_state, shuffle +from sklearn.utils.testing import set_random_state +from sklearn.base import clone +from metric_learn._util import (check_input, make_context, preprocess_tuples, + make_name, preprocess_points, + check_collapsed_pairs) +from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, + LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, + MMC_Supervised, RCA_Supervised, SDML_Supervised, + Constraints) +from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, + _PairsClassifierMixin, + _QuadrupletsClassifierMixin) +from metric_learn.exceptions import PreprocessorError +from sklearn.datasets import make_regression, make_blobs, load_iris + + +SEED = 42 +RNG = check_random_state(SEED) + +Dataset = namedtuple('Dataset', ('data target preprocessor to_transform')) +# Data and target are what we will fit on. Preprocessor is the additional +# data if we use a preprocessor (which should be the default ArrayIndexer), +# and to_transform is some additional data that we would want to transform + + +@pytest.fixture +def build_classification(with_preprocessor=False): + """Basic array for testing when using a preprocessor""" + X, y = shuffle(*make_blobs(random_state=SEED), + random_state=SEED) + indices = shuffle(np.arange(X.shape[0]), random_state=SEED).astype(int) + if with_preprocessor: + return Dataset(indices, y[indices], X, indices) + else: + return Dataset(X[indices], y[indices], None, X[indices]) + + +@pytest.fixture +def build_regression(with_preprocessor=False): + """Basic array for testing when using a preprocessor""" + X, y = shuffle(*make_regression(n_samples=100, n_features=5, + random_state=SEED), + random_state=SEED) + indices = shuffle(np.arange(X.shape[0]), random_state=SEED).astype(int) + if with_preprocessor: + return Dataset(indices, y[indices], X, indices) + else: + return Dataset(X[indices], y[indices], None, X[indices]) + + +def build_data(): + input_data, labels = load_iris(return_X_y=True) + X, y = shuffle(input_data, labels, random_state=SEED) + num_constraints = 50 + constraints = ( + Constraints.random_subset(y, random_state=check_random_state(SEED))) + pairs = ( + constraints + .positive_negative_pairs(num_constraints, same_length=True, + random_state=check_random_state(SEED))) + return X, pairs + + +def build_pairs(with_preprocessor=False): + # builds a toy pairs problem + X, indices = build_data() + c = np.vstack([np.column_stack(indices[:2]), np.column_stack(indices[2:])]) + target = np.concatenate([np.ones(indices[0].shape[0]), + - np.ones(indices[0].shape[0])]) + c, target = shuffle(c, target, random_state=SEED) + if with_preprocessor: + # if preprocessor, we build a 2D array of pairs of indices + return Dataset(c, target, X, c[:, 0]) + else: + # if not, we build a 3D array of pairs of samples + return Dataset(X[c], target, None, X[c[:, 0]]) + + +def build_quadruplets(with_preprocessor=False): + # builds a toy quadruplets problem + X, indices = build_data() + c = np.column_stack(indices) + target = np.ones(c.shape[0]) # quadruplets targets are not used + # anyways + c, target = shuffle(c, target, random_state=SEED) + if with_preprocessor: + # if preprocessor, we build a 2D array of quadruplets of indices + return Dataset(c, target, X, c[:, 0]) + else: + # if not, we build a 3D array of quadruplets of samples + return Dataset(X[c], target, None, X[c[:, 0]]) + + +quadruplets_learners = [(LSML(), build_quadruplets)] +ids_quadruplets_learners = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + quadruplets_learners])) + +pairs_learners = [(ITML(), build_pairs), + (MMC(max_iter=2), build_pairs), # max_iter=2 for faster + (SDML(), build_pairs), + ] +ids_pairs_learners = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + pairs_learners])) + +classifiers = [(Covariance(), build_classification), + (LFDA(), build_classification), + (LMNN(), build_classification), + (NCA(), build_classification), + (RCA(), build_classification), + (ITML_Supervised(max_iter=5), build_classification), + (LSML_Supervised(), build_classification), + (MMC_Supervised(max_iter=5), build_classification), + (RCA_Supervised(num_chunks=10), build_classification), + (SDML_Supervised(), build_classification) + ] +ids_classifiers = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + classifiers])) + +regressors = [(MLKR(), build_regression)] +ids_regressors = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in regressors])) + +WeaklySupervisedClasses = (_PairsClassifierMixin, + _QuadrupletsClassifierMixin) + +tuples_learners = pairs_learners + quadruplets_learners +ids_tuples_learners = ids_pairs_learners + ids_quadruplets_learners + +supervised_learners = classifiers + regressors +ids_supervised_learners = ids_classifiers + ids_regressors + +metric_learners = tuples_learners + supervised_learners +ids_metric_learners = ids_tuples_learners + ids_supervised_learners + + +def mock_preprocessor(indices): + """A preprocessor for testing purposes that returns an all ones 3D array + """ + return np.ones((indices.shape[0], 3)) + + +@pytest.mark.parametrize('type_of_inputs', ['other', 'tuple', 'classics', 2, + int, NCA()]) +def test_check_input_invalid_type_of_inputs(type_of_inputs): + """Tests that an invalid type of inputs in check_inputs raises an error.""" + with pytest.raises(ValueError) as e: + check_input([[0.2, 2.1], [0.2, .8]], type_of_inputs=type_of_inputs) + msg = ("Unknown value {} for type_of_inputs. Valid values are " + "'classic' or 'tuples'.".format(type_of_inputs)) + assert str(e.value) == msg + + +# ---------------- test check_input with 'tuples' type_of_input' ------------ + + +@pytest.fixture +def tuples_prep(): + """Basic array for testing when using a preprocessor""" + tuples = np.array([[1, 2], + [2, 3]]) + return tuples + + +@pytest.fixture +def tuples_no_prep(): + """Basic array for testing when using no preprocessor""" + tuples = np.array([[[1., 2.3], [2.3, 5.3]], + [[2.3, 4.3], [0.2, 0.4]]]) + return tuples + + +@pytest.mark.parametrize('estimator, expected', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +def test_make_context(estimator, expected): + """test the make_name function""" + assert make_context(estimator) == expected + + +@pytest.mark.parametrize('estimator, expected', + [(NCA(), "NCA"), ('NCA', "NCA"), (None, None)]) +def test_make_name(estimator, expected): + """test the make_name function""" + assert make_name(estimator) == expected + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_tuples, preprocessor', + [(tuples_prep, mock_preprocessor), + (tuples_no_prep, None), + (tuples_no_prep, mock_preprocessor)]) +def test_check_tuples_invalid_tuple_size(estimator, context, load_tuples, + preprocessor): + """Checks that the exception are raised if tuple_size is not the one + expected""" + tuples = load_tuples() + preprocessed_tuples = (preprocess_tuples(tuples, preprocessor) + if (preprocessor is not None and + tuples.ndim == 2) else tuples) + expected_msg = ("Tuples of 3 element(s) expected{}. Got tuples of 2 " + "element(s) instead (shape={}):\ninput={}.\n" + .format(context, preprocessed_tuples.shape, + preprocessed_tuples)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples', tuple_size=3, + preprocessor=preprocessor, estimator=estimator) + assert str(raised_error.value) == expected_msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('tuples, found, expected, preprocessor', + [(5, '0', '2D array of indicators or 3D array of ' + 'formed tuples', mock_preprocessor), + (5, '0', '3D array of formed tuples', None), + ([1, 2], '1', '2D array of indicators or 3D array ' + 'of formed tuples', mock_preprocessor), + ([1, 2], '1', '3D array of formed tuples', None), + ([[[[5]]]], '4', '2D array of indicators or 3D array' + ' of formed tuples', + mock_preprocessor), + ([[[[5]]]], '4', '3D array of formed tuples', None), + ([[1], [3]], '2', '3D array of formed ' + 'tuples', None)]) +def test_check_tuples_invalid_shape(estimator, context, tuples, found, + expected, preprocessor): + """Checks that a value error with the appropriate message is raised if + shape is invalid (not 2D with preprocessor or 3D with no preprocessor) + """ + tuples = np.array(tuples) + msg = ("{} expected{}{}. Found {}D array instead:\ninput={}. Reshape your " + "data{}.\n" + .format(expected, context, ' when using a preprocessor' + if preprocessor else '', found, tuples, + ' and/or use a preprocessor' if + (not preprocessor and tuples.ndim == 2) else '')) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples', + preprocessor=preprocessor, ensure_min_samples=0, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +def test_check_tuples_invalid_n_features(estimator, context, tuples_no_prep): + """Checks that the right warning is printed if not enough features + Here we only test if no preprocessor (otherwise we don't ensure this) + """ + msg = ("Found array with 2 feature(s) (shape={}) while" + " a minimum of 3 is required{}.".format(tuples_no_prep.shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples_no_prep, type_of_inputs='tuples', + preprocessor=None, ensure_min_features=3, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_tuples, preprocessor', + [(tuples_prep, mock_preprocessor), + (tuples_no_prep, None), + (tuples_no_prep, mock_preprocessor)]) +def test_check_tuples_invalid_n_samples(estimator, context, load_tuples, + preprocessor): + """Checks that the right warning is printed if n_samples is too small""" + tuples = load_tuples() + msg = ("Found array with 2 sample(s) (shape={}) while a minimum of 3 " + "is required{}.".format((preprocess_tuples(tuples, preprocessor) + if (preprocessor is not None and + tuples.ndim == 2) else tuples).shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples', + preprocessor=preprocessor, + ensure_min_samples=3, estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_tuples, preprocessor', + [(tuples_prep, mock_preprocessor), + (tuples_no_prep, None), + (tuples_no_prep, mock_preprocessor)]) +def test_check_tuples_invalid_dtype_convertible(estimator, context, + load_tuples, preprocessor): + """Checks that a warning is raised if a convertible input is converted to + float""" + tuples = load_tuples().astype(object) # here the object conversion is + # useless for the tuples_prep case, but this allows to test the + # tuples_prep case + + if preprocessor is not None: # if the preprocessor is not None we + # overwrite it to have a preprocessor that returns objects + def preprocessor(indices): # + # preprocessor that returns objects + return np.ones((indices.shape[0], 3)).astype(object) + + msg = ("Data with input dtype object was converted to float64{}." + .format(context)) + with pytest.warns(DataConversionWarning) as raised_warning: + check_input(tuples, type_of_inputs='tuples', + preprocessor=preprocessor, dtype=np.float64, + warn_on_dtype=True, estimator=estimator) + assert str(raised_warning[0].message) == msg + + +def test_check_tuples_invalid_dtype_not_convertible_with_preprocessor( + tuples_prep): + """Checks that a value error is thrown if attempting to convert an + input not convertible to float, when using a preprocessor + """ + + def preprocessor(indices): + # preprocessor that returns objects + return np.full((indices.shape[0], 3), 'a') + + with pytest.raises(ValueError): + check_input(tuples_prep, type_of_inputs='tuples', + preprocessor=preprocessor, dtype=np.float64) + + +def test_check_tuples_invalid_dtype_not_convertible_without_preprocessor( + tuples_no_prep): + """Checks that a value error is thrown if attempting to convert an + input not convertible to float, when using no preprocessor + """ + tuples = np.full_like(tuples_no_prep, 'a', dtype=object) + with pytest.raises(ValueError): + check_input(tuples, type_of_inputs='tuples', + preprocessor=None, dtype=np.float64) + + +@pytest.mark.parametrize('tuple_size', [2, None]) +def test_check_tuples_valid_tuple_size(tuple_size, tuples_prep, tuples_no_prep): + """For inputs that have the right matrix dimension (2D or 3D for instance), + checks that checking the number of tuples (pairs, quadruplets, etc) raises + no warning if there is the right number of points in a tuple. + """ + with pytest.warns(None) as record: + check_input(tuples_prep, type_of_inputs='tuples', + preprocessor=mock_preprocessor, tuple_size=tuple_size) + check_input(tuples_no_prep, type_of_inputs='tuples', preprocessor=None, + tuple_size=tuple_size) + assert len(record) == 0 + + +@pytest.mark.parametrize('tuples', + [np.array([[2.5, 0.1, 2.6], + [1.6, 4.8, 9.1]]), + np.array([[2, 0, 2], + [1, 4, 9]]), + np.array([["img1.png", "img3.png"], + ["img2.png", "img4.png"]]), + [[2, 0, 2], + [1, 4, 9]], + [np.array([2, 0, 2]), + np.array([1, 4, 9])], + ((2, 0, 2), + (1, 4, 9)), + np.array([[[1.2, 2.2], [1.4, 3.3]], + [[2.6, 2.3], [3.4, 5.0]]])]) +def test_check_tuples_valid_with_preprocessor(tuples): + """Test that valid inputs when using a preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(tuples, type_of_inputs='tuples', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + +@pytest.mark.parametrize('tuples', + [np.array([[[2.5], [0.1], [2.6]], + [[1.6], [4.8], [9.1]], + [[5.6], [2.8], [6.1]]]), + np.array([[[2], [0], [2]], + [[1], [4], [9]], + [[1], [5], [3]]]), + [[[2], [0], [2]], + [[1], [4], [9]], + [[3], [4], [29]]], + (((2, 1), (0, 2), (2, 3)), + ((1, 2), (4, 4), (9, 3)), + ((3, 1), (4, 4), (29, 4)))]) +def test_check_tuples_valid_without_preprocessor(tuples): + """Test that valid inputs when using no preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(tuples, type_of_inputs='tuples', preprocessor=None) + assert len(record) == 0 + + +def test_check_tuples_behaviour_auto_dtype(tuples_no_prep): + """Checks that check_tuples allows by default every type if using a + preprocessor, and numeric types if using no preprocessor""" + tuples_prep = [['img1.png', 'img2.png'], ['img3.png', 'img5.png']] + with pytest.warns(None) as record: + check_input(tuples_prep, type_of_inputs='tuples', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + with pytest.warns(None) as record: + check_input(tuples_no_prep, type_of_inputs='tuples') # numeric type + assert len(record) == 0 + + # not numeric type + tuples_no_prep = np.array([[['img1.png'], ['img2.png']], + [['img3.png'], ['img5.png']]]) + tuples_no_prep = tuples_no_prep.astype(object) + with pytest.raises(ValueError): + check_input(tuples_no_prep, type_of_inputs='tuples') + + +def test_check_tuples_invalid_complex_data(): + """Checks that the right error message is thrown if given complex data ( + this comes from sklearn's check_array's message)""" + tuples = np.array([[[1 + 2j, 3 + 4j], [5 + 7j, 5 + 7j]], + [[1 + 3j, 2 + 4j], [5 + 8j, 1 + 7j]]]) + msg = ("Complex data not supported\n" + "{}\n".format(tuples)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples') + assert str(raised_error.value) == msg + + +# ------------- test check_input with 'classic' type_of_inputs ---------------- + + +@pytest.fixture +def points_prep(): + """Basic array for testing when using a preprocessor""" + points = np.array([1, 2]) + return points + + +@pytest.fixture +def points_no_prep(): + """Basic array for testing when using no preprocessor""" + points = np.array([[1., 2.3], + [2.3, 4.3]]) + return points + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('points, found, expected, preprocessor', + [(5, '0', '1D array of indicators or 2D array of ' + 'formed points', mock_preprocessor), + (5, '0', '2D array of formed points', None), + ([1, 2], '1', '2D array of formed points', None), + ([[[5]]], '3', '1D array of indicators or 2D ' + 'array of formed points', + mock_preprocessor), + ([[[5]]], '3', '2D array of formed points', None)]) +def test_check_classic_invalid_shape(estimator, context, points, found, + expected, preprocessor): + """Checks that a value error with the appropriate message is raised if + shape is invalid (valid being 1D or 2D with preprocessor or 2D with no + preprocessor) + """ + points = np.array(points) + msg = ("{} expected{}{}. Found {}D array instead:\ninput={}. Reshape your " + "data{}.\n" + .format(expected, context, ' when using a preprocessor' + if preprocessor else '', found, points, + ' and/or use a preprocessor' if + (not preprocessor and points.ndim == 1) else '')) + with pytest.raises(ValueError) as raised_error: + check_input(points, type_of_inputs='classic', preprocessor=preprocessor, + ensure_min_samples=0, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +def test_check_classic_invalid_n_features(estimator, context, + points_no_prep): + """Checks that the right warning is printed if not enough features + Here we only test if no preprocessor (otherwise we don't ensure this) + """ + msg = ("Found array with 2 feature(s) (shape={}) while" + " a minimum of 3 is required{}.".format(points_no_prep.shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(points_no_prep, type_of_inputs='classic', preprocessor=None, + ensure_min_features=3, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_points, preprocessor', + [(points_prep, mock_preprocessor), + (points_no_prep, None), + (points_no_prep, mock_preprocessor)]) +def test_check_classic_invalid_n_samples(estimator, context, load_points, + preprocessor): + """Checks that the right warning is printed if n_samples is too small""" + points = load_points() + msg = ("Found array with 2 sample(s) (shape={}) while a minimum of 3 " + "is required{}.".format((preprocess_points(points, + preprocessor) + if preprocessor is not None and + points.ndim == 1 else + points).shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(points, type_of_inputs='classic', preprocessor=preprocessor, + ensure_min_samples=3, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_points, preprocessor', + [(points_prep, mock_preprocessor), + (points_no_prep, None), + (points_no_prep, mock_preprocessor)]) +def test_check_classic_invalid_dtype_convertible(estimator, context, + load_points, + preprocessor): + """Checks that a warning is raised if a convertible input is converted to + float""" + points = load_points().astype(object) # here the object conversion is + # useless for the points_prep case, but this allows to test the + # points_prep case + + if preprocessor is not None: # if the preprocessor is not None we + # overwrite it to have a preprocessor that returns objects + def preprocessor(indices): + # preprocessor that returns objects + return np.ones((indices.shape[0], 3)).astype(object) + + msg = ("Data with input dtype object was converted to float64{}." + .format(context)) + with pytest.warns(DataConversionWarning) as raised_warning: + check_input(points, type_of_inputs='classic', + preprocessor=preprocessor, dtype=np.float64, + warn_on_dtype=True, estimator=estimator) + assert str(raised_warning[0].message) == msg + + +@pytest.mark.parametrize('preprocessor, points', + [(mock_preprocessor, np.array([['a', 'b'], + ['e', 'b']])), + (None, np.array([[['b', 'v'], ['a', 'd']], + [['x', 'u'], ['c', 'a']]]))]) +def test_check_classic_invalid_dtype_not_convertible(preprocessor, points): + """Checks that a value error is thrown if attempting to convert an + input not convertible to float + """ + with pytest.raises(ValueError): + check_input(points, type_of_inputs='classic', + preprocessor=preprocessor, dtype=np.float64) + + +@pytest.mark.parametrize('points', + [["img1.png", "img3.png", "img2.png"], + np.array(["img1.png", "img3.png", "img2.png"]), + [2, 0, 2, 1, 4, 9], + range(10), + np.array([2, 0, 2]), + (2, 0, 2), + np.array([[1.2, 2.2], + [2.6, 2.3]])]) +def test_check_classic_valid_with_preprocessor(points): + """Test that valid inputs when using a preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(points, type_of_inputs='classic', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + +@pytest.mark.parametrize('points', + [np.array([[2.5, 0.1, 2.6], + [1.6, 4.8, 9.1], + [5.6, 2.8, 6.1]]), + np.array([[2, 0, 2], + [1, 4, 9], + [1, 5, 3]]), + [[2, 0, 2], + [1, 4, 9], + [3, 4, 29]], + ((2, 1, 0, 2, 2, 3), + (1, 2, 4, 4, 9, 3), + (3, 1, 4, 4, 29, 4))]) +def test_check_classic_valid_without_preprocessor(points): + """Test that valid inputs when using no preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(points, type_of_inputs='classic', preprocessor=None) + assert len(record) == 0 + + +def test_check_classic_by_default(): + """Checks that 'classic' is the default behaviour of check_input""" + assert (check_input([[2, 3], [3, 2]]) == + check_input([[2, 3], [3, 2]], type_of_inputs='classic')).all() + + +def test_check_classic_behaviour_auto_dtype(points_no_prep): + """Checks that check_input (for points) allows by default every type if + using a preprocessor, and numeric types if using no preprocessor""" + points_prep = ['img1.png', 'img2.png', 'img3.png', 'img5.png'] + with pytest.warns(None) as record: + check_input(points_prep, type_of_inputs='classic', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + with pytest.warns(None) as record: + check_input(points_no_prep, type_of_inputs='classic') # numeric type + assert len(record) == 0 + + # not numeric type + points_no_prep = np.array(['img1.png', 'img2.png', 'img3.png', + 'img5.png']) + points_no_prep = points_no_prep.astype(object) + with pytest.raises(ValueError): + check_input(points_no_prep, type_of_inputs='classic') + + +def test_check_classic_invalid_complex_data(): + """Checks that the right error message is thrown if given complex data ( + this comes from sklearn's check_array's message)""" + points = np.array([[[1 + 2j, 3 + 4j], [5 + 7j, 5 + 7j]], + [[1 + 3j, 2 + 4j], [5 + 8j, 1 + 7j]]]) + msg = ("Complex data not supported\n" + "{}\n".format(points)) + with pytest.raises(ValueError) as raised_error: + check_input(points, type_of_inputs='classic') + assert str(raised_error.value) == msg + + +# ----------------------------- Test preprocessor ----------------------------- + + +X = np.array([[0.89, 0.11, 1.48, 0.12], + [2.63, 1.08, 1.68, 0.46], + [1.00, 0.59, 0.62, 1.15]]) + + +class MockFileLoader: + """Preprocessor that takes a root file path at construction and simulates + fetching the file in the specific root folder when given the name of the + file""" + + def __init__(self, root): + self.root = root + self.folders = {'fake_root': {'img0.png': X[0], + 'img1.png': X[1], + 'img2.png': X[2] + }, + 'other_folder': {} # empty folder + } + + def __call__(self, path_list): + images = list() + for path in path_list: + images.append(self.folders[self.root][path]) + return np.array(images) + + +def mock_id_loader(list_of_indicators): + """A preprocessor as a function that takes indicators (strings) and + returns the corresponding samples""" + points = [] + for indicator in list_of_indicators: + points.append(X[int(indicator[2:])]) + return np.array(points) + + +tuples_list = [np.array([[0, 1], + [2, 1]]), + + np.array([['img0.png', 'img1.png'], + ['img2.png', 'img1.png']]), + + np.array([['id0', 'id1'], + ['id2', 'id1']]) + ] + +points_list = [np.array([0, 1, 2, 1]), + + np.array(['img0.png', 'img1.png', 'img2.png', 'img1.png']), + + np.array(['id0', 'id1', 'id2', 'id1']) + ] + +preprocessors = [X, MockFileLoader('fake_root'), mock_id_loader] + + +@pytest.fixture +def y_tuples(): + y = [-1, 1] + return y + + +@pytest.fixture +def y_points(): + y = [0, 1, 0, 0] + return y + + +@pytest.mark.parametrize('preprocessor, tuples', zip(preprocessors, + tuples_list)) +def test_preprocessor_weakly_supervised(preprocessor, tuples, y_tuples): + """Tests different ways to use the preprocessor argument: an array, + a class callable, and a function callable, with a weakly supervised + algorithm + """ + nca = ITML(preprocessor=preprocessor) + nca.fit(tuples, y_tuples) + + +@pytest.mark.parametrize('preprocessor, points', zip(preprocessors, + points_list)) +def test_preprocessor_supervised(preprocessor, points, y_points): + """Tests different ways to use the preprocessor argument: an array, + a class callable, and a function callable, with a supervised algorithm + """ + lfda = LFDA(preprocessor=preprocessor) + lfda.fit(points, y_points) + + +@pytest.mark.parametrize('estimator', ['NCA', NCA(), None]) +def test_preprocess_tuples_invalid_message(estimator): + """Checks that if the preprocessor does some weird stuff, the preprocessed + input is detected as weird. Checks this for preprocess_tuples.""" + + context = make_context(estimator) + (' after the preprocessor ' + 'has been applied') + + def preprocessor(sequence): + return np.ones((len(sequence), 2, 2)) # returns a 3D array instead of 2D + + with pytest.raises(ValueError) as raised_error: + check_input(np.ones((3, 2)), type_of_inputs='tuples', + preprocessor=preprocessor, estimator=estimator) + expected_msg = ("3D array of formed tuples expected{}. Found 4D " + "array instead:\ninput={}. Reshape your data{}.\n" + .format(context, np.ones((3, 2, 2, 2)), + ' and/or use a preprocessor' if preprocessor + is not None else '')) + assert str(raised_error.value) == expected_msg + + +@pytest.mark.parametrize('estimator', ['NCA', NCA(), None]) +def test_preprocess_points_invalid_message(estimator): + """Checks that if the preprocessor does some weird stuff, the preprocessed + input is detected as weird.""" + + context = make_context(estimator) + (' after the preprocessor ' + 'has been applied') + + def preprocessor(sequence): + return np.ones((len(sequence), 2, 2)) # returns a 3D array instead of 2D + + with pytest.raises(ValueError) as raised_error: + check_input(np.ones((3,)), type_of_inputs='classic', + preprocessor=preprocessor, estimator=estimator) + expected_msg = ("2D array of formed points expected{}. " + "Found 3D array instead:\ninput={}. Reshape your data{}.\n" + .format(context, np.ones((3, 2, 2)), + ' and/or use a preprocessor' if preprocessor + is not None else '')) + assert str(raised_error.value) == expected_msg + + +def test_preprocessor_error_message(): + """Tests whether the preprocessor returns a preprocessor error when there + is a problem using the preprocessor + """ + preprocessor = ArrayIndexer(np.array([[1.2, 3.3], [3.1, 3.2]])) + + # with tuples + X = np.array([[[2, 3], [3, 3]], [[2, 3], [3, 2]]]) + # There are less samples than the max index we want to preprocess + with pytest.raises(PreprocessorError): + preprocess_tuples(X, preprocessor) + + # with points + X = np.array([[1], [2], [3], [3]]) + with pytest.raises(PreprocessorError): + preprocess_points(X, preprocessor) + + +@pytest.mark.parametrize('input_data', [[[5, 3], [3, 2]], + ((5, 3), (3, 2)) + ]) +@pytest.mark.parametrize('indices', [[0, 1], (1, 0)]) +def test_array_like_indexer_array_like_valid_classic(input_data, indices): + """Checks that any array-like is valid in the 'preprocessor' argument, + and in the indices, for a classic input""" + class MockMetricLearner(MahalanobisMixin): + pass + + mock_algo = MockMetricLearner(preprocessor=input_data) + mock_algo._prepare_inputs(indices, type_of_inputs='classic') + + +@pytest.mark.parametrize('input_data', [[[5, 3], [3, 2]], + ((5, 3), (3, 2)) + ]) +@pytest.mark.parametrize('indices', [[[0, 1], [1, 0]], ((1, 0), (1, 0))]) +def test_array_like_indexer_array_like_valid_tuples(input_data, indices): + """Checks that any array-like is valid in the 'preprocessor' argument, + and in the indices, for a classic input""" + class MockMetricLearner(MahalanobisMixin): + pass + + mock_algo = MockMetricLearner(preprocessor=input_data) + mock_algo._prepare_inputs(indices, type_of_inputs='tuples') + + +@pytest.mark.parametrize('preprocessor', [4, NCA()]) +def test_error_message_check_preprocessor(preprocessor): + """Checks that if the preprocessor given is not an array-like or a + callable, the right error message is returned""" + class MockMetricLearner(MahalanobisMixin): + pass + + mock_algo = MockMetricLearner(preprocessor=preprocessor) + with pytest.raises(ValueError) as e: + mock_algo.check_preprocessor() + assert str(e.value) == ("Invalid type for the preprocessor: {}. You should " + "provide either None, an array-like object, " + "or a callable.".format(type(preprocessor))) + + +@pytest.mark.parametrize('estimator', [ITML(), LSML(), MMC(), SDML()], + ids=['ITML', 'LSML', 'MMC', 'SDML']) +def test_error_message_tuple_size(estimator): + """Tests that if a tuples learner is not given the good number of points + per tuple, it throws an error message""" + estimator = clone(estimator) + set_random_state(estimator) + invalid_pairs = np.array([[[1.3, 6.3], [3., 6.8], [6.5, 4.4]], + [[1.9, 5.3], [1., 7.8], [3.2, 1.2]]]) + y = [1, 1] + with pytest.raises(ValueError) as raised_err: + estimator.fit(invalid_pairs, y) + expected_msg = ("Tuples of {} element(s) expected{}. Got tuples of 3 " + "element(s) instead (shape=(2, 3, 2)):\ninput={}.\n" + .format(estimator._tuple_size, make_context(estimator), + invalid_pairs)) + assert str(raised_err.value) == expected_msg + + +@pytest.mark.parametrize('estimator, _', metric_learners, + ids=ids_metric_learners) +def test_error_message_t_score_pairs(estimator, _): + """tests that if you want to score_pairs on triplets for instance, it returns + the right error message + """ + estimator = clone(estimator) + set_random_state(estimator) + estimator.check_preprocessor() + triplets = np.array([[[1.3, 6.3], [3., 6.8], [6.5, 4.4]], + [[1.9, 5.3], [1., 7.8], [3.2, 1.2]]]) + with pytest.raises(ValueError) as raised_err: + estimator.score_pairs(triplets) + expected_msg = ("Tuples of 2 element(s) expected{}. Got tuples of 3 " + "element(s) instead (shape=(2, 3, 2)):\ninput={}.\n" + .format(make_context(estimator), triplets)) + assert str(raised_err.value) == expected_msg + + +def test_preprocess_tuples_simple_example(): + """Test the preprocessor on a very simple example of tuples to ensure the + result is as expected""" + array = np.array([[1, 2], + [2, 3], + [4, 5]]) + + def fun(row): + return np.array([[1, 1], [3, 3], [4, 4]]) + + expected_result = np.array([[[1, 1], [1, 1]], + [[3, 3], [3, 3]], + [[4, 4], [4, 4]]]) + + assert (preprocess_tuples(array, fun) == expected_result).all() + + +def test_preprocess_points_simple_example(): + """Test the preprocessor on very simple examples of points to ensure the + result is as expected""" + array = np.array([1, 2, 4]) + + def fun(row): + return [[1, 1], [3, 3], [4, 4]] + + expected_result = np.array([[1, 1], + [3, 3], + [4, 4]]) + + assert (preprocess_points(array, fun) == expected_result).all() + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_same_with_or_without_preprocessor(estimator, build_dataset): + """Test that algorithms using a preprocessor behave consistently +# with their no-preprocessor equivalent + """ + dataset_indices = build_dataset(with_preprocessor=True) + dataset_formed = build_dataset(with_preprocessor=False) + X = dataset_indices.preprocessor + indicators_to_transform = dataset_indices.to_transform + formed_points_to_transform = dataset_formed.to_transform + (indices_train, indices_test, y_train, y_test, formed_train, + formed_test) = train_test_split(dataset_indices.data, + dataset_indices.target, + dataset_formed.data, + random_state=SEED) + + def make_random_state(estimator): + rs = {} + if estimator.__class__.__name__[-11:] == '_Supervised': + rs['random_state'] = check_random_state(SEED) + return rs + + estimator_with_preprocessor = clone(estimator) + set_random_state(estimator_with_preprocessor) + estimator_with_preprocessor.set_params(preprocessor=X) + estimator_with_preprocessor.fit(indices_train, y_train, + **make_random_state(estimator)) + + estimator_without_preprocessor = clone(estimator) + set_random_state(estimator_without_preprocessor) + estimator_without_preprocessor.set_params(preprocessor=None) + estimator_without_preprocessor.fit(formed_train, y_train, + **make_random_state(estimator)) + + estimator_with_prep_formed = clone(estimator) + set_random_state(estimator_with_prep_formed) + estimator_with_prep_formed.set_params(preprocessor=X) + estimator_with_prep_formed.fit(indices_train, y_train, + **make_random_state(estimator)) + + # test prediction methods + for method in ["predict", "decision_function"]: + if hasattr(estimator, method): + output_with_prep = getattr(estimator_with_preprocessor, + method)(indices_test) + output_without_prep = getattr(estimator_without_preprocessor, + method)(formed_test) + assert np.array(output_with_prep == output_without_prep).all() + output_with_prep = getattr(estimator_with_preprocessor, + method)(indices_test) + output_with_prep_formed = getattr(estimator_with_prep_formed, + method)(formed_test) + assert np.array(output_with_prep == output_with_prep_formed).all() + + # test score_pairs + output_with_prep = estimator_with_preprocessor.score_pairs( + indicators_to_transform[[[[0, 2], [5, 3]]]]) + output_without_prep = estimator_without_preprocessor.score_pairs( + formed_points_to_transform[[[[0, 2], [5, 3]]]]) + assert np.array(output_with_prep == output_without_prep).all() + + output_with_prep = estimator_with_preprocessor.score_pairs( + indicators_to_transform[[[[0, 2], [5, 3]]]]) + output_without_prep = estimator_with_prep_formed.score_pairs( + formed_points_to_transform[[[[0, 2], [5, 3]]]]) + assert np.array(output_with_prep == output_without_prep).all() + + # test transform + output_with_prep = estimator_with_preprocessor.transform( + indicators_to_transform) + output_without_prep = estimator_without_preprocessor.transform( + formed_points_to_transform) + assert np.array(output_with_prep == output_without_prep).all() + + output_with_prep = estimator_with_preprocessor.transform( + indicators_to_transform) + output_without_prep = estimator_with_prep_formed.transform( + formed_points_to_transform) + assert np.array(output_with_prep == output_without_prep).all() + + +def test_check_collapsed_pairs_raises_no_error(): + """Checks that check_collapsed_pairs raises no error if no collapsed pairs + is present""" + pairs_ok = np.array([[[0.1, 3.3], [3.3, 0.1]], + [[0.1, 3.3], [3.3, 0.1]], + [[2.5, 8.1], [0.1, 3.3]]]) + check_collapsed_pairs(pairs_ok) + + +def test_check_collapsed_pairs_raises_error(): + """Checks that check_collapsed_pairs raises no error if no collapsed pairs + is present""" + pairs_not_ok = np.array([[[0.1, 3.3], [0.1, 3.3]], + [[0.1, 3.3], [3.3, 0.1]], + [[2.5, 8.1], [2.5, 8.1]]]) + with pytest.raises(ValueError) as e: + check_collapsed_pairs(pairs_not_ok) + assert str(e.value) == ("2 collapsed pairs found (where the left element is " + "the same as the right element), out of 3 pairs in" + " total.") diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py deleted file mode 100644 index 8cae4bfc..00000000 --- a/test/test_weakly_supervised.py +++ /dev/null @@ -1,253 +0,0 @@ -import pytest -from sklearn.datasets import load_iris -from sklearn.pipeline import make_pipeline -from sklearn.utils import shuffle, check_random_state -from sklearn.utils.estimator_checks import is_public_parameter -from sklearn.utils.testing import (assert_allclose_dense_sparse, - set_random_state) -from sklearn.utils.fixes import signature - -from metric_learn import ITML, MMC, SDML, LSML -from metric_learn.constraints import wrap_pairs, Constraints -from sklearn import clone -import numpy as np -from sklearn.model_selection import cross_val_score, train_test_split - -RNG = check_random_state(0) - -def build_data(): - dataset = load_iris() - X, y = shuffle(dataset.data, dataset.target, random_state=RNG) - num_constraints = 20 - constraints = Constraints.random_subset(y, random_state=RNG) - pairs = constraints.positive_negative_pairs(num_constraints, - same_length=True, - random_state=RNG) - return X, pairs - - -def build_pairs(): - # test that you can do cross validation on tuples of points with - # a WeaklySupervisedMetricLearner - X, pairs = build_data() - pairs, y = wrap_pairs(X, pairs) - pairs, y = shuffle(pairs, y, random_state=RNG) - (pairs_train, pairs_test, y_train, - y_test) = train_test_split(pairs, y, random_state=RNG) - return (pairs, y, pairs_train, pairs_test, - y_train, y_test) - - -def build_quadruplets(): - # test that you can do cross validation on a tuples of points with - # a WeaklySupervisedMetricLearner - X, pairs = build_data() - c = np.column_stack(pairs) - quadruplets = X[c] - quadruplets = shuffle(quadruplets, random_state=RNG) - y = y_train = y_test = None - quadruplets_train, quadruplets_test = train_test_split(quadruplets, - random_state=RNG) - return (quadruplets, y, quadruplets_train, quadruplets_test, - y_train, y_test) - - -list_estimators = [(ITML(), build_pairs), - (LSML(), build_quadruplets), - (MMC(), build_pairs), - (SDML(), build_pairs) - ] - -ids_estimators = ['itml', - 'lsml', - 'mmc', - 'sdml', - ] - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) -def test_cross_validation(estimator, build_dataset): - (tuples, y, tuples_train, tuples_test, - y_train, y_test) = build_dataset() - estimator = clone(estimator) - set_random_state(estimator) - - assert np.isfinite(cross_val_score(estimator, tuples, y)).all() - - -def check_score(estimator, tuples, y): - score = estimator.score(tuples, y) - assert np.isfinite(score) - - -def check_predict(estimator, tuples): - y_predicted = estimator.predict(tuples) - assert len(y_predicted), len(tuples) - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) -def test_simple_estimator(estimator, build_dataset): - (tuples, y, tuples_train, tuples_test, - y_train, y_test) = build_dataset() - estimator = clone(estimator) - set_random_state(estimator) - - estimator.fit(tuples_train, y_train) - check_score(estimator, tuples_test, y_test) - check_predict(estimator, tuples_test) - - -@pytest.mark.parametrize('estimator', [est[0] for est in list_estimators], - ids=ids_estimators) -def test_no_attributes_set_in_init(estimator): - """Check setting during init. Taken from scikit-learn.""" - estimator = clone(estimator) - if hasattr(type(estimator).__init__, "deprecated_original"): - return - - init_params = _get_args(type(estimator).__init__) - parents_init_params = [param for params_parent in - (_get_args(parent) for parent in - type(estimator).__mro__) - for param in params_parent] - - # Test for no setting apart from parameters during init - invalid_attr = (set(vars(estimator)) - set(init_params) - - set(parents_init_params)) - assert not invalid_attr, \ - ("Estimator %s should not set any attribute apart" - " from parameters during init. Found attributes %s." - % (type(estimator).__name__, sorted(invalid_attr))) - # Ensure that each parameter is set in init - invalid_attr = (set(init_params) - set(vars(estimator)) - - set(["self"])) - assert not invalid_attr, \ - ("Estimator %s should store all parameters" - " as an attribute during init. Did not find " - "attributes %s." % (type(estimator).__name__, sorted(invalid_attr))) - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) -def test_estimators_fit_returns_self(estimator, build_dataset): - """Check if self is returned when calling fit""" - # From scikit-learn - (tuples, y, tuples_train, tuples_test, - y_train, y_test) = build_dataset() - estimator = clone(estimator) - assert estimator.fit(tuples, y) is estimator - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) -def test_pipeline_consistency(estimator, build_dataset): - # From scikit learn - # check that make_pipeline(est) gives same score as est - (tuples, y, tuples_train, tuples_test, - y_train, y_test) = build_dataset() - estimator = clone(estimator) - pipeline = make_pipeline(estimator) - estimator.fit(tuples, y) - pipeline.fit(tuples, y) - - funcs = ["score", "fit_transform"] - - for func_name in funcs: - func = getattr(estimator, func_name, None) - if func is not None: - func_pipeline = getattr(pipeline, func_name) - result = func(tuples, y) - result_pipe = func_pipeline(tuples, y) - assert_allclose_dense_sparse(result, result_pipe) - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) -def test_dict_unchanged(estimator, build_dataset): - # From scikit-learn - (tuples, y, tuples_train, tuples_test, - y_train, y_test) = build_dataset() - estimator = clone(estimator) - if hasattr(estimator, "num_dims"): - estimator.num_dims = 1 - estimator.fit(tuples, y) - for method in ["predict", "decision_function", "predict_proba"]: - if hasattr(estimator, method): - dict_before = estimator.__dict__.copy() - getattr(estimator, method)(tuples) - assert estimator.__dict__ == dict_before, \ - ("Estimator changes __dict__ during %s" - % method) - for method in ["transform"]: - if hasattr(estimator, method): - dict_before = estimator.__dict__.copy() - # we transform only 2D arrays (dataset of points) - getattr(estimator, method)(tuples[:, 0, :]) - assert estimator.__dict__ == dict_before, \ - ("Estimator changes __dict__ during %s" - % method) - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) -def test_dont_overwrite_parameters(estimator, build_dataset): - # From scikit-learn - # check that fit method only changes or sets private attributes - (tuples, y, tuples_train, tuples_test, - y_train, y_test) = build_dataset() - estimator = clone(estimator) - if hasattr(estimator, "num_dims"): - estimator.num_dims = 1 - dict_before_fit = estimator.__dict__.copy() - - estimator.fit(tuples, y) - dict_after_fit = estimator.__dict__ - - public_keys_after_fit = [key for key in dict_after_fit.keys() - if is_public_parameter(key)] - - attrs_added_by_fit = [key for key in public_keys_after_fit - if key not in dict_before_fit.keys()] - - # check that fit doesn't add any public attribute - assert not attrs_added_by_fit, \ - ("Estimator adds public attribute(s) during" - " the fit method." - " Estimators are only allowed to add private " - "attributes" - " either started with _ or ended" - " with _ but %s added" % ', '.join(attrs_added_by_fit)) - - # check that fit doesn't change any public attribute - attrs_changed_by_fit = [key for key in public_keys_after_fit - if (dict_before_fit[key] - is not dict_after_fit[key])] - - assert not attrs_changed_by_fit, \ - ("Estimator changes public attribute(s) during" - " the fit method. Estimators are only allowed" - " to change attributes started" - " or ended with _, but" - " %s changed" % ', '.join(attrs_changed_by_fit)) - - -def _get_args(function, varargs=False): - """Helper to get function arguments""" - - try: - params = signature(function).parameters - except ValueError: - # Error on builtin C function - return [] - args = [key for key, param in params.items() - if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)] - if varargs: - varargs = [param.name for param in params.values() - if param.kind == param.VAR_POSITIONAL] - if len(varargs) == 0: - varargs = None - return args, varargs - else: - return args