Skip to content

Commit 6c592ff

Browse files
committed
iter
1 parent 32eda46 commit 6c592ff

File tree

1 file changed

+64
-7
lines changed

1 file changed

+64
-7
lines changed

imblearn/utils/estimator_checks.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def _yield_sampler_checks(sampler):
6565
yield check_samplers_dask_dataframe
6666
yield check_samplers_list
6767
yield check_samplers_multiclass_ova
68+
if "dask-array" in tags["X_types"]:
69+
yield check_samplers_multiclass_ova_dask_array
70+
if "dask-dataframe" in tags["X_types"]:
71+
yield check_samplers_multiclass_ova_dask_dataframe
6872
yield check_samplers_preserve_dtype
6973
yield check_samplers_sample_indices
7074
yield check_samplers_2d_target
@@ -343,27 +347,20 @@ def check_samplers_dask_dataframe(name, sampler_orig):
343347
)
344348
y_s = dataframe.from_array(y)
345349
y_s = y_s.rename("target")
346-
y_s_ohe = dataframe.get_dummies(
347-
y_s.astype(pd.CategoricalDtype(categories=[0, 1, 2]))
348-
)
349350

350351
for validate_if_dask_collection in (True, False):
351352
sampler.set_params(
352353
validate_if_dask_collection=validate_if_dask_collection
353354
)
354355
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
355-
# FIXME: not supported with validate=False
356-
X_res, y_res_s_ohe = sampler.fit_resample(X, y_s_ohe)
357356
X_res, y_res = sampler.fit_resample(X, y)
358357

359358
# check that we return the same type for dataframes or series types
360359
assert isinstance(X_res_df, dataframe.DataFrame)
361360
assert isinstance(y_res_s, dataframe.Series)
362-
assert isinstance(y_res_s_ohe, dataframe.DataFrame)
363361

364362
assert X_df.columns.to_list() == X_res_df.columns.to_list()
365363
assert y_s.name == y_res_s.name
366-
assert y_s_ohe.columns.to_list() == y_res_s_ohe.columns.to_list()
367364

368365
assert_allclose(np.array(X_res_df), X_res)
369366
assert_allclose(np.array(y_res_s), y_res)
@@ -408,6 +405,66 @@ def check_samplers_multiclass_ova(name, sampler):
408405
assert_allclose(y_res, y_res_ova.argmax(axis=1))
409406

410407

408+
def check_samplers_multiclass_ova_dask_array(name, sampler_orig):
409+
pytest.importorskip("dask")
410+
from dask import array
411+
sampler = clone(sampler_orig)
412+
X, y = make_classification(
413+
n_samples=1000,
414+
n_classes=3,
415+
n_informative=4,
416+
weights=[0.2, 0.3, 0.5],
417+
random_state=0,
418+
)
419+
y_ova = label_binarize(y, np.unique(y))
420+
421+
X = array.from_array(X)
422+
y = array.from_array(y)
423+
y_ova = array.from_array(y_ova)
424+
425+
sampler.set_params(validate_if_dask_collection=True)
426+
X_res, y_res = sampler.fit_resample(X, y)
427+
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
428+
429+
assert_allclose(X_res, X_res_ova)
430+
assert type_of_target(y_res_ova) == type_of_target(y_ova)
431+
assert_allclose(y_res, y_res_ova.argmax(axis=1))
432+
433+
assert isinstance(X_res_ova, array.Array)
434+
assert isinstance(y_res, array.Array)
435+
assert isinstance(y_res_ova, array.Array)
436+
437+
438+
def check_samplers_multiclass_ova_dask_dataframe(name, sampler_orig):
439+
pytest.importorskip("dask")
440+
from dask import dataframe
441+
sampler = clone(sampler_orig)
442+
X, y = make_classification(
443+
n_samples=1000,
444+
n_classes=3,
445+
n_informative=4,
446+
weights=[0.2, 0.3, 0.5],
447+
random_state=0,
448+
)
449+
y_ova = label_binarize(y, np.unique(y))
450+
451+
X = dataframe.from_array(X)
452+
y = dataframe.from_array(y)
453+
y_ova = dataframe.from_array(y_ova)
454+
455+
sampler.set_params(validate_if_dask_collection=True)
456+
X_res, y_res = sampler.fit_resample(X, y)
457+
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
458+
459+
assert_allclose(X_res, X_res_ova)
460+
assert type_of_target(y_res_ova) == type_of_target(y_ova)
461+
assert_allclose(y_res, y_res_ova.to_dask_array().argmax(axis=1))
462+
463+
assert isinstance(X_res_ova, dataframe.DataFrame)
464+
assert isinstance(y_res, dataframe.Series)
465+
assert isinstance(y_res_ova, dataframe.DataFrame)
466+
467+
411468
def check_samplers_2d_target(name, sampler):
412469
X, y = make_classification(
413470
n_samples=100,

0 commit comments

Comments
 (0)