Skip to content

Commit feb5a41

Browse files
authored
FIX/TST: Synchronize Pipeline with scikit-learn (#514)
1 parent ca7d301 commit feb5a41

File tree

2 files changed

+169
-90
lines changed

2 files changed

+169
-90
lines changed

imblearn/pipeline.py

Lines changed: 82 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
from __future__ import division
1717

18+
from collections import defaultdict
19+
from itertools import islice
20+
1821
from sklearn import pipeline
1922
from sklearn.base import clone
2023
from sklearn.utils.metaestimators import if_delegate_has_method
@@ -38,6 +41,9 @@ class Pipeline(pipeline.Pipeline):
3841
cross-validated together while setting different parameters.
3942
For this, it enables setting parameters of the various steps using their
4043
names and the parameter name separated by a '__', as in the example below.
44+
A step's estimator may be replaced entirely by setting the parameter
45+
with its name to another estimator, or a transformer removed by setting
46+
it to 'passthrough' or ``None``.
4147
4248
Parameters
4349
----------
@@ -121,7 +127,7 @@ def _validate_steps(self):
121127
estimator = estimators[-1]
122128

123129
for t in transformers:
124-
if t is None:
130+
if t is None or t == 'passthrough':
125131
continue
126132
if (not (hasattr(t, "fit") or
127133
hasattr(t, "fit_transform") or
@@ -130,8 +136,9 @@ def _validate_steps(self):
130136
hasattr(t, "fit_resample"))):
131137
raise TypeError(
132138
"All intermediate steps of the chain should "
133-
"be estimators that implement fit and transform or sample "
134-
"(but not both) '%s' (type %s) doesn't)" % (t, type(t)))
139+
"be estimators that implement fit and transform or "
140+
"fit_resample (but not both) or be a string 'passthrough' "
141+
"'%s' (type %s) doesn't)" % (t, type(t)))
135142

136143
if (hasattr(t, "fit_resample") and (hasattr(t, "fit_transform") or
137144
hasattr(t, "transform"))):
@@ -146,14 +153,16 @@ def _validate_steps(self):
146153
" Pipelines")
147154

148155
# We allow last estimator to be None as an identity transformation
149-
if estimator is not None and not hasattr(estimator, "fit"):
150-
raise TypeError("Last step of Pipeline should implement fit. "
151-
"'%s' (type %s) doesn't" % (estimator,
152-
type(estimator)))
156+
if (estimator is not None and estimator != 'passthrough'
157+
and not hasattr(estimator, "fit")):
158+
raise TypeError("Last step of Pipeline should implement fit or be "
159+
"the string 'passthrough'. '%s' (type %s) doesn't"
160+
% (estimator, type(estimator)))
153161

154162
# Estimator interface
155163

156164
def _fit(self, X, y=None, **fit_params):
165+
self.steps = list(self.steps)
157166
self._validate_steps()
158167
# Setup the memory
159168
memory = check_memory(self.memory)
@@ -166,44 +175,39 @@ def _fit(self, X, y=None, **fit_params):
166175
for pname, pval in fit_params.items():
167176
step, param = pname.split('__', 1)
168177
fit_params_steps[step][param] = pval
169-
Xt = X
170-
yt = y
171-
for step_idx, (name, transformer) in enumerate(self.steps[:-1]):
172-
if transformer is None:
173-
pass
174-
else:
175-
if hasattr(memory, 'location'):
176-
# joblib >= 0.12
177-
if memory.location is None:
178-
# we do not clone when caching is disabled to
179-
# preserve backward compatibility
180-
cloned_transformer = transformer
181-
else:
182-
cloned_transformer = clone(transformer)
183-
elif hasattr(memory, 'cachedir'):
184-
# joblib < 0.11
185-
if memory.cachedir is None:
186-
# we do not clone when caching is disabled to
187-
# preserve backward compatibility
188-
cloned_transformer = transformer
178+
for step_idx, name, transformer in self._iter(with_final=False):
179+
if hasattr(memory, 'location'):
180+
# joblib >= 0.12
181+
if memory.location is None:
182+
# we do not clone when caching is disabled to
183+
# preserve backward compatibility
184+
cloned_transformer = transformer
189185
else:
190186
cloned_transformer = clone(transformer)
191-
# Fit or load from cache the current transfomer
192-
if (hasattr(cloned_transformer, "transform") or
193-
hasattr(cloned_transformer, "fit_transform")):
194-
Xt, fitted_transformer = fit_transform_one_cached(
195-
cloned_transformer, None, Xt, yt,
196-
**fit_params_steps[name])
197-
elif hasattr(cloned_transformer, "fit_resample"):
198-
Xt, yt, fitted_transformer = fit_resample_one_cached(
199-
cloned_transformer, Xt, yt, **fit_params_steps[name])
200-
# Replace the transformer of the step with the fitted
201-
# transformer. This is necessary when loading the transformer
202-
# from the cache.
203-
self.steps[step_idx] = (name, fitted_transformer)
204-
if self._final_estimator is None:
205-
return Xt, yt, {}
206-
return Xt, yt, fit_params_steps[self.steps[-1][0]]
187+
elif hasattr(memory, 'cachedir'):
188+
# joblib < 0.11
189+
if memory.cachedir is None:
190+
# we do not clone when caching is disabled to
191+
# preserve backward compatibility
192+
cloned_transformer = transformer
193+
else:
194+
cloned_transformer = clone(transformer)
195+
# Fit or load from cache the current transfomer
196+
if (hasattr(cloned_transformer, "transform") or
197+
hasattr(cloned_transformer, "fit_transform")):
198+
X, fitted_transformer = fit_transform_one_cached(
199+
cloned_transformer, None, X, y,
200+
**fit_params_steps[name])
201+
elif hasattr(cloned_transformer, "fit_resample"):
202+
X, y, fitted_transformer = fit_resample_one_cached(
203+
cloned_transformer, X, y, **fit_params_steps[name])
204+
# Replace the transformer of the step with the fitted
205+
# transformer. This is necessary when loading the transformer
206+
# from the cache.
207+
self.steps[step_idx] = (name, fitted_transformer)
208+
if self._final_estimator == 'passthrough':
209+
return X, y, {}
210+
return X, y, fit_params_steps[self.steps[-1][0]]
207211

208212
def fit(self, X, y=None, **fit_params):
209213
"""Fit the model
@@ -234,7 +238,7 @@ def fit(self, X, y=None, **fit_params):
234238
235239
"""
236240
Xt, yt, fit_params = self._fit(X, y, **fit_params)
237-
if self._final_estimator is not None:
241+
if self._final_estimator != 'passthrough':
238242
self._final_estimator.fit(Xt, yt, **fit_params)
239243
return self
240244

@@ -268,7 +272,7 @@ def fit_transform(self, X, y=None, **fit_params):
268272
"""
269273
last_step = self._final_estimator
270274
Xt, yt, fit_params = self._fit(X, y, **fit_params)
271-
if last_step is None:
275+
if last_step == 'passthrough':
272276
return Xt
273277
elif hasattr(last_step, 'fit_transform'):
274278
return last_step.fit_transform(Xt, yt, **fit_params)
@@ -308,7 +312,7 @@ def fit_resample(self, X, y=None, **fit_params):
308312
"""
309313
last_step = self._final_estimator
310314
Xt, yt, fit_params = self._fit(X, y, **fit_params)
311-
if last_step is None:
315+
if last_step == 'passthrough':
312316
return Xt
313317
elif hasattr(last_step, 'fit_resample'):
314318
return last_step.fit_resample(Xt, yt, **fit_params)
@@ -338,9 +342,7 @@ def predict(self, X, **predict_params):
338342
339343
"""
340344
Xt = X
341-
for _, transform in self.steps[:-1]:
342-
if transform is None:
343-
continue
345+
for _, _, transform in self._iter(with_final=False):
344346
if hasattr(transform, "fit_resample"):
345347
pass
346348
else:
@@ -394,15 +396,33 @@ def predict_proba(self, X):
394396
395397
"""
396398
Xt = X
397-
for _, transform in self.steps[:-1]:
398-
if transform is None:
399-
continue
399+
for _, _, transform in self._iter(with_final=False):
400400
if hasattr(transform, "fit_resample"):
401401
pass
402402
else:
403403
Xt = transform.transform(Xt)
404404
return self.steps[-1][-1].predict_proba(Xt)
405405

406+
@if_delegate_has_method(delegate='_final_estimator')
407+
def score_samples(self, X):
408+
"""Apply transforms, and score_samples of the final estimator.
409+
Parameters
410+
----------
411+
X : iterable
412+
Data to predict on. Must fulfill input requirements of first step
413+
of the pipeline.
414+
Returns
415+
-------
416+
y_score : ndarray, shape (n_samples,)
417+
"""
418+
Xt = X
419+
for _, _, transformer in self._iter(with_final=False):
420+
if hasattr(transformer, "fit_resample"):
421+
pass
422+
else:
423+
Xt = transformer.transform(Xt)
424+
return self.steps[-1][-1].score_samples(Xt)
425+
406426
@if_delegate_has_method(delegate='_final_estimator')
407427
def decision_function(self, X):
408428
"""Apply transformers/samplers, and decision_function of the final
@@ -420,9 +440,7 @@ def decision_function(self, X):
420440
421441
"""
422442
Xt = X
423-
for _, transform in self.steps[:-1]:
424-
if transform is None:
425-
continue
443+
for _, _, transform in self._iter(with_final=False):
426444
if hasattr(transform, "fit_resample"):
427445
pass
428446
else:
@@ -446,9 +464,7 @@ def predict_log_proba(self, X):
446464
447465
"""
448466
Xt = X
449-
for _, transform in self.steps[:-1]:
450-
if transform is None:
451-
continue
467+
for _, _, transform in self._iter(with_final=False):
452468
if hasattr(transform, "fit_resample"):
453469
pass
454470
else:
@@ -473,15 +489,13 @@ def transform(self):
473489
Xt : array-like, shape = [n_samples, n_transformed_features]
474490
"""
475491
# _final_estimator is None or has transform, otherwise attribute error
476-
if self._final_estimator is not None:
492+
if self._final_estimator != 'passthrough':
477493
self._final_estimator.transform
478494
return self._transform
479495

480496
def _transform(self, X):
481497
Xt = X
482-
for name, transform in self.steps:
483-
if transform is None:
484-
continue
498+
for _, _, transform in self._iter():
485499
if hasattr(transform, "fit_resample"):
486500
pass
487501
else:
@@ -507,29 +521,20 @@ def inverse_transform(self):
507521
Xt : array-like, shape = [n_samples, n_features]
508522
"""
509523
# raise AttributeError if necessary for hasattr behaviour
510-
for name, transform in self.steps:
511-
if transform is not None:
512-
transform.inverse_transform
524+
for _, _, transform in self._iter():
525+
transform.inverse_transform
513526
return self._inverse_transform
514527

515528
def _inverse_transform(self, X):
516529
Xt = X
517-
for name, transform in self.steps[::-1]:
518-
if transform is None:
519-
continue
530+
reverse_iter = reversed(list(self._iter()))
531+
for _, _, transform in reverse_iter:
520532
if hasattr(transform, "fit_resample"):
521533
pass
522534
else:
523535
Xt = transform.inverse_transform(Xt)
524536
return Xt
525537

526-
# need to overwrite sklearn's _final_estimator since sklearn supports
527-
# 'passthrough', but imblearn does not.
528-
@property
529-
def _final_estimator(self):
530-
estimator = self.steps[-1][1]
531-
return estimator
532-
533538
@if_delegate_has_method(delegate='_final_estimator')
534539
def score(self, X, y=None, sample_weight=None):
535540
"""Apply transformers/samplers, and score with the final estimator
@@ -553,9 +558,7 @@ def score(self, X, y=None, sample_weight=None):
553558
score : float
554559
"""
555560
Xt = X
556-
for _, transform in self.steps[:-1]:
557-
if transform is None:
558-
continue
561+
for _, _, transform in self._iter(with_final=False):
559562
if hasattr(transform, "fit_resample"):
560563
pass
561564
else:
@@ -618,7 +621,7 @@ def make_pipeline(*steps, **kwargs):
618621
>>> from sklearn.naive_bayes import GaussianNB
619622
>>> from sklearn.preprocessing import StandardScaler
620623
>>> make_pipeline(StandardScaler(), GaussianNB(priors=None))
621-
... # doctest: +NORMALIZE_WHITESPACE
624+
... # doctest: +NORMALIZE_WHITESPACE
622625
Pipeline(memory=None,
623626
steps=[('standardscaler',
624627
StandardScaler(copy=True, with_mean=True, with_std=True)),

0 commit comments

Comments
 (0)