7
7
# License: MIT
8
8
9
9
import logging
10
+ import warnings
10
11
from collections import Counter
12
+ from numbers import Real
11
13
12
- import numpy as np
13
- from sklearn .utils import check_random_state , check_X_y
14
+ from sklearn .utils import check_X_y
15
+
16
+ from ..under_sampling .prototype_selection import RandomUnderSampler
17
+ from ..utils import check_ratio
14
18
15
19
LOGGER = logging .getLogger (__name__ )
16
20
@@ -28,14 +32,23 @@ def make_imbalance(X, y, ratio, min_c_=None, random_state=None):
28
32
y : ndarray, shape (n_samples, )
29
33
Corresponding label for each sample in X.
30
34
31
- ratio : float,
32
- The desired ratio given by the number of samples in
33
- the minority class over the the number of samples in
34
- the majority class. Thus the ratio should be in the interval [0., 1.]
35
+ ratio : str, dict, or callable, optional (default='auto')
36
+ Ratio to use for resampling the data set.
37
+
38
+ - If ``dict``, the keys correspond to the targeted classes. The values
39
+ correspond to the desired number of samples.
40
+ - If callable, function taking ``y`` and returns a ``dict``. The keys
41
+ correspond to the targeted classes. The values correspond to the
42
+ desired number of samples.
35
43
36
44
min_c_ : str or int, optional (default=None)
37
45
The identifier of the class to be the minority class.
38
- If None, min_c_ is set to be the current minority class.
46
+ If ``None``, ``min_c_`` is set to be the current minority class.
47
+ Only used when ``ratio`` is a float for back-compatibility.
48
+
49
+ .. deprecated:: 0.2
50
+ ``min_c_`` is deprecated in 0.2 and will be removed in 0.4. Use
51
+ ``ratio`` by passing a ``dict`` instead.
39
52
40
53
random_state : int, RandomState instance or None, optional (default=None)
41
54
If int, random_state is the seed used by the random number generator;
@@ -51,48 +64,57 @@ def make_imbalance(X, y, ratio, min_c_=None, random_state=None):
51
64
y_resampled : ndarray, shape (n_samples_new)
52
65
The corresponding label of `X_resampled`
53
66
54
- """
55
- if isinstance (ratio , float ):
56
- if ratio > 1 :
57
- raise ValueError ('Ratio cannot be greater than one.'
58
- ' Got {}.' .format (ratio ))
59
- elif ratio <= 0 :
60
- raise ValueError ('Ratio have to be strictly positive.'
61
- ' Got {}.' .format (ratio ))
62
- else :
63
- raise ValueError ('Ratio must be a float between 0.0 < ratio < 1.0'
64
- ' Got {} instead.' .format (ratio ))
67
+ Examples
68
+ --------
69
+ >>> from collections import Counter
70
+ >>> from sklearn.datasets import load_iris
71
+ >>> from imblearn.datasets import make_imbalance
72
+
73
+ >>> data = load_iris()
74
+ >>> X, y = data.data, data.target
75
+ >>> print('Distribution before imbalancing: {}'.format(Counter(y)))
76
+ Distribution before imbalancing: Counter({0: 50, 1: 50, 2: 50})
77
+ >>> X_res, y_res = make_imbalance(X, y, ratio={0: 10, 1: 20, 2: 30},
78
+ ... random_state=42)
79
+ >>> print('Distribution after imbalancing: {}'.format(Counter(y_res)))
80
+ Distribution after imbalancing: Counter({2: 30, 1: 20, 0: 10})
65
81
82
+ """
66
83
X , y = check_X_y (X , y )
67
-
68
- random_state = check_random_state (random_state )
69
-
70
- stats_c_ = Counter (y )
84
+ target_stats = Counter (y )
85
+ # restrict ratio to be a dict or a callable
86
+ if isinstance (ratio , dict ) or callable (ratio ):
87
+ ratio_ = check_ratio (ratio , y , 'under-sampling' )
88
+ # FIXME: deprecated in 0.2 to be removed in 0.4
89
+ elif isinstance (ratio , Real ):
90
+ if min_c_ is None :
91
+ min_c_ = min (target_stats , key = target_stats .get )
92
+ else :
93
+ warnings .warn ("'min_c_' is deprecated in 0.2 and will be removed"
94
+ " in 0.4. Use 'ratio' as dictionary instead." ,
95
+ DeprecationWarning )
96
+ warnings .warn ("'ratio' being a float is deprecated in 0.2 and will not"
97
+ " be supported in 0.4. Use a dictionary instead." ,
98
+ DeprecationWarning )
99
+ class_majority = max (target_stats , key = target_stats .get )
100
+ ratio_ = {}
101
+ for class_sample , n_sample in target_stats .items ():
102
+ if class_sample == min_c_ :
103
+ n_min_samples = int (target_stats [class_majority ] * ratio )
104
+ ratio_ [class_sample ] = n_min_samples
105
+ else :
106
+ ratio_ [class_sample ] = n_sample
107
+ ratio_ = check_ratio (ratio_ , y , 'under-sampling' )
108
+ else :
109
+ raise ValueError ("'ratio' has to be a dictionary or a function"
110
+ " returning a dictionary. Got {} instead." .format (
111
+ type (ratio )))
71
112
72
113
LOGGER .info ('The original target distribution in the dataset is: %s' ,
73
- stats_c_ )
74
-
75
- if min_c_ is None :
76
- min_c_ = min (stats_c_ , key = stats_c_ .get )
77
-
78
- n_min_samples = int (np .count_nonzero (y != min_c_ ) * ratio )
79
- if n_min_samples > stats_c_ [min_c_ ]:
80
- raise ValueError ('Current imbalance ratio of data is lower than'
81
- ' desired ratio! Got {} > {}.' .format (
82
- n_min_samples , stats_c_ [min_c_ ]))
83
- if n_min_samples == 0 :
84
- raise ValueError ('Not enough samples for desired ratio!'
85
- ' Got {}.' .format (n_min_samples ))
86
-
87
- mask = y == min_c_
88
-
89
- idx_maj = np .where (~ mask )[0 ]
90
- idx_min = np .where (mask )[0 ]
91
- idx_min = random_state .choice (idx_min , size = n_min_samples , replace = False )
92
- idx = np .concatenate ((idx_min , idx_maj ), axis = 0 )
93
-
94
- X_resampled , y_resampled = X [idx , :], y [idx ]
95
-
114
+ target_stats )
115
+ rus = RandomUnderSampler (ratio = ratio_ , replacement = False ,
116
+ random_state = random_state )
117
+ X_resampled , y_resampled = rus .fit_sample (X , y )
96
118
LOGGER .info ('Make the dataset imbalanced: %s' , Counter (y_resampled ))
97
119
98
120
return X_resampled , y_resampled
0 commit comments