Skip to content

Commit 1e3fdfe

Browse files
authored
[MRG] EHN make_imbalance handle multi-class (#312)
* EHN make_imbalance handle multi-class * TST make the test for make_imbalance
1 parent f695fb7 commit 1e3fdfe

File tree

8 files changed

+161
-163
lines changed

8 files changed

+161
-163
lines changed

doc/whats_new.rst

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,69 @@ Changelog
1212
Bug fixes
1313
---------
1414

15+
- Fixed a bug in :func:`utils.check_ratio` such that an error is raised when
16+
the number of samples required is negative. By `Guillaume Lemaitre`_.
17+
1518
- Fixed a bug in :class:`under_sampling.NearMiss` version 3. The
1619
indices returned were wrong. By `Guillaume Lemaitre`_.
17-
- fixed bug for :class:`ensemble.BalanceCascade` and :class:`combine.SMOTEENN`
20+
21+
- Fixed bug for :class:`ensemble.BalanceCascade` and :class:`combine.SMOTEENN`
1822
and :class:`SMOTETomek. By `Guillaume Lemaitre`_.`
1923

2024
New features
2125
~~~~~~~~~~~~
2226

2327
- Turn off steps in :class:`pipeline.Pipeline` using the `None`
2428
object. By `Christos Aridas`_.
29+
2530
- Add a fetching function `datasets.fetch_datasets` in order to get some
2631
imbalanced datasets useful for benchmarking. By `Guillaume Lemaitre`_.
2732

2833
Enhancement
2934
~~~~~~~~~~~
3035

36+
- :func:`datasets.make_imbalance` take a ratio similarly to other samplers. It
37+
supports multiclass. By `Guillaume Lemaitre`_.
38+
3139
- All the unit tests have been factorized and a `check_estimators` has
3240
been derived from scikit-learn. By `Guillaume Lemaitre`_.
41+
3342
- Script for automatic build of conda packages and uploading. By
3443
`Guillaume Lemaitre`_
44+
3545
- Remove seaborn dependence and improve the examples. By `Guillaume
3646
Lemaitre`_.
47+
3748
- adapt all classes to multi-class resampling. By `Guillaume Lemaitre`_
3849

3950
API changes summary
4051
~~~~~~~~~~~~~~~~~~~
4152

4253
- `__init__` has been removed from the :class:`base.SamplerMixin` to
4354
create a real mixin class. By `Guillaume Lemaitre`_.
55+
4456
- creation of a module `exceptions` to handle consistant raising of
4557
errors. By `Guillaume Lemaitre`_.
58+
4659
- creation of a module `utils.validation` to make checking of
4760
recurrent patterns. By `Guillaume Lemaitre`_.
61+
4862
- move the under-sampling methods in `prototype_selection` and
4963
`prototype_generation` submodule to make a clearer dinstinction. By
5064
`Guillaume Lemaitre`_.
65+
5166
- change `ratio` such that it can adapt to multiple class problems. By
5267
`Guillaume Lemaitre`_.
5368

5469
Deprecation
5570
~~~~~~~~~~~
5671

72+
- Deprecation of the use of ``min_c_`` in :func:`datasets.make_imbalance`. By
73+
`Guillaume Lemaitre`_
74+
75+
- Deprecation of the use of float in :func:`datasets.make_imbalance` for the
76+
ratio parameter. By `Guillaume Lemaitre`_.
77+
5778
- deprecate the use of float as ratio in favor of dictionary, string, or
5879
callable. By `Guillaume Lemaitre`_.
5980

imblearn/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from sklearn.base import BaseEstimator
1313
from sklearn.externals import six
14-
from sklearn.utils import check_X_y, check_random_state
14+
from sklearn.utils import check_X_y
1515
from sklearn.utils.validation import check_is_fitted
1616

1717
from .utils import check_ratio, check_target_type, hash_X_y

imblearn/datasets/imbalance.py

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
# License: MIT
88

99
import logging
10+
import warnings
1011
from collections import Counter
12+
from numbers import Real
1113

12-
import numpy as np
13-
from sklearn.utils import check_random_state, check_X_y
14+
from sklearn.utils import check_X_y
15+
16+
from ..under_sampling.prototype_selection import RandomUnderSampler
17+
from ..utils import check_ratio
1418

1519
LOGGER = logging.getLogger(__name__)
1620

@@ -28,14 +32,23 @@ def make_imbalance(X, y, ratio, min_c_=None, random_state=None):
2832
y : ndarray, shape (n_samples, )
2933
Corresponding label for each sample in X.
3034
31-
ratio : float,
32-
The desired ratio given by the number of samples in
33-
the minority class over the the number of samples in
34-
the majority class. Thus the ratio should be in the interval [0., 1.]
35+
ratio : str, dict, or callable, optional (default='auto')
36+
Ratio to use for resampling the data set.
37+
38+
- If ``dict``, the keys correspond to the targeted classes. The values
39+
correspond to the desired number of samples.
40+
- If callable, function taking ``y`` and returns a ``dict``. The keys
41+
correspond to the targeted classes. The values correspond to the
42+
desired number of samples.
3543
3644
min_c_ : str or int, optional (default=None)
3745
The identifier of the class to be the minority class.
38-
If None, min_c_ is set to be the current minority class.
46+
If ``None``, ``min_c_`` is set to be the current minority class.
47+
Only used when ``ratio`` is a float for back-compatibility.
48+
49+
.. deprecated:: 0.2
50+
``min_c_`` is deprecated in 0.2 and will be removed in 0.4. Use
51+
``ratio`` by passing a ``dict`` instead.
3952
4053
random_state : int, RandomState instance or None, optional (default=None)
4154
If int, random_state is the seed used by the random number generator;
@@ -51,48 +64,57 @@ def make_imbalance(X, y, ratio, min_c_=None, random_state=None):
5164
y_resampled : ndarray, shape (n_samples_new)
5265
The corresponding label of `X_resampled`
5366
54-
"""
55-
if isinstance(ratio, float):
56-
if ratio > 1:
57-
raise ValueError('Ratio cannot be greater than one.'
58-
' Got {}.'.format(ratio))
59-
elif ratio <= 0:
60-
raise ValueError('Ratio have to be strictly positive.'
61-
' Got {}.'.format(ratio))
62-
else:
63-
raise ValueError('Ratio must be a float between 0.0 < ratio < 1.0'
64-
' Got {} instead.'.format(ratio))
67+
Examples
68+
--------
69+
>>> from collections import Counter
70+
>>> from sklearn.datasets import load_iris
71+
>>> from imblearn.datasets import make_imbalance
72+
73+
>>> data = load_iris()
74+
>>> X, y = data.data, data.target
75+
>>> print('Distribution before imbalancing: {}'.format(Counter(y)))
76+
Distribution before imbalancing: Counter({0: 50, 1: 50, 2: 50})
77+
>>> X_res, y_res = make_imbalance(X, y, ratio={0: 10, 1: 20, 2: 30},
78+
... random_state=42)
79+
>>> print('Distribution after imbalancing: {}'.format(Counter(y_res)))
80+
Distribution after imbalancing: Counter({2: 30, 1: 20, 0: 10})
6581
82+
"""
6683
X, y = check_X_y(X, y)
67-
68-
random_state = check_random_state(random_state)
69-
70-
stats_c_ = Counter(y)
84+
target_stats = Counter(y)
85+
# restrict ratio to be a dict or a callable
86+
if isinstance(ratio, dict) or callable(ratio):
87+
ratio_ = check_ratio(ratio, y, 'under-sampling')
88+
# FIXME: deprecated in 0.2 to be removed in 0.4
89+
elif isinstance(ratio, Real):
90+
if min_c_ is None:
91+
min_c_ = min(target_stats, key=target_stats.get)
92+
else:
93+
warnings.warn("'min_c_' is deprecated in 0.2 and will be removed"
94+
" in 0.4. Use 'ratio' as dictionary instead.",
95+
DeprecationWarning)
96+
warnings.warn("'ratio' being a float is deprecated in 0.2 and will not"
97+
" be supported in 0.4. Use a dictionary instead.",
98+
DeprecationWarning)
99+
class_majority = max(target_stats, key=target_stats.get)
100+
ratio_ = {}
101+
for class_sample, n_sample in target_stats.items():
102+
if class_sample == min_c_:
103+
n_min_samples = int(target_stats[class_majority] * ratio)
104+
ratio_[class_sample] = n_min_samples
105+
else:
106+
ratio_[class_sample] = n_sample
107+
ratio_ = check_ratio(ratio_, y, 'under-sampling')
108+
else:
109+
raise ValueError("'ratio' has to be a dictionary or a function"
110+
" returning a dictionary. Got {} instead.".format(
111+
type(ratio)))
71112

72113
LOGGER.info('The original target distribution in the dataset is: %s',
73-
stats_c_)
74-
75-
if min_c_ is None:
76-
min_c_ = min(stats_c_, key=stats_c_.get)
77-
78-
n_min_samples = int(np.count_nonzero(y != min_c_) * ratio)
79-
if n_min_samples > stats_c_[min_c_]:
80-
raise ValueError('Current imbalance ratio of data is lower than'
81-
' desired ratio! Got {} > {}.'.format(
82-
n_min_samples, stats_c_[min_c_]))
83-
if n_min_samples == 0:
84-
raise ValueError('Not enough samples for desired ratio!'
85-
' Got {}.'.format(n_min_samples))
86-
87-
mask = y == min_c_
88-
89-
idx_maj = np.where(~mask)[0]
90-
idx_min = np.where(mask)[0]
91-
idx_min = random_state.choice(idx_min, size=n_min_samples, replace=False)
92-
idx = np.concatenate((idx_min, idx_maj), axis=0)
93-
94-
X_resampled, y_resampled = X[idx, :], y[idx]
95-
114+
target_stats)
115+
rus = RandomUnderSampler(ratio=ratio_, replacement=False,
116+
random_state=random_state)
117+
X_resampled, y_resampled = rus.fit_sample(X, y)
96118
LOGGER.info('Make the dataset imbalanced: %s', Counter(y_resampled))
97119

98120
return X_resampled, y_resampled
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Test the module easy ensemble."""
2+
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
3+
# Christos Aridas
4+
# License: MIT
5+
6+
7+
from __future__ import print_function
8+
9+
from collections import Counter
10+
11+
import numpy as np
12+
13+
from sklearn.datasets import load_iris
14+
from sklearn.utils.testing import (assert_equal, assert_raises_regex,
15+
assert_warns_message)
16+
17+
from imblearn.datasets import make_imbalance
18+
19+
data = load_iris()
20+
X, Y = data.data, data.target
21+
22+
23+
def test_make_imbalance_error():
24+
# we are reusing part of utils.check_ratio, however this is not cover in
25+
# the common tests so we will repeat it here
26+
ratio = {0: -100, 1: 50, 2: 50}
27+
assert_raises_regex(ValueError, "in a class cannot be negative",
28+
make_imbalance, X, Y, ratio)
29+
ratio = {0: 10, 1: 70}
30+
assert_raises_regex(ValueError, "should be less or equal to the original",
31+
make_imbalance, X, Y, ratio)
32+
y_ = np.zeros((X.shape[0], ))
33+
ratio = {0: 10}
34+
assert_raises_regex(ValueError, "needs to have more than 1 class.",
35+
make_imbalance, X, y_, ratio)
36+
ratio = 'random-string'
37+
assert_raises_regex(ValueError, "has to be a dictionary or a function",
38+
make_imbalance, X, Y, ratio)
39+
40+
41+
# FIXME: to be removed in 0.4 due to deprecation
42+
def test_make_imbalance_float():
43+
X_, y_ = assert_warns_message(DeprecationWarning,
44+
"'min_c_' is deprecated in 0.2",
45+
make_imbalance, X, Y, ratio=0.5, min_c_=1)
46+
X_, y_ = assert_warns_message(DeprecationWarning,
47+
"'ratio' being a float is deprecated",
48+
make_imbalance, X, Y, ratio=0.5, min_c_=1)
49+
assert_equal(Counter(y_), {0: 50, 1: 25, 2: 50})
50+
# resample without using min_c_
51+
X_, y_ = make_imbalance(X_, y_, ratio=0.25, min_c_=None)
52+
assert_equal(Counter(y_), {0: 50, 1: 12, 2: 50})
53+
54+
55+
def test_make_imbalance_dict():
56+
ratio = {0: 10, 1: 20, 2: 30}
57+
X_, y_ = make_imbalance(X, Y, ratio=ratio)
58+
assert_equal(Counter(y_), ratio)
59+
60+
ratio = {0: 10, 1: 20}
61+
X_, y_ = make_imbalance(X, Y, ratio=ratio)
62+
assert_equal(Counter(y_), {0: 10, 1: 20, 2: 50})

0 commit comments

Comments
 (0)