@@ -327,6 +327,7 @@ def check_samplers_dask_array(name, sampler_orig):
327
327
328
328
def check_samplers_dask_dataframe (name , sampler_orig ):
329
329
pytest .importorskip ("dask" )
330
+ pd = pytest .importorskip ("pandas" )
330
331
from dask import dataframe
331
332
sampler = clone (sampler_orig )
332
333
# Check that the samplers handle dask dataframe and dask series
@@ -342,20 +343,27 @@ def check_samplers_dask_dataframe(name, sampler_orig):
342
343
)
343
344
y_s = dataframe .from_array (y )
344
345
y_s = y_s .rename ("target" )
346
+ y_s_ohe = dataframe .get_dummies (
347
+ y_s .astype (pd .CategoricalDtype (categories = [0 , 1 , 2 ]))
348
+ )
345
349
346
350
for validate_if_dask_collection in (True , False ):
347
351
sampler .set_params (
348
352
validate_if_dask_collection = validate_if_dask_collection
349
353
)
350
354
X_res_df , y_res_s = sampler .fit_resample (X_df , y_s )
355
+ # FIXME: not supported with validate=False
356
+ X_res , y_res_s_ohe = sampler .fit_resample (X , y_s_ohe )
351
357
X_res , y_res = sampler .fit_resample (X , y )
352
358
353
359
# check that we return the same type for dataframes or series types
354
360
assert isinstance (X_res_df , dataframe .DataFrame )
355
361
assert isinstance (y_res_s , dataframe .Series )
362
+ assert isinstance (y_res_s_ohe , dataframe .DataFrame )
356
363
357
364
assert X_df .columns .to_list () == X_res_df .columns .to_list ()
358
365
assert y_s .name == y_res_s .name
366
+ assert y_s_ohe .columns .to_list () == y_res_s_ohe .columns .to_list ()
359
367
360
368
assert_allclose (np .array (X_res_df ), X_res )
361
369
assert_allclose (np .array (y_res_s ), y_res )
0 commit comments