@@ -65,6 +65,10 @@ def _yield_sampler_checks(sampler):
65
65
yield check_samplers_dask_dataframe
66
66
yield check_samplers_list
67
67
yield check_samplers_multiclass_ova
68
+ if "dask-array" in tags ["X_types" ]:
69
+ yield check_samplers_multiclass_ova_dask_array
70
+ if "dask-dataframe" in tags ["X_types" ]:
71
+ yield check_samplers_multiclass_ova_dask_dataframe
68
72
yield check_samplers_preserve_dtype
69
73
yield check_samplers_sample_indices
70
74
yield check_samplers_2d_target
@@ -343,27 +347,20 @@ def check_samplers_dask_dataframe(name, sampler_orig):
343
347
)
344
348
y_s = dataframe .from_array (y )
345
349
y_s = y_s .rename ("target" )
346
- y_s_ohe = dataframe .get_dummies (
347
- y_s .astype (pd .CategoricalDtype (categories = [0 , 1 , 2 ]))
348
- )
349
350
350
351
for validate_if_dask_collection in (True , False ):
351
352
sampler .set_params (
352
353
validate_if_dask_collection = validate_if_dask_collection
353
354
)
354
355
X_res_df , y_res_s = sampler .fit_resample (X_df , y_s )
355
- # FIXME: not supported with validate=False
356
- X_res , y_res_s_ohe = sampler .fit_resample (X , y_s_ohe )
357
356
X_res , y_res = sampler .fit_resample (X , y )
358
357
359
358
# check that we return the same type for dataframes or series types
360
359
assert isinstance (X_res_df , dataframe .DataFrame )
361
360
assert isinstance (y_res_s , dataframe .Series )
362
- assert isinstance (y_res_s_ohe , dataframe .DataFrame )
363
361
364
362
assert X_df .columns .to_list () == X_res_df .columns .to_list ()
365
363
assert y_s .name == y_res_s .name
366
- assert y_s_ohe .columns .to_list () == y_res_s_ohe .columns .to_list ()
367
364
368
365
assert_allclose (np .array (X_res_df ), X_res )
369
366
assert_allclose (np .array (y_res_s ), y_res )
@@ -408,6 +405,66 @@ def check_samplers_multiclass_ova(name, sampler):
408
405
assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
409
406
410
407
408
+ def check_samplers_multiclass_ova_dask_array (name , sampler_orig ):
409
+ pytest .importorskip ("dask" )
410
+ from dask import array
411
+ sampler = clone (sampler_orig )
412
+ X , y = make_classification (
413
+ n_samples = 1000 ,
414
+ n_classes = 3 ,
415
+ n_informative = 4 ,
416
+ weights = [0.2 , 0.3 , 0.5 ],
417
+ random_state = 0 ,
418
+ )
419
+ y_ova = label_binarize (y , np .unique (y ))
420
+
421
+ X = array .from_array (X )
422
+ y = array .from_array (y )
423
+ y_ova = array .from_array (y_ova )
424
+
425
+ sampler .set_params (validate_if_dask_collection = True )
426
+ X_res , y_res = sampler .fit_resample (X , y )
427
+ X_res_ova , y_res_ova = sampler .fit_resample (X , y_ova )
428
+
429
+ assert_allclose (X_res , X_res_ova )
430
+ assert type_of_target (y_res_ova ) == type_of_target (y_ova )
431
+ assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
432
+
433
+ assert isinstance (X_res_ova , array .Array )
434
+ assert isinstance (y_res , array .Array )
435
+ assert isinstance (y_res_ova , array .Array )
436
+
437
+
438
+ def check_samplers_multiclass_ova_dask_dataframe (name , sampler_orig ):
439
+ pytest .importorskip ("dask" )
440
+ from dask import dataframe
441
+ sampler = clone (sampler_orig )
442
+ X , y = make_classification (
443
+ n_samples = 1000 ,
444
+ n_classes = 3 ,
445
+ n_informative = 4 ,
446
+ weights = [0.2 , 0.3 , 0.5 ],
447
+ random_state = 0 ,
448
+ )
449
+ y_ova = label_binarize (y , np .unique (y ))
450
+
451
+ X = dataframe .from_array (X )
452
+ y = dataframe .from_array (y )
453
+ y_ova = dataframe .from_array (y_ova )
454
+
455
+ sampler .set_params (validate_if_dask_collection = True )
456
+ X_res , y_res = sampler .fit_resample (X , y )
457
+ X_res_ova , y_res_ova = sampler .fit_resample (X , y_ova )
458
+
459
+ assert_allclose (X_res , X_res_ova )
460
+ assert type_of_target (y_res_ova ) == type_of_target (y_ova )
461
+ assert_allclose (y_res , y_res_ova .to_dask_array ().argmax (axis = 1 ))
462
+
463
+ assert isinstance (X_res_ova , dataframe .DataFrame )
464
+ assert isinstance (y_res , dataframe .Series )
465
+ assert isinstance (y_res_ova , dataframe .DataFrame )
466
+
467
+
411
468
def check_samplers_2d_target (name , sampler ):
412
469
X , y = make_classification (
413
470
n_samples = 100 ,
0 commit comments