Skip to content

Commit 267dd32

Browse files
authored
BUG: fix non deterministic result by always sorting sampling_strategy (#449)
2 parents cc78dde + 5064c7e commit 267dd32

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ Bug fixes
5656
generating new samples. :issue:`354` by :user:`Guillaume Lemaitre
5757
<glemaitre>`.
5858

59+
- Fix bug which allow for sorted behavior of ``sampling_strategy`` dictionary
60+
and thus to obtain a deterministic results when using the same random state.
61+
:issue:`447` by :user:`Guillaume Lemaitre <glemaitre>`.
62+
5963
- Force to clone scikit-learn estimator passed as attributes to samplers.
6064
:issue:`446` by :user:`Guillaume Lemaitre <glemaitre>`.
6165

imblearn/utils/tests/test_validation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# License: MIT
55

66
from collections import Counter
7+
from collections import OrderedDict
78

89
import pytest
910
import numpy as np
@@ -372,3 +373,19 @@ def test_hash_X_y():
372373
y = np.array([0] * 2 + [1] * 3)
373374
# all data will be used in this case
374375
assert hash_X_y(X, y) == (joblib.hash(X), joblib.hash(y))
376+
377+
378+
@pytest.mark.parametrize(
379+
"sampling_strategy, sampling_type, expected_result",
380+
[({3: 25, 1: 25, 2: 25}, 'under-sampling',
381+
OrderedDict({1: 25, 2: 25, 3: 25})),
382+
({3: 100, 1: 100, 2: 100}, 'over-sampling',
383+
OrderedDict({1: 50, 2: 0, 3: 75}))])
384+
def test_sampling_strategy_check_order(sampling_strategy, sampling_type,
385+
expected_result):
386+
# We pass on purpose a non sorted dictionary and check that the resulting
387+
# dictionary is sorted. Refer to issue #428.
388+
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
389+
sampling_strategy_ = check_sampling_strategy(
390+
sampling_strategy, y, sampling_type)
391+
assert sampling_strategy_ == expected_result

imblearn/utils/validation.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import warnings
88
from collections import Counter
9+
from collections import OrderedDict
910
from numbers import Integral, Real
1011

1112
import numpy as np
@@ -463,21 +464,30 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs):
463464
raise ValueError("When 'sampling_strategy' is a string, it needs"
464465
" to be one of {}. Got '{}' instead.".format(
465466
SAMPLING_TARGET_KIND, sampling_strategy))
466-
return SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type)
467+
return OrderedDict(sorted(
468+
SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type).items()))
467469
elif isinstance(sampling_strategy, dict):
468-
return _sampling_strategy_dict(sampling_strategy, y, sampling_type)
470+
return OrderedDict(sorted(
471+
_sampling_strategy_dict(sampling_strategy, y, sampling_type)
472+
.items()))
469473
elif isinstance(sampling_strategy, list):
470-
return _sampling_strategy_list(sampling_strategy, y, sampling_type)
474+
return OrderedDict(sorted(
475+
_sampling_strategy_list(sampling_strategy, y, sampling_type)
476+
.items()))
471477
elif isinstance(sampling_strategy, Real):
472478
if sampling_strategy <= 0 or sampling_strategy > 1:
473479
raise ValueError(
474480
"When 'sampling_strategy' is a float, it should be "
475481
"in the range (0, 1]. Got {} instead."
476482
.format(sampling_strategy))
477-
return _sampling_strategy_float(sampling_strategy, y, sampling_type)
483+
return OrderedDict(sorted(
484+
_sampling_strategy_float(sampling_strategy, y, sampling_type)
485+
.items()))
478486
elif callable(sampling_strategy):
479487
sampling_strategy_ = sampling_strategy(y, **kwargs)
480-
return _sampling_strategy_dict(sampling_strategy_, y, sampling_type)
488+
return OrderedDict(sorted(
489+
_sampling_strategy_dict(sampling_strategy_, y, sampling_type)
490+
.items()))
481491

482492

483493
SAMPLING_TARGET_KIND = {

0 commit comments

Comments
 (0)