Skip to content

Commit 0073f9d

Browse files
committed
BUG: fix non deterministic result by always sorting sampling_strategy
1 parent 7dfa6bd commit 0073f9d

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ Bug fixes
5656
generating new samples. :issue:`354` by :user:`Guillaume Lemaitre
5757
<glemaitre>`.
5858

59+
- Fix bug which allow for sorted behavior of ``sampling_strategy`` dictionary
60+
and thus to obtain a deterministic results when using the same random state.
61+
:issue:`447` by :user:`Guillaume Lemaitre <glemaitre>`.
62+
5963
Maintenance
6064
...........
6165

imblearn/utils/tests/test_validation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# License: MIT
55

66
from collections import Counter
7+
from collections import OrderedDict
78

89
import pytest
910
import numpy as np
@@ -371,3 +372,19 @@ def test_hash_X_y():
371372
y = np.array([0] * 2 + [1] * 3)
372373
# all data will be used in this case
373374
assert hash_X_y(X, y) == (joblib.hash(X), joblib.hash(y))
375+
376+
377+
@pytest.mark.parametrize(
378+
"sampling_strategy, sampling_type, expected_result",
379+
[({3: 25, 1: 25, 2: 25}, 'under-sampling',
380+
OrderedDict({1: 25, 2: 25, 3:25})),
381+
({3: 100, 1: 100, 2: 100}, 'over-sampling',
382+
OrderedDict({1:50, 2:0, 3:75}))])
383+
def test_sampling_strategy_check_order(sampling_strategy, sampling_type,
384+
expected_result):
385+
# We pass on purpose a non sorted dictionary and check that the resulting
386+
# dictionary is sorted. Refer to issue #428.
387+
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
388+
sampling_strategy_ = check_sampling_strategy(
389+
sampling_strategy, y, sampling_type)
390+
assert sampling_strategy_ == expected_result

imblearn/utils/validation.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import warnings
88
from collections import Counter
9+
from collections import OrderedDict
910
from numbers import Integral, Real
1011

1112
import numpy as np
@@ -462,21 +463,30 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs):
462463
raise ValueError("When 'sampling_strategy' is a string, it needs"
463464
" to be one of {}. Got '{}' instead.".format(
464465
SAMPLING_TARGET_KIND, sampling_strategy))
465-
return SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type)
466+
return OrderedDict(sorted(
467+
SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type).items()))
466468
elif isinstance(sampling_strategy, dict):
467-
return _sampling_strategy_dict(sampling_strategy, y, sampling_type)
469+
return OrderedDict(sorted(
470+
_sampling_strategy_dict(sampling_strategy, y, sampling_type)
471+
.items()))
468472
elif isinstance(sampling_strategy, list):
469-
return _sampling_strategy_list(sampling_strategy, y, sampling_type)
473+
return OrderedDict(sorted(
474+
_sampling_strategy_list(sampling_strategy, y, sampling_type)
475+
.items()))
470476
elif isinstance(sampling_strategy, Real):
471477
if sampling_strategy <= 0 or sampling_strategy > 1:
472478
raise ValueError(
473479
"When 'sampling_strategy' is a float, it should be "
474480
"in the range (0, 1]. Got {} instead."
475481
.format(sampling_strategy))
476-
return _sampling_strategy_float(sampling_strategy, y, sampling_type)
482+
return OrderedDict(sorted(
483+
_sampling_strategy_float(sampling_strategy, y, sampling_type)
484+
.items()))
477485
elif callable(sampling_strategy):
478486
sampling_strategy_ = sampling_strategy(y, **kwargs)
479-
return _sampling_strategy_dict(sampling_strategy_, y, sampling_type)
487+
return OrderedDict(sorted(
488+
_sampling_strategy_dict(sampling_strategy_, y, sampling_type)
489+
.items()))
480490

481491

482492
SAMPLING_TARGET_KIND = {

0 commit comments

Comments
 (0)