Skip to content

Commit 97dd664

Browse files
committed
MNT remove ensemble
1 parent b1d2d56 commit 97dd664

File tree

3 files changed

+11
-98
lines changed

3 files changed

+11
-98
lines changed

imblearn/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def fit(self, X, y):
3232
3333
Parameters
3434
----------
35-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
35+
X : {array-like, dataframe, sparse matrix} of shape \
36+
(n_samples, n_features)
3637
Data array.
3738
3839
y : array-like of shape (n_samples,)
@@ -54,15 +55,16 @@ def fit_resample(self, X, y):
5455
5556
Parameters
5657
----------
57-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
58+
X : {array-like, dataframe, sparse matrix} of shape \
59+
(n_samples, n_features)
5860
Matrix containing the data which have to be sampled.
5961
6062
y : array-like of shape (n_samples,)
6163
Corresponding label for each sample in X.
6264
6365
Returns
6466
-------
65-
X_resampled : {array-like, sparse matrix} of shape \
67+
X_resampled : {array-like, dataframe, sparse matrix} of shape \
6668
(n_samples_new, n_features)
6769
The array containing the resampled data.
6870

imblearn/ensemble/base.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

imblearn/utils/estimator_checks.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from imblearn.over_sampling.base import BaseOverSampler
3232
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
33-
from imblearn.ensemble.base import BaseEnsembleSampler
3433
from imblearn.under_sampling import NearMiss, ClusterCentroids
3534

3635

@@ -168,12 +167,6 @@ def check_samplers_fit_resample(name, Sampler):
168167
for class_sample in target_stats.keys()
169168
if class_sample != class_minority
170169
)
171-
elif isinstance(sampler, BaseEnsembleSampler):
172-
y_ensemble = y_res[0]
173-
n_samples = min(target_stats.values())
174-
assert all(
175-
value == n_samples for value in Counter(y_ensemble).values()
176-
)
177170

178171

179172
def check_samplers_sampling_strategy_fit_resample(name, Sampler):
@@ -202,12 +195,6 @@ def check_samplers_sampling_strategy_fit_resample(name, Sampler):
202195
sampler.set_params(sampling_strategy=sampling_strategy)
203196
X_res, y_res = sampler.fit_resample(X, y)
204197
assert Counter(y_res)[1] == expected_stat
205-
if isinstance(sampler, BaseEnsembleSampler):
206-
sampling_strategy = {2: 201, 0: 201}
207-
sampler.set_params(sampling_strategy=sampling_strategy)
208-
X_res, y_res = sampler.fit_resample(X, y)
209-
y_ensemble = y_res[0]
210-
assert Counter(y_ensemble)[1] == expected_stat
211198

212199

213200
def check_samplers_sparse(name, Sampler):
@@ -239,17 +226,10 @@ def check_samplers_sparse(name, Sampler):
239226
set_random_state(sampler)
240227
X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y)
241228
X_res, y_res = sampler.fit_resample(X, y)
242-
if not isinstance(sampler, BaseEnsembleSampler):
243-
assert sparse.issparse(X_res_sparse)
244-
assert_allclose(X_res_sparse.A, X_res)
245-
assert_allclose(y_res_sparse, y_res)
246-
else:
247-
for x_sp, x, y_sp, y in zip(
248-
X_res_sparse, X_res, y_res_sparse, y_res
249-
):
250-
assert sparse.issparse(x_sp)
251-
assert_allclose(x_sp.A, x)
252-
assert_allclose(y_sp, y)
229+
for x_sp, x, y_sp, y in zip(X_res_sparse, X_res, y_res_sparse, y_res):
230+
assert sparse.issparse(x_sp)
231+
assert_allclose(x_sp.A, x)
232+
assert_allclose(y_sp, y)
253233

254234

255235
def check_samplers_pandas(name, Sampler):
@@ -297,13 +277,8 @@ def check_samplers_multiclass_ova(name, Sampler):
297277
X_res, y_res = sampler.fit_resample(X, y)
298278
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
299279
assert_allclose(X_res, X_res_ova)
300-
if issubclass(Sampler, BaseEnsembleSampler):
301-
for batch_y, batch_y_ova in zip(y_res, y_res_ova):
302-
assert type_of_target(batch_y_ova) == type_of_target(y_ova)
303-
assert_allclose(batch_y, batch_y_ova.argmax(axis=1))
304-
else:
305-
assert type_of_target(y_res_ova) == type_of_target(y_ova)
306-
assert_allclose(y_res, y_res_ova.argmax(axis=1))
280+
assert type_of_target(y_res_ova) == type_of_target(y_ova)
281+
assert_allclose(y_res, y_res_ova.argmax(axis=1))
307282

308283

309284
def check_samplers_preserve_dtype(name, Sampler):

0 commit comments

Comments
 (0)