Skip to content

Commit 88f057d

Browse files
authored
FIX forward properly the metadata with the pipeline (#1115)
* FIX forward properly the metadata with the pipeline * older sklearn
1 parent 2d65471 commit 88f057d

File tree

3 files changed

+98
-48
lines changed

3 files changed

+98
-48
lines changed

imblearn/base.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
METHODS.append("fit_transform")
2323
METHODS.append("fit_resample")
2424

25+
try:
26+
from sklearn.utils._metadata_requests import SIMPLE_METHODS
27+
28+
SIMPLE_METHODS.append("fit_resample")
29+
except ImportError:
30+
# in older versions of scikit-learn, only METHODS is used
31+
pass
32+
2533

2634
class SamplerMixin(metaclass=ABCMeta):
2735
"""Mixin class for samplers with abstract method.
@@ -33,7 +41,7 @@ class SamplerMixin(metaclass=ABCMeta):
3341
_estimator_type = "sampler"
3442

3543
@_fit_context(prefer_skip_nested_validation=True)
36-
def fit(self, X, y):
44+
def fit(self, X, y, **params):
3745
"""Check inputs and statistics of the sampler.
3846
3947
You should use ``fit_resample`` in all cases.
@@ -47,6 +55,9 @@ def fit(self, X, y):
4755
y : array-like of shape (n_samples,)
4856
Target array.
4957
58+
**params : dict
59+
Extra parameters to use by the sampler.
60+
5061
Returns
5162
-------
5263
self : object
@@ -58,7 +69,8 @@ def fit(self, X, y):
5869
)
5970
return self
6071

61-
def fit_resample(self, X, y):
72+
@_fit_context(prefer_skip_nested_validation=True)
73+
def fit_resample(self, X, y, **params):
6274
"""Resample the dataset.
6375
6476
Parameters
@@ -70,6 +82,9 @@ def fit_resample(self, X, y):
7082
y : array-like of shape (n_samples,)
7183
Corresponding label for each sample in X.
7284
85+
**params : dict
86+
Extra parameters to use by the sampler.
87+
7388
Returns
7489
-------
7590
X_resampled : {array-like, dataframe, sparse matrix} of shape \
@@ -87,7 +102,7 @@ def fit_resample(self, X, y):
87102
self.sampling_strategy, y, self._sampling_type
88103
)
89104

90-
output = self._fit_resample(X, y)
105+
output = self._fit_resample(X, y, **params)
91106

92107
y_ = (
93108
label_binarize(output[1], classes=np.unique(y)) if binarize_y else output[1]
@@ -97,7 +112,7 @@ def fit_resample(self, X, y):
97112
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
98113

99114
@abstractmethod
100-
def _fit_resample(self, X, y):
115+
def _fit_resample(self, X, y, **params):
101116
"""Base method defined in each sampler to defined the sampling
102117
strategy.
103118
@@ -109,6 +124,9 @@ def _fit_resample(self, X, y):
109124
y : array-like of shape (n_samples,)
110125
Corresponding label for each sample in X.
111126
127+
**params : dict
128+
Extra parameters to use by the sampler.
129+
112130
Returns
113131
-------
114132
X_resampled : {ndarray, sparse matrix} of shape \
@@ -139,7 +157,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
139157
X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse)
140158
return X, y, binarize_y
141159

142-
def fit(self, X, y):
160+
def fit(self, X, y, **params):
143161
"""Check inputs and statistics of the sampler.
144162
145163
You should use ``fit_resample`` in all cases.
@@ -158,10 +176,9 @@ def fit(self, X, y):
158176
self : object
159177
Return the instance itself.
160178
"""
161-
self._validate_params()
162-
return super().fit(X, y)
179+
return super().fit(X, y, **params)
163180

164-
def fit_resample(self, X, y):
181+
def fit_resample(self, X, y, **params):
165182
"""Resample the dataset.
166183
167184
Parameters
@@ -182,8 +199,7 @@ def fit_resample(self, X, y):
182199
y_resampled : array-like of shape (n_samples_new,)
183200
The corresponding label of `X_resampled`.
184201
"""
185-
self._validate_params()
186-
return super().fit_resample(X, y)
202+
return super().fit_resample(X, y, **params)
187203

188204
def _more_tags(self):
189205
return {"X_types": ["2darray", "sparse", "dataframe"]}

imblearn/pipeline.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,34 +1168,45 @@ def get_metadata_routing(self):
11681168
router = MetadataRouter(owner=self.__class__.__name__)
11691169

11701170
# first we add all steps except the last one
1171-
for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
1171+
for _, name, trans in self._iter(
1172+
with_final=False, filter_passthrough=True, filter_resample=False
1173+
):
11721174
method_mapping = MethodMapping()
11731175
# fit, fit_predict, and fit_transform call fit_transform if it
11741176
# exists, or else fit and transform
11751177
if hasattr(trans, "fit_transform"):
1176-
method_mapping.add(caller="fit", callee="fit_transform")
1177-
method_mapping.add(caller="fit_transform", callee="fit_transform")
1178-
method_mapping.add(caller="fit_predict", callee="fit_transform")
1179-
method_mapping.add(caller="fit_resample", callee="fit_transform")
1178+
(
1179+
method_mapping.add(caller="fit", callee="fit_transform")
1180+
.add(caller="fit_transform", callee="fit_transform")
1181+
.add(caller="fit_predict", callee="fit_transform")
1182+
)
11801183
else:
1181-
method_mapping.add(caller="fit", callee="fit")
1182-
method_mapping.add(caller="fit", callee="transform")
1183-
method_mapping.add(caller="fit_transform", callee="fit")
1184-
method_mapping.add(caller="fit_transform", callee="transform")
1185-
method_mapping.add(caller="fit_predict", callee="fit")
1186-
method_mapping.add(caller="fit_predict", callee="transform")
1187-
method_mapping.add(caller="fit_resample", callee="fit")
1188-
method_mapping.add(caller="fit_resample", callee="transform")
1189-
1190-
method_mapping.add(caller="predict", callee="transform")
1191-
method_mapping.add(caller="predict", callee="transform")
1192-
method_mapping.add(caller="predict_proba", callee="transform")
1193-
method_mapping.add(caller="decision_function", callee="transform")
1194-
method_mapping.add(caller="predict_log_proba", callee="transform")
1195-
method_mapping.add(caller="transform", callee="transform")
1196-
method_mapping.add(caller="inverse_transform", callee="inverse_transform")
1197-
method_mapping.add(caller="score", callee="transform")
1198-
method_mapping.add(caller="fit_resample", callee="transform")
1184+
(
1185+
method_mapping.add(caller="fit", callee="fit")
1186+
.add(caller="fit", callee="transform")
1187+
.add(caller="fit_transform", callee="fit")
1188+
.add(caller="fit_transform", callee="transform")
1189+
.add(caller="fit_predict", callee="fit")
1190+
.add(caller="fit_predict", callee="transform")
1191+
)
1192+
1193+
(
1194+
# handling sampler if the fit_* stage
1195+
method_mapping.add(caller="fit", callee="fit_resample")
1196+
.add(caller="fit_transform", callee="fit_resample")
1197+
.add(caller="fit_predict", callee="fit_resample")
1198+
)
1199+
(
1200+
method_mapping.add(caller="predict", callee="transform")
1201+
.add(caller="predict", callee="transform")
1202+
.add(caller="predict_proba", callee="transform")
1203+
.add(caller="decision_function", callee="transform")
1204+
.add(caller="predict_log_proba", callee="transform")
1205+
.add(caller="transform", callee="transform")
1206+
.add(caller="inverse_transform", callee="inverse_transform")
1207+
.add(caller="score", callee="transform")
1208+
.add(caller="fit_resample", callee="transform")
1209+
)
11991210

12001211
router.add(method_mapping=method_mapping, **{name: trans})
12011212

@@ -1207,23 +1218,24 @@ def get_metadata_routing(self):
12071218
method_mapping = MethodMapping()
12081219
if hasattr(final_est, "fit_transform"):
12091220
method_mapping.add(caller="fit_transform", callee="fit_transform")
1210-
method_mapping.add(caller="fit_resample", callee="fit_transform")
12111221
else:
1222+
(
1223+
method_mapping.add(caller="fit", callee="fit").add(
1224+
caller="fit", callee="transform"
1225+
)
1226+
)
1227+
(
12121228
method_mapping.add(caller="fit", callee="fit")
1213-
method_mapping.add(caller="fit", callee="transform")
1214-
method_mapping.add(caller="fit_resample", callee="fit")
1215-
method_mapping.add(caller="fit_resample", callee="transform")
1216-
1217-
method_mapping.add(caller="fit", callee="fit")
1218-
method_mapping.add(caller="predict", callee="predict")
1219-
method_mapping.add(caller="fit_predict", callee="fit_predict")
1220-
method_mapping.add(caller="predict_proba", callee="predict_proba")
1221-
method_mapping.add(caller="decision_function", callee="decision_function")
1222-
method_mapping.add(caller="predict_log_proba", callee="predict_log_proba")
1223-
method_mapping.add(caller="transform", callee="transform")
1224-
method_mapping.add(caller="inverse_transform", callee="inverse_transform")
1225-
method_mapping.add(caller="score", callee="score")
1226-
method_mapping.add(caller="fit_resample", callee="fit_resample")
1229+
.add(caller="predict", callee="predict")
1230+
.add(caller="fit_predict", callee="fit_predict")
1231+
.add(caller="predict_proba", callee="predict_proba")
1232+
.add(caller="decision_function", callee="decision_function")
1233+
.add(caller="predict_log_proba", callee="predict_log_proba")
1234+
.add(caller="transform", callee="transform")
1235+
.add(caller="inverse_transform", callee="inverse_transform")
1236+
.add(caller="score", callee="score")
1237+
.add(caller="fit_resample", callee="fit_resample")
1238+
)
12271239

12281240
router.add(method_mapping=method_mapping, **{final_name: final_est})
12291241
return router

imblearn/tests/test_pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from sklearn.utils.fixes import parse_version
3636

37+
from imblearn.base import BaseSampler
3738
from imblearn.datasets import make_imbalance
3839
from imblearn.pipeline import Pipeline, make_pipeline
3940
from imblearn.under_sampling import EditedNearestNeighbours as ENN
@@ -1495,3 +1496,24 @@ def test_transform_input_sklearn_version():
14951496

14961497
# end of transform_input tests
14971498
# =============================
1499+
1500+
1501+
def test_metadata_routing_with_sampler():
1502+
"""Check that we can use a sampler with metadata routing."""
1503+
X, y = make_classification()
1504+
cost_matrix = np.random.rand(X.shape[0], 2, 2)
1505+
1506+
class CostSensitiveSampler(BaseSampler):
1507+
def fit_resample(self, X, y, cost_matrix=None):
1508+
return self._fit_resample(X, y, cost_matrix=cost_matrix)
1509+
1510+
def _fit_resample(self, X, y, cost_matrix=None):
1511+
self.cost_matrix_ = cost_matrix
1512+
return X, y
1513+
1514+
with config_context(enable_metadata_routing=True):
1515+
sampler = CostSensitiveSampler().set_fit_resample_request(cost_matrix=True)
1516+
pipeline = Pipeline([("sampler", sampler), ("model", LogisticRegression())])
1517+
pipeline.fit(X, y, cost_matrix=cost_matrix)
1518+
1519+
assert_allclose(pipeline[0].cost_matrix_, cost_matrix)

0 commit comments

Comments
 (0)