12
12
# Christos Aridas
13
13
# Guillaume Lemaitre <g.lemaitre58@gmail.com>
14
14
# License: BSD
15
- from itertools import filterfalse
16
-
17
15
from sklearn import pipeline
18
16
from sklearn .base import clone
19
- from sklearn .utils import Bunch , _print_elapsed_time
17
+ from sklearn .utils import _print_elapsed_time
20
18
from sklearn .utils .metaestimators import if_delegate_has_method
21
19
from sklearn .utils .validation import check_memory
22
20
@@ -170,13 +168,13 @@ def _validate_steps(self):
170
168
)
171
169
172
170
def _iter (
173
- self , with_final = True , filter_passthrough = True , with_resample = False
171
+ self , with_final = True , filter_passthrough = True , filter_resample = True
174
172
):
175
173
it = super ()._iter (with_final , filter_passthrough )
176
- if with_resample :
177
- return it
174
+ if filter_resample :
175
+ return filter ( lambda x : not hasattr ( x [ - 1 ], "fit_resample" ), it )
178
176
else :
179
- return filterfalse ( lambda x : hasattr ( x [ - 1 ], "fit_resample" ), it )
177
+ return it
180
178
181
179
# Estimator interface
182
180
@@ -206,7 +204,7 @@ def _fit(self, X, y=None, **fit_params):
206
204
name ,
207
205
transformer ) in self ._iter (with_final = False ,
208
206
filter_passthrough = False ,
209
- with_resample = True ):
207
+ filter_resample = False ):
210
208
if (transformer is None or transformer == 'passthrough' ):
211
209
with _print_elapsed_time ('Pipeline' ,
212
210
self ._log_message (step_idx )):
@@ -220,7 +218,7 @@ def _fit(self, X, y=None, **fit_params):
220
218
else :
221
219
cloned_transformer = clone (transformer )
222
220
elif hasattr (memory , "cachedir" ):
223
- # joblib < 0.11
221
+ # joblib <= 0.11
224
222
if memory .cachedir is None :
225
223
# we do not clone when caching is disabled to
226
224
# preserve backward compatibility
0 commit comments