@@ -1168,34 +1168,45 @@ def get_metadata_routing(self):
1168
1168
router = MetadataRouter (owner = self .__class__ .__name__ )
1169
1169
1170
1170
# first we add all steps except the last one
1171
- for _ , name , trans in self ._iter (with_final = False , filter_passthrough = True ):
1171
+ for _ , name , trans in self ._iter (
1172
+ with_final = False , filter_passthrough = True , filter_resample = False
1173
+ ):
1172
1174
method_mapping = MethodMapping ()
1173
1175
# fit, fit_predict, and fit_transform call fit_transform if it
1174
1176
# exists, or else fit and transform
1175
1177
if hasattr (trans , "fit_transform" ):
1176
- method_mapping .add (caller = "fit" , callee = "fit_transform" )
1177
- method_mapping .add (caller = "fit_transform" , callee = "fit_transform" )
1178
- method_mapping .add (caller = "fit_predict" , callee = "fit_transform" )
1179
- method_mapping .add (caller = "fit_resample" , callee = "fit_transform" )
1178
+ (
1179
+ method_mapping .add (caller = "fit" , callee = "fit_transform" )
1180
+ .add (caller = "fit_transform" , callee = "fit_transform" )
1181
+ .add (caller = "fit_predict" , callee = "fit_transform" )
1182
+ )
1180
1183
else :
1181
- method_mapping .add (caller = "fit" , callee = "fit" )
1182
- method_mapping .add (caller = "fit" , callee = "transform" )
1183
- method_mapping .add (caller = "fit_transform" , callee = "fit" )
1184
- method_mapping .add (caller = "fit_transform" , callee = "transform" )
1185
- method_mapping .add (caller = "fit_predict" , callee = "fit" )
1186
- method_mapping .add (caller = "fit_predict" , callee = "transform" )
1187
- method_mapping .add (caller = "fit_resample" , callee = "fit" )
1188
- method_mapping .add (caller = "fit_resample" , callee = "transform" )
1189
-
1190
- method_mapping .add (caller = "predict" , callee = "transform" )
1191
- method_mapping .add (caller = "predict" , callee = "transform" )
1192
- method_mapping .add (caller = "predict_proba" , callee = "transform" )
1193
- method_mapping .add (caller = "decision_function" , callee = "transform" )
1194
- method_mapping .add (caller = "predict_log_proba" , callee = "transform" )
1195
- method_mapping .add (caller = "transform" , callee = "transform" )
1196
- method_mapping .add (caller = "inverse_transform" , callee = "inverse_transform" )
1197
- method_mapping .add (caller = "score" , callee = "transform" )
1198
- method_mapping .add (caller = "fit_resample" , callee = "transform" )
1184
+ (
1185
+ method_mapping .add (caller = "fit" , callee = "fit" )
1186
+ .add (caller = "fit" , callee = "transform" )
1187
+ .add (caller = "fit_transform" , callee = "fit" )
1188
+ .add (caller = "fit_transform" , callee = "transform" )
1189
+ .add (caller = "fit_predict" , callee = "fit" )
1190
+ .add (caller = "fit_predict" , callee = "transform" )
1191
+ )
1192
+
1193
+ (
1194
+ # handling sampler if the fit_* stage
1195
+ method_mapping .add (caller = "fit" , callee = "fit_resample" )
1196
+ .add (caller = "fit_transform" , callee = "fit_resample" )
1197
+ .add (caller = "fit_predict" , callee = "fit_resample" )
1198
+ )
1199
+ (
1200
+ method_mapping .add (caller = "predict" , callee = "transform" )
1201
+ .add (caller = "predict" , callee = "transform" )
1202
+ .add (caller = "predict_proba" , callee = "transform" )
1203
+ .add (caller = "decision_function" , callee = "transform" )
1204
+ .add (caller = "predict_log_proba" , callee = "transform" )
1205
+ .add (caller = "transform" , callee = "transform" )
1206
+ .add (caller = "inverse_transform" , callee = "inverse_transform" )
1207
+ .add (caller = "score" , callee = "transform" )
1208
+ .add (caller = "fit_resample" , callee = "transform" )
1209
+ )
1199
1210
1200
1211
router .add (method_mapping = method_mapping , ** {name : trans })
1201
1212
@@ -1207,23 +1218,24 @@ def get_metadata_routing(self):
1207
1218
method_mapping = MethodMapping ()
1208
1219
if hasattr (final_est , "fit_transform" ):
1209
1220
method_mapping .add (caller = "fit_transform" , callee = "fit_transform" )
1210
- method_mapping .add (caller = "fit_resample" , callee = "fit_transform" )
1211
1221
else :
1222
+ (
1223
+ method_mapping .add (caller = "fit" , callee = "fit" ).add (
1224
+ caller = "fit" , callee = "transform"
1225
+ )
1226
+ )
1227
+ (
1212
1228
method_mapping .add (caller = "fit" , callee = "fit" )
1213
- method_mapping .add (caller = "fit" , callee = "transform" )
1214
- method_mapping .add (caller = "fit_resample" , callee = "fit" )
1215
- method_mapping .add (caller = "fit_resample" , callee = "transform" )
1216
-
1217
- method_mapping .add (caller = "fit" , callee = "fit" )
1218
- method_mapping .add (caller = "predict" , callee = "predict" )
1219
- method_mapping .add (caller = "fit_predict" , callee = "fit_predict" )
1220
- method_mapping .add (caller = "predict_proba" , callee = "predict_proba" )
1221
- method_mapping .add (caller = "decision_function" , callee = "decision_function" )
1222
- method_mapping .add (caller = "predict_log_proba" , callee = "predict_log_proba" )
1223
- method_mapping .add (caller = "transform" , callee = "transform" )
1224
- method_mapping .add (caller = "inverse_transform" , callee = "inverse_transform" )
1225
- method_mapping .add (caller = "score" , callee = "score" )
1226
- method_mapping .add (caller = "fit_resample" , callee = "fit_resample" )
1229
+ .add (caller = "predict" , callee = "predict" )
1230
+ .add (caller = "fit_predict" , callee = "fit_predict" )
1231
+ .add (caller = "predict_proba" , callee = "predict_proba" )
1232
+ .add (caller = "decision_function" , callee = "decision_function" )
1233
+ .add (caller = "predict_log_proba" , callee = "predict_log_proba" )
1234
+ .add (caller = "transform" , callee = "transform" )
1235
+ .add (caller = "inverse_transform" , callee = "inverse_transform" )
1236
+ .add (caller = "score" , callee = "score" )
1237
+ .add (caller = "fit_resample" , callee = "fit_resample" )
1238
+ )
1227
1239
1228
1240
router .add (method_mapping = method_mapping , ** {final_name : final_est })
1229
1241
return router
0 commit comments