diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py index f20e66a20..88da4e867 100644 --- a/imblearn/pipeline.py +++ b/imblearn/pipeline.py @@ -13,42 +13,16 @@ # Guillaume Lemaitre # License: BSD -from __future__ import division, print_function - -from warnings import warn +from __future__ import division from sklearn import pipeline +from sklearn.base import clone from sklearn.externals import six +from sklearn.externals.joblib import Memory from sklearn.utils import tosequence from sklearn.utils.metaestimators import if_delegate_has_method -__all__ = ['Pipeline'] - - -def _validate_step_methods(step): - - if (not (hasattr(step, "fit") or hasattr(step, "fit_transform") or hasattr( - step, "fit_sample")) or - not (hasattr(step, "transform") or hasattr(step, "sample"))): - raise TypeError( - "All intermediate steps of the chain should " - "be estimators that implement fit and transform or sample" - "(but not both) '%s' (type %s) doesn't)" % (step, type(step))) - - -def _validate_step_behaviour(step): - if (hasattr(step, "fit_sample") and hasattr(step, "fit_transform")) or ( - hasattr(step, "sample") and hasattr(step, "transform")): - raise TypeError( - "All intermediate steps of the chain should " - "be estimators that implement fit and transform or sample." - " '%s' implements both)" % (step)) - - -def _validate_step_class(step): - if isinstance(step, pipeline.Pipeline): - raise TypeError( - "All intermediate steps of the chain should not be Pipelines") +__all__ = ['Pipeline', 'make_pipeline'] class Pipeline(pipeline.Pipeline): @@ -58,6 +32,8 @@ class Pipeline(pipeline.Pipeline): Intermediate steps of the pipeline must be transformers or resamplers, that is, they must implement fit, transform and sample methods. The final estimator only needs to implement fit. + The transformers and samplers in the pipeline can be cached using + ``memory`` argument. The purpose of the pipeline is to assemble several steps that can be cross-validated together while setting different parameters. @@ -71,6 +47,17 @@ class Pipeline(pipeline.Pipeline): fit/transform/fit_sample) that are chained, in the order in which they are chained, with the last object an estimator. + memory : Instance of joblib.Memory or string, optional (default=None) + Used to cache the fitted transformers of the pipeline. By default, + no caching is performed. If a string is given, it is the path to + the caching directory. Enabling caching triggers a clone of + the transformers before fitting. Therefore, the transformer + instance given to the pipeline cannot be inspected + directly. Use the attribute ``named_steps`` or ``steps`` to + inspect estimators within the pipeline. Caching the + transformers is advantageous when fitting is time consuming. + + Attributes ---------- named_steps : dict @@ -114,71 +101,147 @@ class Pipeline(pipeline.Pipeline): # BaseEstimator interface - def __init__(self, steps): - names, estimators = zip(*steps) - if len(dict(steps)) != len(steps): - raise ValueError("Provided step names are not unique: %s" % - (names, )) - + def __init__(self, steps, memory=None): # shallow copy of steps self.steps = tosequence(steps) - transforms = estimators[:-1] + self._validate_steps() + self.memory = memory + + def _validate_steps(self): + names, estimators = zip(*self.steps) + + # validate names + self._validate_names(names) + + # validate estimators + transformers = estimators[:-1] estimator = estimators[-1] - for t in transforms: + for t in transformers: if t is None: continue - _validate_step_methods(t) - _validate_step_behaviour(t) - _validate_step_class(t) - - if not hasattr(estimator, "fit"): - raise TypeError("Last step of chain should implement fit " - "'%s' (type %s) doesn't)" % - (estimator, type(estimator))) + if (not (hasattr(t, "fit") or + hasattr(t, "fit_transform") or + hasattr(t, "fit_sample")) or + not (hasattr(t, "transform") or + hasattr(t, "sample"))): + raise TypeError( + "All intermediate steps of the chain should " + "be estimators that implement fit and transform or sample " + "(but not both) '%s' (type %s) doesn't)" % (t, type(t))) + + if ((hasattr(t, "fit_sample") and + hasattr(t, "fit_transform")) or + (hasattr(t, "sample") and + hasattr(t, "transform"))): + raise TypeError( + "All intermediate steps of the chain should " + "be estimators that implement fit and transform or sample." + " '%s' implements both)" % (t)) + + if isinstance(t, pipeline.Pipeline): + raise TypeError( + "All intermediate steps of the chain should not be" + " Pipelines") + + # We allow last estimator to be None as an identity transformation + if estimator is not None and not hasattr(estimator, "fit"): + raise TypeError("Last step of Pipeline should implement fit. " + "'%s' (type %s) doesn't" + % (estimator, type(estimator))) # Estimator interface - def _pre_transform(self, X, y=None, **fit_params): - fit_params_steps = dict((step, {}) for step, _ in self.steps) + def _fit(self, X, y=None, **fit_params): + self._validate_steps() + # Setup the memory + memory = self.memory + if memory is None: + memory = Memory(cachedir=None, verbose=0) + elif isinstance(memory, six.string_types): + memory = Memory(cachedir=memory, verbose=0) + elif not isinstance(memory, Memory): + raise ValueError("'memory' should either be a string or" + " a joblib.Memory instance, got" + " 'memory={!r}' instead.".format(memory)) + + fit_transform_one_cached = memory.cache(_fit_transform_one) + fit_sample_one_cached = memory.cache(_fit_sample_one) + + fit_params_steps = dict((name, {}) for name, step in self.steps + if step is not None) for pname, pval in six.iteritems(fit_params): step, param = pname.split('__', 1) fit_params_steps[step][param] = pval Xt = X yt = y - for name, transform in self.steps[:-1]: - if transform is None: - continue - if hasattr(transform, "fit_transform"): - Xt = transform.fit_transform(Xt, yt, **fit_params_steps[name]) - elif hasattr(transform, "fit_sample"): - Xt, yt = transform.fit_sample(Xt, yt, **fit_params_steps[name]) + for step_idx, (name, transformer) in enumerate(self.steps[:-1]): + if transformer is None: + pass else: - Xt = transform.fit(Xt, yt, **fit_params_steps[name]) \ - .transform(Xt) + if memory.cachedir is None: + # we do not clone when caching is disabled to preserve + # backward compatibility + cloned_transformer = transformer + else: + cloned_transformer = clone(transformer) + # Fit or load from cache the current transfomer + if (hasattr(cloned_transformer, "transform") or + hasattr(cloned_transformer, "fit_transform")): + Xt, fitted_transformer = fit_transform_one_cached( + cloned_transformer, None, Xt, yt, + **fit_params_steps[name]) + elif hasattr(cloned_transformer, "sample"): + Xt, yt, fitted_transformer = fit_sample_one_cached( + cloned_transformer, Xt, yt, + **fit_params_steps[name]) + # Replace the transformer of the step with the fitted + # transformer. This is necessary when loading the transformer + # from the cache. + self.steps[step_idx] = (name, fitted_transformer) + if self._final_estimator is None: + return Xt, yt, {} return Xt, yt, fit_params_steps[self.steps[-1][0]] def fit(self, X, y=None, **fit_params): - """Fit all the transforms and samples one after the other and transform - the data, then fit the transformed data using the final estimator. + """Fit the model + + Fit all the transforms/samplers one after the other and + transform/sample the data, then fit the transformed/sampled + data using the final estimator. Parameters ---------- X : iterable Training data. Must fulfill input requirements of first step of the pipeline. + y : iterable, default=None Training targets. Must fulfill label requirements for all steps of the pipeline. + + **fit_params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Returns + ------- + self : Pipeline + This estimator + """ - Xt, yt, fit_params = self._pre_transform(X, y, **fit_params) - self.steps[-1][-1].fit(Xt, yt, **fit_params) + Xt, yt, fit_params = self._fit(X, y, **fit_params) + if self._final_estimator is not None: + self._final_estimator.fit(Xt, yt, **fit_params) return self def fit_transform(self, X, y=None, **fit_params): - """Fit all the transforms and samples one after the other and - transform or sample the data, then use fit_transform on - transformed data using the final estimator. + """Fit the model and transform with the final estimator + + Fits all the transformers/samplers one after the other and + transform/sample the data, then uses fit_transform on + transformed data with the final estimator. Parameters ---------- @@ -189,18 +252,34 @@ def fit_transform(self, X, y=None, **fit_params): y : iterable, default=None Training targets. Must fulfill label requirements for all steps of the pipeline. + + **fit_params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Returns + ------- + Xt : array-like, shape = [n_samples, n_transformed_features] + Transformed samples + """ - Xt, yt, fit_params = self._pre_transform(X, y, **fit_params) - if hasattr(self.steps[-1][-1], 'fit_transform'): - return self.steps[-1][-1].fit_transform(Xt, yt, **fit_params) + last_step = self._final_estimator + Xt, yt, fit_params = self._fit(X, y, **fit_params) + if last_step is None: + return Xt + elif hasattr(last_step, 'fit_transform'): + return last_step.fit_transform(Xt, yt, **fit_params) else: - return self.steps[-1][-1].fit(Xt, yt, **fit_params).transform(Xt) + return last_step.fit(Xt, yt, **fit_params).transform(Xt) @if_delegate_has_method(delegate='_final_estimator') def fit_sample(self, X, y=None, **fit_params): - """Fit all the transforms and samples one after the other and - transform or sample the data, then use fit_sample on - transformed data using the final estimator. + """Fit the model and sample with the final estimator + + Fits all the transformers/samplers one after the other and + transform/sample the data, then uses fit_sample on transformed + data with the final estimator. Parameters ---------- @@ -211,24 +290,45 @@ def fit_sample(self, X, y=None, **fit_params): y : iterable, default=None Training targets. Must fulfill label requirements for all steps of the pipeline. + + **fit_params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Returns + ------- + Xt : array-like, shape = [n_samples, n_transformed_features] + Transformed samples + + yt : array-like, shape = [n_samples, n_transformed_features] + Transformed target + """ - Xt, yt, fit_params = self._pre_transform(X, y, **fit_params) - return self.steps[-1][-1].fit_sample(Xt, yt, **fit_params) + last_step = self._final_estimator + Xt, yt, fit_params = self._fit(X, y, **fit_params) + if last_step is None: + return Xt + elif hasattr(last_step, 'fit_sample'): + return last_step.fit_sample(Xt, yt, **fit_params) @if_delegate_has_method(delegate='_final_estimator') def sample(self, X, y): - """Applies transforms to the data, and the sample method of - the final estimator. Valid only if the final estimator - implements sample. + """Sample the data with the final estimator + + Applies transformers/samplers to the data, and the sample + method of the final estimator. Valid only if the final + estimator implements sample. Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. + """ Xt = X - for _, transform in self.steps[:-1]: + for name, transform in self.steps[:-1]: if transform is None: continue if hasattr(transform, "fit_sample"): @@ -243,15 +343,19 @@ def sample(self, X, y): @if_delegate_has_method(delegate='_final_estimator') def predict(self, X): - """Applies transforms to the data, and the predict method of - the final estimator. Valid only if the final estimator - implements predict. + """Apply transformers/samplers to the data, and predict with the final + estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. + + Returns + ------- + y_pred : array-like + """ Xt = X for _, transform in self.steps[:-1]: @@ -265,36 +369,49 @@ def predict(self, X): @if_delegate_has_method(delegate='_final_estimator') def fit_predict(self, X, y=None, **fit_params): - """Applies fit_predict of last step in pipeline after transforms - and samples. + """Applies fit_predict of last step in pipeline after transforms. - Applies fit_transforms or fit_samples of a pipeline to the data, - followed by the fit_predict method of the final estimator in the - pipeline. Valid only if the final estimator implements fit_predict. + Applies fit_transforms of a pipeline to the data, followed by the + fit_predict method of the final estimator in the pipeline. Valid + only if the final estimator implements fit_predict. Parameters ---------- X : iterable Training data. Must fulfill input requirements of first step of the pipeline. + y : iterable, default=None Training targets. Must fulfill label requirements for all steps of the pipeline. + + **fit_params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Returns + ------- + y_pred : array-like """ - Xt, yt, fit_params = self._pre_transform(X, y, **fit_params) + Xt, yt, fit_params = self._fit(X, y, **fit_params) return self.steps[-1][-1].fit_predict(Xt, yt, **fit_params) @if_delegate_has_method(delegate='_final_estimator') def predict_proba(self, X): - """Applies transforms to the data, and the predict_proba method of the - final estimator. Valid only if the final estimator implements - predict_proba. + """Apply transformers/samplers, and predict_proba of the final + estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. + + Returns + ------- + y_proba : array-like, shape = [n_samples, n_classes] + """ Xt = X for _, transform in self.steps[:-1]: @@ -308,15 +425,19 @@ def predict_proba(self, X): @if_delegate_has_method(delegate='_final_estimator') def decision_function(self, X): - """Applies transforms to the data, and the decision_function method of - the final estimator. Valid only if the final estimator implements - decision_function. + """Apply transformers/samplers, and decision_function of the final + estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. + + Returns + ------- + y_score : array-like, shape = [n_samples, n_classes] + """ Xt = X for _, transform in self.steps[:-1]: @@ -330,15 +451,19 @@ def decision_function(self, X): @if_delegate_has_method(delegate='_final_estimator') def predict_log_proba(self, X): - """Applies transforms to the data, and the predict_log_proba method of - the final estimator. Valid only if the final estimator implements - predict_log_proba. + """Apply transformers/samplers, and predict_log_proba of the final + estimator Parameters ---------- X : iterable Data to predict on. Must fulfill input requirements of first step of the pipeline. + + Returns + ------- + y_score : array-like, shape = [n_samples, n_classes] + """ Xt = X for _, transform in self.steps[:-1]: @@ -350,20 +475,31 @@ def predict_log_proba(self, X): Xt = transform.transform(Xt) return self.steps[-1][-1].predict_log_proba(Xt) - @if_delegate_has_method(delegate='_final_estimator') - def transform(self, X): - """Applies transforms to the data, and the transform method of the - final estimator. Valid only if the final estimator implements - transform. + @property + def transform(self): + """Apply transformers/samplers, and transform with the final estimator + + This also works where final estimator is ``None``: all prior + transformations are applied. Parameters ---------- X : iterable - Data to predict on. Must fulfill input requirements of first step + Data to transform. Must fulfill input requirements of first step of the pipeline. + + Returns + ------- + Xt : array-like, shape = [n_samples, n_transformed_features] """ + # _final_estimator is None or has transform, otherwise attribute error + if self._final_estimator is not None: + self._final_estimator.transform + return self._transform + + def _transform(self, X): Xt = X - for _, transform in self.steps: + for name, transform in self.steps: if transform is None: continue if hasattr(transform, "fit_sample"): @@ -372,48 +508,62 @@ def transform(self, X): Xt = transform.transform(Xt) return Xt - @if_delegate_has_method(delegate='_final_estimator') - def inverse_transform(self, X): - """Applies inverse transform to the data. - Starts with the last step of the pipeline and applies - ``inverse_transform`` in inverse order of the pipeline steps. - Valid only if all steps of the pipeline implement inverse_transform. + @property + def inverse_transform(self): + """Apply inverse transformations in reverse order + + All estimators in the pipeline must support ``inverse_transform``. Parameters ---------- - X : iterable - Data to inverse transform. Must fulfill output requirements of the - last step of the pipeline. + Xt : array-like, shape = [n_samples, n_transformed_features] + Data samples, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. Must fulfill + input requirements of last step of pipeline's + ``inverse_transform`` method. + + Returns + ------- + Xt : array-like, shape = [n_samples, n_features] """ - if X.ndim == 1: - warn("From version 0.19, a 1d X will not be reshaped in" - " pipeline.inverse_transform any more.", FutureWarning) - X = X[None, :] + # raise AttributeError if necessary for hasattr behaviour + for name, transform in self.steps: + if transform is not None: + transform.inverse_transform + return self._inverse_transform + + def _inverse_transform(self, X): Xt = X - for _, step in self.steps[::-1]: - if step is None: + for name, transform in self.steps[::-1]: + if transform is None: continue - if hasattr(step, "fit_sample"): + if hasattr(transform, "fit_sample"): pass else: - Xt = step.inverse_transform(Xt) + Xt = transform.inverse_transform(Xt) return Xt @if_delegate_has_method(delegate='_final_estimator') - def score(self, X, y=None): - """Applies transforms to the data, and the score method of the - final estimator. Valid only if the final estimator implements - score. + def score(self, X, y=None, sample_weight=None): + """Apply transformers/samplers, and score with the final estimator Parameters ---------- X : iterable - Data to score. Must fulfill input requirements of first step of the - pipeline. + Data to predict on. Must fulfill input requirements of first step + of the pipeline. y : iterable, default=None Targets used for scoring. Must fulfill label requirements for all steps of the pipeline. + + sample_weight : array-like, default=None + If not None, this argument is passed as ``sample_weight`` keyword + argument to the ``score`` method of the final estimator. + + Returns + ------- + score : float """ Xt = X for _, transform in self.steps[:-1]: @@ -423,7 +573,28 @@ def score(self, X, y=None): pass else: Xt = transform.transform(Xt) - return self.steps[-1][-1].score(Xt, y) + score_params = {} + if sample_weight is not None: + score_params['sample_weight'] = sample_weight + return self.steps[-1][-1].score(Xt, y, **score_params) + + +def _fit_transform_one(transformer, weight, X, y, + **fit_params): + if hasattr(transformer, 'fit_transform'): + res = transformer.fit_transform(X, y, **fit_params) + else: + res = transformer.fit(X, y, **fit_params).transform(X) + # if we have a weight for this transformer, multiply output + if weight is None: + return res, transformer + return res * weight, transformer + + +def _fit_sample_one(sampler, X, y, **fit_params): + X_res, y_res = sampler.fit_sample(X, y, **fit_params) + + return X_res, y_res, sampler def make_pipeline(*steps): diff --git a/imblearn/tests/test_pipeline.py b/imblearn/tests/test_pipeline.py index 9b7385876..3b3df4874 100644 --- a/imblearn/tests/test_pipeline.py +++ b/imblearn/tests/test_pipeline.py @@ -1,20 +1,33 @@ """ Test the pipeline module. """ + +from tempfile import mkdtemp +import shutil +import time + import numpy as np -from numpy.testing import assert_array_almost_equal -from sklearn.base import clone -from sklearn.cluster import KMeans -from sklearn.datasets import load_iris, make_classification +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_raises_regex +from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_false +from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_dict_equal +from sklearn.utils.testing import assert_allclose + +from sklearn.base import clone, BaseEstimator +from sklearn.svm import SVC from sklearn.decomposition import PCA +from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LinearRegression +from sklearn.cluster import KMeans from sklearn.feature_selection import SelectKBest, f_classif -from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.datasets import load_iris, make_classification from sklearn.preprocessing import StandardScaler -from sklearn.svm import SVC -from sklearn.utils.testing import ( - assert_allclose, assert_array_equal, assert_equal, assert_false, - assert_raise_message, assert_raises, assert_raises_regex, assert_true, - assert_warns_message) +from sklearn.externals.joblib import Memory from imblearn.pipeline import Pipeline, make_pipeline from imblearn.under_sampling import (RandomUnderSampler, @@ -26,11 +39,13 @@ "the the pizza beer beer copyright", "the burger beer beer copyright", "the coke burger coke copyright", - "the coke burger burger", ) + "the coke burger burger", +) + R_TOL = 1e-4 -class IncorrectT(object): +class NoFit(object): """Small class to test parameter dispatching. """ @@ -39,7 +54,7 @@ def __init__(self, a=None, b=None): self.b = b -class T(IncorrectT): +class NoTrans(NoFit): def fit(self, X, y): return self @@ -52,8 +67,12 @@ def set_params(self, **params): return self -class TransfT(T): +class NoInvTransf(NoTrans): + def transform(self, X, y=None): + return X + +class Transf(NoInvTransf): def transform(self, X, y=None): return X @@ -61,7 +80,36 @@ def inverse_transform(self, X): return X -class FitParamT(object): +class TransfFitParams(Transf): + + def fit(self, X, y, **fit_params): + self.fit_params = fit_params + return self + + +class Mult(BaseEstimator): + def __init__(self, mult=1): + self.mult = mult + + def fit(self, X, y): + return self + + def transform(self, X): + return np.asarray(X) * self.mult + + def inverse_transform(self, X): + return np.asarray(X) / self.mult + + def predict(self, X): + return (np.asarray(X) * self.mult).sum(axis=1) + + predict_proba = predict_log_proba = decision_function = predict + + def score(self, X, y=None): + return np.sum(X) + + +class FitParamT(BaseEstimator): """Mock classifier """ @@ -74,9 +122,46 @@ def fit(self, X, y, should_succeed=False): def predict(self, X): return self.successful + def fit_predict(self, X, y, should_succeed=False): + self.fit(X, y, should_succeed=should_succeed) + return self.predict(X) -class FitTransformSample(T): - """Mock classifier + def score(self, X, y=None, sample_weight=None): + if sample_weight is not None: + X = X * sample_weight + return np.sum(X) + + +class DummyTransf(Transf): + """Transformer which store the column means""" + + def fit(self, X, y): + self.means_ = np.mean(X, axis=0) + # store timestamp to figure out whether the result of 'fit' has been + # cached or not + self.timestamp_ = time.time() + return self + + +class DummySampler(NoTrans): + """Samplers which returns a balanced number of samples""" + + def fit(self, X, y): + self.means_ = np.mean(X, axis=0) + # store timestamp to figure out whether the result of 'fit' has been + # cached or not + self.timestamp_ = time.time() + return self + + def sample(self, X, y): + return X, y + + def fit_sample(self, X, y): + return self.fit(X, y).sample(X, y) + + +class FitTransformSample(NoTrans): + """Estimator implementing both transform and sample """ def fit(self, X, y, should_succeed=False): @@ -94,14 +179,16 @@ def test_pipeline_init(): assert_raises(TypeError, Pipeline) # Check that we can't instantiate pipelines with objects without fit # method - pipe = assert_raises(TypeError, Pipeline, [('svc', IncorrectT)]) + assert_raises_regex(TypeError, + 'Last step of Pipeline should implement fit. ' + '.*NoFit.*', + Pipeline, [('clf', NoFit())]) # Smoke test with only an estimator - clf = T() + clf = NoTrans() pipe = Pipeline([('svc', clf)]) - assert_equal( - pipe.get_params(deep=True), - dict( - svc__a=None, svc__b=None, svc=clf, **pipe.get_params(deep=False))) + assert_equal(pipe.get_params(deep=True), + dict(svc__a=None, svc__b=None, svc=clf, + **pipe.get_params(deep=False))) # Check that params are set pipe.set_params(svc__a=0.1) @@ -115,8 +202,11 @@ def test_pipeline_init(): filter1 = SelectKBest(f_classif) pipe = Pipeline([('anova', filter1), ('svc', clf)]) - # Check that we can't use the same stage name twice - assert_raises(ValueError, Pipeline, [('svc', SVC()), ('svc', SVC())]) + # Check that we can't instantiate with non-transformers on the way + # Note that NoTrans implements fit, but not transform + assert_raises_regex(TypeError, + 'implement fit and transform or sample', + Pipeline, [('t', NoTrans()), ('svc', clf)]) # Check that params are set pipe.set_params(svc__C=0.1) @@ -167,13 +257,44 @@ def test_pipeline_methods_anova(): def test_pipeline_fit_params(): # Test that the pipeline can take fit parameters - pipe = Pipeline([('transf', TransfT()), ('clf', FitParamT())]) + pipe = Pipeline([('transf', Transf()), ('clf', FitParamT())]) pipe.fit(X=None, y=None, clf__should_succeed=True) # classifier should return True assert_true(pipe.predict(None)) # and transformer params should not be changed assert_true(pipe.named_steps['transf'].a is None) assert_true(pipe.named_steps['transf'].b is None) + # invalid parameters should raise an error message + assert_raise_message( + TypeError, + "fit() got an unexpected keyword argument 'bad'", + pipe.fit, None, None, clf__bad=True + ) + + +def test_pipeline_sample_weight_supported(): + # Pipeline should pass sample_weight + X = np.array([[1, 2]]) + pipe = Pipeline([('transf', Transf()), ('clf', FitParamT())]) + pipe.fit(X, y=None) + assert_equal(pipe.score(X), 3) + assert_equal(pipe.score(X, y=None), 3) + assert_equal(pipe.score(X, y=None, sample_weight=None), 3) + assert_equal(pipe.score(X, sample_weight=np.array([2, 3])), 8) + + +def test_pipeline_sample_weight_unsupported(): + # When sample_weight is None it shouldn't be passed + X = np.array([[1, 2]]) + pipe = Pipeline([('transf', Transf()), ('clf', Mult())]) + pipe.fit(X, y=None) + assert_equal(pipe.score(X), 3) + assert_equal(pipe.score(X, sample_weight=None), 3) + assert_raise_message( + TypeError, + "score() got an unexpected keyword argument 'sample_weight'", + pipe.score, X, sample_weight=np.array([2, 3]) + ) def test_pipeline_raise_set_params_error(): @@ -185,18 +306,16 @@ def test_pipeline_raise_set_params_error(): 'Check the list of available parameters ' 'with `estimator.get_params().keys()`.') - assert_raise_message( - ValueError, - error_msg % ('fake', 'Pipeline'), - pipe.set_params, - fake='nope') + assert_raise_message(ValueError, + error_msg % ('fake', 'Pipeline'), + pipe.set_params, + fake='nope') # nested model check - assert_raise_message( - ValueError, - error_msg % ("fake", pipe), - pipe.set_params, - fake__estimator='nope') + assert_raise_message(ValueError, + error_msg % ("fake", pipe), + pipe.set_params, + fake__estimator='nope') def test_pipeline_methods_pca_svm(): @@ -206,7 +325,7 @@ def test_pipeline_methods_pca_svm(): y = iris.target # Test with PCA + SVC clf = SVC(probability=True, random_state=0) - pca = PCA() + pca = PCA(svd_solver='full', n_components='mle', whiten=True) pipe = Pipeline([('pca', pca), ('svc', clf)]) pipe.fit(X, y) pipe.predict(X) @@ -223,7 +342,7 @@ def test_pipeline_methods_preprocessing_svm(): n_samples = X.shape[0] n_classes = len(np.unique(y)) scaler = StandardScaler() - pca = PCA(n_components=2) + pca = PCA(n_components=2, svd_solver='randomized', whiten=True) clf = SVC(probability=True, random_state=0, decision_function_shape='ovr') for preprocessing in [scaler, pca]: @@ -232,7 +351,7 @@ def test_pipeline_methods_preprocessing_svm(): # check shapes of various prediction functions predict = pipe.predict(X) - assert_equal(predict.shape, (n_samples, )) + assert_equal(predict.shape, (n_samples,)) proba = pipe.predict_proba(X) assert_equal(proba.shape, (n_samples, n_classes)) @@ -253,27 +372,47 @@ def test_fit_predict_on_pipeline(): iris = load_iris() scaler = StandardScaler() km = KMeans(random_state=0) + # As pipeline doesn't clone estimators on construction, + # it must have its own estimators + scaler_for_pipeline = StandardScaler() + km_for_pipeline = KMeans(random_state=0) # first compute the transform and clustering step separately scaled = scaler.fit_transform(iris.data) separate_pred = km.fit_predict(scaled) # use a pipeline to do the transform and clustering in one step - pipe = Pipeline([('scaler', scaler), ('Kmeans', km)]) + pipe = Pipeline([ + ('scaler', scaler_for_pipeline), + ('Kmeans', km_for_pipeline) + ]) pipeline_pred = pipe.fit_predict(iris.data) - assert_allclose(pipeline_pred, separate_pred, rtol=R_TOL) + assert_array_almost_equal(pipeline_pred, separate_pred) def test_fit_predict_on_pipeline_without_fit_predict(): # tests that a pipeline does not have fit_predict method when final # step of pipeline does not have fit_predict defined scaler = StandardScaler() - pca = PCA() + pca = PCA(svd_solver='full') pipe = Pipeline([('scaler', scaler), ('pca', pca)]) assert_raises_regex(AttributeError, - "'PCA' object has no attribute 'fit_predict'", getattr, - pipe, 'fit_predict') + "'PCA' object has no attribute 'fit_predict'", + getattr, pipe, 'fit_predict') + + +def test_fit_predict_with_intermediate_fit_params(): + # tests that Pipeline passes fit_params to intermediate steps + # when fit_predict is invoked + pipe = Pipeline([('transf', TransfFitParams()), ('clf', FitParamT())]) + pipe.fit_predict(X=None, + y=None, + transf__should_get_this=True, + clf__should_succeed=True) + assert_true(pipe.named_steps['transf'].fit_params['should_get_this']) + assert_true(pipe.named_steps['clf'].successful) + assert_false('should_succeed' in pipe.named_steps['transf'].fit_params) def test_pipeline_transform(): @@ -281,19 +420,19 @@ def test_pipeline_transform(): # Also test pipeline.transform and pipeline.inverse_transform iris = load_iris() X = iris.data - pca = PCA(n_components=2) + pca = PCA(n_components=2, svd_solver='full') pipeline = Pipeline([('pca', pca)]) # test transform and fit_transform: X_trans = pipeline.fit(X).transform(X) X_trans2 = pipeline.fit_transform(X) X_trans3 = pca.fit_transform(X) - assert_allclose(X_trans, X_trans2, rtol=R_TOL) - assert_allclose(X_trans, X_trans3, rtol=R_TOL) + assert_array_almost_equal(X_trans, X_trans2) + assert_array_almost_equal(X_trans, X_trans3) X_back = pipeline.inverse_transform(X_trans) X_back2 = pca.inverse_transform(X_trans) - assert_allclose(X_back, X_back2, rtol=R_TOL) + assert_array_almost_equal(X_back, X_back2) def test_pipeline_fit_transform(): @@ -301,28 +440,152 @@ def test_pipeline_fit_transform(): iris = load_iris() X = iris.data y = iris.target - transft = TransfT() - pipeline = Pipeline([('mock', transft)]) + transf = Transf() + pipeline = Pipeline([('mock', transf)]) # test fit_transform: X_trans = pipeline.fit_transform(X, y) - X_trans2 = transft.fit(X, y).transform(X) - assert_allclose(X_trans, X_trans2, rtol=R_TOL) + X_trans2 = transf.fit(X, y).transform(X) + assert_array_almost_equal(X_trans, X_trans2) + + +def test_set_pipeline_steps(): + transf1 = Transf() + transf2 = Transf() + pipeline = Pipeline([('mock', transf1)]) + assert_true(pipeline.named_steps['mock'] is transf1) + + # Directly setting attr + pipeline.steps = [('mock2', transf2)] + assert_true('mock' not in pipeline.named_steps) + assert_true(pipeline.named_steps['mock2'] is transf2) + assert_equal([('mock2', transf2)], pipeline.steps) + + # Using set_params + pipeline.set_params(steps=[('mock', transf1)]) + assert_equal([('mock', transf1)], pipeline.steps) + + # Using set_params to replace single step + pipeline.set_params(mock=transf2) + assert_equal([('mock', transf2)], pipeline.steps) + + # With invalid data + pipeline.set_params(steps=[('junk', ())]) + assert_raises(TypeError, pipeline.fit, [[1]], [1]) + assert_raises(TypeError, pipeline.fit_transform, [[1]], [1]) + + +def test_set_pipeline_step_none(): + # Test setting Pipeline steps to None + X = np.array([[1]]) + y = np.array([1]) + mult2 = Mult(mult=2) + mult3 = Mult(mult=3) + mult5 = Mult(mult=5) + + def make(): + return Pipeline([('m2', mult2), ('m3', mult3), ('last', mult5)]) + + pipeline = make() + + exp = 2 * 3 * 5 + assert_array_equal([[exp]], pipeline.fit_transform(X, y)) + assert_array_equal([exp], pipeline.fit(X).predict(X)) + assert_array_equal(X, pipeline.inverse_transform([[exp]])) + + pipeline.set_params(m3=None) + exp = 2 * 5 + assert_array_equal([[exp]], pipeline.fit_transform(X, y)) + assert_array_equal([exp], pipeline.fit(X).predict(X)) + assert_array_equal(X, pipeline.inverse_transform([[exp]])) + assert_dict_equal(pipeline.get_params(deep=True), + {'steps': pipeline.steps, + 'm2': mult2, + 'm3': None, + 'last': mult5, + 'memory': None, + 'm2__mult': 2, + 'last__mult': 5, + }) + + pipeline.set_params(m2=None) + exp = 5 + assert_array_equal([[exp]], pipeline.fit_transform(X, y)) + assert_array_equal([exp], pipeline.fit(X).predict(X)) + assert_array_equal(X, pipeline.inverse_transform([[exp]])) + + # for other methods, ensure no AttributeErrors on None: + other_methods = ['predict_proba', 'predict_log_proba', + 'decision_function', 'transform', 'score'] + for method in other_methods: + getattr(pipeline, method)(X) + + pipeline.set_params(m2=mult2) + exp = 2 * 5 + assert_array_equal([[exp]], pipeline.fit_transform(X, y)) + assert_array_equal([exp], pipeline.fit(X).predict(X)) + assert_array_equal(X, pipeline.inverse_transform([[exp]])) + + pipeline = make() + pipeline.set_params(last=None) + # mult2 and mult3 are active + exp = 6 + pipeline.fit(X, y) + pipeline.transform(X) + assert_array_equal([[exp]], pipeline.fit(X, y).transform(X)) + assert_array_equal([[exp]], pipeline.fit_transform(X, y)) + assert_array_equal(X, pipeline.inverse_transform([[exp]])) + assert_raise_message(AttributeError, + "'NoneType' object has no attribute 'predict'", + getattr, pipeline, 'predict') + + # Check None step at construction time + exp = 2 * 5 + pipeline = Pipeline([('m2', mult2), ('m3', None), ('last', mult5)]) + assert_array_equal([[exp]], pipeline.fit_transform(X, y)) + assert_array_equal([exp], pipeline.fit(X).predict(X)) + assert_array_equal(X, pipeline.inverse_transform([[exp]])) + + +def test_pipeline_ducktyping(): + pipeline = make_pipeline(Mult(5)) + pipeline.predict + pipeline.transform + pipeline.inverse_transform + + pipeline = make_pipeline(Transf()) + assert_false(hasattr(pipeline, 'predict')) + pipeline.transform + pipeline.inverse_transform + + pipeline = make_pipeline(None) + assert_false(hasattr(pipeline, 'predict')) + pipeline.transform + pipeline.inverse_transform + + pipeline = make_pipeline(Transf(), NoInvTransf()) + assert_false(hasattr(pipeline, 'predict')) + pipeline.transform + assert_false(hasattr(pipeline, 'inverse_transform')) + + pipeline = make_pipeline(NoInvTransf(), Transf()) + assert_false(hasattr(pipeline, 'predict')) + pipeline.transform + assert_false(hasattr(pipeline, 'inverse_transform')) def test_make_pipeline(): - t1 = TransfT() - t2 = TransfT() - + t1 = Transf() + t2 = Transf() pipe = make_pipeline(t1, t2) assert_true(isinstance(pipe, Pipeline)) - assert_equal(pipe.steps[0][0], "transft-1") - assert_equal(pipe.steps[1][0], "transft-2") + assert_equal(pipe.steps[0][0], "transf-1") + assert_equal(pipe.steps[1][0], "transf-2") pipe = make_pipeline(t1, t2, FitParamT()) assert_true(isinstance(pipe, Pipeline)) - assert_equal(pipe.steps[0][0], "transft-1") - assert_equal(pipe.steps[1][0], "transft-2") + assert_equal(pipe.steps[0][0], "transf-1") + assert_equal(pipe.steps[1][0], "transf-2") assert_equal(pipe.steps[2][0], "fitparamt") @@ -341,12 +604,151 @@ def test_classes_property(): assert_array_equal(clf.classes_, np.unique(y)) -def test_X1d_inverse_transform(): - transformer = TransfT() - pipeline = make_pipeline(transformer) - X = np.ones(10) - msg = "1d X will not be reshaped in pipeline.inverse_transform" - assert_warns_message(FutureWarning, msg, pipeline.inverse_transform, X) +def test_pipeline_wrong_memory(): + # Test that an error is raised when memory is not a string or a Memory + # instance + iris = load_iris() + X = iris.data + y = iris.target + # Define memory as an integer + memory = 1 + cached_pipe = Pipeline([('transf', DummyTransf()), ('svc', SVC())], + memory=memory) + assert_raises_regex(ValueError, "'memory' should either be a string or a" + " joblib.Memory instance, got 'memory=1' instead.", + cached_pipe.fit, X, y) + + +def test_pipeline_memory_transformer(): + iris = load_iris() + X = iris.data + y = iris.target + cachedir = mkdtemp() + try: + memory = Memory(cachedir=cachedir, verbose=10) + # Test with Transformer + SVC + clf = SVC(probability=True, random_state=0) + transf = DummyTransf() + pipe = Pipeline([('transf', clone(transf)), ('svc', clf)]) + cached_pipe = Pipeline([('transf', transf), ('svc', clf)], + memory=memory) + + # Memoize the transformer at the first fit + cached_pipe.fit(X, y) + pipe.fit(X, y) + # Get the time stamp of the tranformer in the cached pipeline + ts = cached_pipe.named_steps['transf'].timestamp_ + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe.predict(X)) + assert_array_equal(pipe.predict_proba(X), cached_pipe.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe.named_steps['transf'].means_) + assert_false(hasattr(transf, 'means_')) + # Check that we are reading the cache while fitting + # a second time + cached_pipe.fit(X, y) + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe.predict(X)) + assert_array_equal(pipe.predict_proba(X), cached_pipe.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe.named_steps['transf'].means_) + assert_equal(ts, cached_pipe.named_steps['transf'].timestamp_) + # Create a new pipeline with cloned estimators + # Check that even changing the name step does not affect the cache hit + clf_2 = SVC(probability=True, random_state=0) + transf_2 = DummyTransf() + cached_pipe_2 = Pipeline([('transf_2', transf_2), ('svc', clf_2)], + memory=memory) + cached_pipe_2.fit(X, y) + + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe_2.predict(X)) + assert_array_equal(pipe.predict_proba(X), + cached_pipe_2.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe_2.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe_2.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe_2.named_steps['transf_2'].means_) + assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_) + finally: + shutil.rmtree(cachedir) + + +def test_pipeline_memory_sampler(): + X, y = make_classification( + n_classes=2, + class_sep=2, + weights=[0.1, 0.9], + n_informative=3, + n_redundant=1, + flip_y=0, + n_features=20, + n_clusters_per_class=1, + n_samples=5000, + random_state=0) + cachedir = mkdtemp() + try: + memory = Memory(cachedir=cachedir, verbose=10) + # Test with Transformer + SVC + clf = SVC(probability=True, random_state=0) + transf = DummySampler() + pipe = Pipeline([('transf', clone(transf)), ('svc', clf)]) + cached_pipe = Pipeline([('transf', transf), ('svc', clf)], + memory=memory) + + # Memoize the transformer at the first fit + cached_pipe.fit(X, y) + pipe.fit(X, y) + # Get the time stamp of the tranformer in the cached pipeline + ts = cached_pipe.named_steps['transf'].timestamp_ + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe.predict(X)) + assert_array_equal(pipe.predict_proba(X), cached_pipe.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe.named_steps['transf'].means_) + assert_false(hasattr(transf, 'means_')) + # Check that we are reading the cache while fitting + # a second time + cached_pipe.fit(X, y) + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe.predict(X)) + assert_array_equal(pipe.predict_proba(X), cached_pipe.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe.named_steps['transf'].means_) + assert_equal(ts, cached_pipe.named_steps['transf'].timestamp_) + # Create a new pipeline with cloned estimators + # Check that even changing the name step does not affect the cache hit + clf_2 = SVC(probability=True, random_state=0) + transf_2 = DummySampler() + cached_pipe_2 = Pipeline([('transf_2', transf_2), ('svc', clf_2)], + memory=memory) + cached_pipe_2.fit(X, y) + + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe_2.predict(X)) + assert_array_equal(pipe.predict_proba(X), + cached_pipe_2.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe_2.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe_2.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe_2.named_steps['transf_2'].means_) + assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_) + finally: + shutil.rmtree(cachedir) def test_pipeline_methods_pca_rus_svm():