@@ -61,6 +61,8 @@ def _yield_sampler_checks(sampler):
61
61
yield check_samplers_pandas
62
62
if "dask-array" in tags ["X_types" ]:
63
63
yield check_samplers_dask_array
64
+ if "dask-dataframe" in tags ["X_types" ]:
65
+ yield check_samplers_dask_dataframe
64
66
yield check_samplers_list
65
67
yield check_samplers_multiclass_ova
66
68
yield check_samplers_preserve_dtype
@@ -295,7 +297,7 @@ def check_samplers_pandas(name, sampler):
295
297
296
298
def check_samplers_dask_array (name , sampler ):
297
299
dask = pytest .importorskip ("dask" )
298
- # Check that the samplers handle pandas dataframe and pandas series
300
+ # Check that the samplers handle dask array
299
301
X , y = make_classification (
300
302
n_samples = 1000 ,
301
303
n_classes = 3 ,
@@ -317,6 +319,37 @@ def check_samplers_dask_array(name, sampler):
317
319
assert_allclose (y_res_dask , y_res )
318
320
319
321
322
+ def check_samplers_dask_dataframe (name , sampler ):
323
+ dask = pytest .importorskip ("dask" )
324
+ # Check that the samplers handle dask dataframe and dask series
325
+ X , y = make_classification (
326
+ n_samples = 1000 ,
327
+ n_classes = 3 ,
328
+ n_informative = 4 ,
329
+ weights = [0.2 , 0.3 , 0.5 ],
330
+ random_state = 0 ,
331
+ )
332
+ X_df = dask .dataframe .from_array (
333
+ X , columns = [str (i ) for i in range (X .shape [1 ])]
334
+ )
335
+ y_s = dask .dataframe .from_array (y )
336
+
337
+ X_res_df , y_res_s = sampler .fit_resample (X_df , y_s )
338
+ X_res , y_res = sampler .fit_resample (X , y )
339
+
340
+ # check that we return the same type for dataframes or series types
341
+ assert isinstance (X_res_df , dask .dataframe .DataFrame )
342
+ assert isinstance (y_res_s , dask .dataframe .Series )
343
+
344
+ # assert X_df.columns.to_list() == X_res_df.columns.to_list()
345
+ # assert y_df.columns.to_list() == y_res_df.columns.to_list()
346
+ # assert y_s.name == y_res_s.name
347
+
348
+ # assert_allclose(X_res_df.to_numpy(), X_res)
349
+ # assert_allclose(y_res_df.to_numpy().ravel(), y_res)
350
+ # assert_allclose(y_res_s.to_numpy(), y_res)
351
+
352
+
320
353
def check_samplers_list (name , sampler ):
321
354
# Check that the can samplers handle simple lists
322
355
X , y = make_classification (
0 commit comments