Skip to content

Commit 6916fe9

Browse files
authored
EHN: random sampler can sample from heterogeneous data (#451)
1 parent 41cd9a6 commit 6916fe9

File tree

14 files changed

+156
-39
lines changed

14 files changed

+156
-39
lines changed

doc/over_sampling.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ As a result, the majority class does not take over the other classes during the
5252
training process. Consequently, all classes are represented by the decision
5353
function.
5454

55+
In addition, :class:`RandomOverSampler` allows to sample heterogeneous data
56+
(e.g. containing some strings)::
57+
58+
>>> import numpy as np
59+
>>> X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
60+
... dtype=np.object)
61+
>>> y_hetero = np.array([0, 0, 1])
62+
>>> X_resampled, y_resampled = ros.fit_sample(X_hetero, y_hetero)
63+
>>> print(X_resampled)
64+
[['xxx' 1 1.0]
65+
['yyy' 2 2.0]
66+
['zzz' 3 3.0]
67+
['zzz' 3 3.0]]
68+
>>> print(y_resampled)
69+
[0 0 1 1]
70+
5571
See :ref:`sphx_glr_auto_examples_over-sampling_plot_random_over_sampling.py`
5672
for usage example.
5773

doc/under_sampling.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,19 @@ by considering independently each targeted class::
103103
>>> print(np.vstack({tuple(row) for row in X_resampled}).shape)
104104
(181, 2)
105105

106+
In addition, :class:`RandomUnderSampler` allows to sample heterogeneous data
107+
(e.g. containing some strings)::
108+
109+
>>> X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
110+
... dtype=np.object)
111+
>>> y_hetero = np.array([0, 0, 1])
112+
>>> X_resampled, y_resampled = rus.fit_sample(X_hetero, y_hetero)
113+
>>> print(X_resampled)
114+
[['xxx' 1 1.0]
115+
['zzz' 3 3.0]]
116+
>>> print(y_resampled)
117+
[0 1]
118+
106119
See :ref:`sphx_glr_auto_examples_plot_sampling_strategy_usage.py`.,
107120
:ref:`sphx_glr_auto_examples_under-sampling_plot_comparison_under_sampling.py`,
108121
and :ref:`sphx_glr_auto_examples_under-sampling_plot_random_under_sampler.py`.

doc/whats_new/v0.0.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ Enhancement
4545
:issue:`439` by :user:`Hugo Gascon<hgascon>` and
4646
:user:`Guillaume Lemaitre <glemaitre>`.
4747

48+
- Allow :class:`imblearn.under_sampling.RandomUnderSampler` and
49+
:class:`imblearn.over_sampling.RandomOverSampler` to sample object array
50+
containing strings.
51+
:issue:`448` by :user:`Guillaume Lemaitre <glemaitre>`.
52+
4853
Bug fixes
4954
.........
5055

imblearn/base.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,6 @@ class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
3131

3232
_estimator_type = 'sampler'
3333

34-
def _check_X_y(self, X, y):
35-
"""Private function to check that the X and y in fitting are the same
36-
than in sampling."""
37-
X_hash, y_hash = hash_X_y(X, y)
38-
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
39-
raise RuntimeError("X and y need to be same array earlier fitted.")
40-
4134
def sample(self, X, y):
4235
"""Resample the dataset.
4336
@@ -60,11 +53,10 @@ def sample(self, X, y):
6053
6154
"""
6255
# Check the consistency of X and y
63-
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
64-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
56+
X, y, binarize_y = self._check_X_y(X, y)
6557

6658
check_is_fitted(self, 'sampling_strategy_')
67-
self._check_X_y(X, y)
59+
self._check_X_y_hash(X, y)
6860

6961
output = self._sample(X, y)
7062

@@ -151,6 +143,19 @@ def __init__(self, sampling_strategy='auto', ratio=None):
151143
self.ratio = ratio
152144
self.logger = logging.getLogger(self.__module__)
153145

146+
@staticmethod
147+
def _check_X_y(X, y):
148+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
149+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
150+
return X, y, binarize_y
151+
152+
def _check_X_y_hash(self, X, y):
153+
"""Private function to check that the X and y in fitting are the same
154+
than in sampling."""
155+
X_hash, y_hash = hash_X_y(X, y)
156+
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
157+
raise RuntimeError("X and y need to be same array earlier fitted.")
158+
154159
@property
155160
def ratio_(self):
156161
# FIXME: remove in 0.6
@@ -183,9 +188,9 @@ def fit(self, X, y):
183188
184189
"""
185190
self._deprecate_ratio()
186-
y = check_target_type(y)
187-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
191+
X, y, _ = self._check_X_y(X, y)
188192
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
193+
# _sampling_type is defined in the children base class
189194
self.sampling_strategy_ = check_sampling_strategy(
190195
self.sampling_strategy, y, self._sampling_type)
191196

imblearn/combine/smote_enn.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.base import clone
1313
from sklearn.utils import check_X_y
1414

15-
from ..base import SamplerMixin
15+
from ..base import BaseSampler
1616
from ..over_sampling import SMOTE
1717
from ..over_sampling.base import BaseOverSampler
1818
from ..under_sampling import EditedNearestNeighbours
@@ -24,7 +24,7 @@
2424
@Substitution(
2525
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
2626
random_state=_random_state_docstring)
27-
class SMOTEENN(SamplerMixin):
27+
class SMOTEENN(BaseSampler):
2828
"""Class to perform over-sampling using SMOTE and cleaning using ENN.
2929
3030
Combine over- and under-sampling using SMOTE and Edited Nearest Neighbours.
@@ -125,14 +125,6 @@ def _validate_estimator(self):
125125
else:
126126
self.enn_ = EditedNearestNeighbours(sampling_strategy='all')
127127

128-
@property
129-
def ratio_(self):
130-
# FIXME: remove in 0.6
131-
warnings.warn("'ratio' and 'ratio_' are deprecated. Use "
132-
"'sampling_strategy' and 'sampling_strategy_' instead.",
133-
DeprecationWarning)
134-
return self.sampling_strategy_
135-
136128
def fit(self, X, y):
137129
"""Find the classes statistics before to perform sampling.
138130

imblearn/combine/smote_tomek.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sklearn.base import clone
1414
from sklearn.utils import check_X_y
1515

16-
from ..base import SamplerMixin
16+
from ..base import BaseSampler
1717
from ..over_sampling import SMOTE
1818
from ..over_sampling.base import BaseOverSampler
1919
from ..under_sampling import TomekLinks
@@ -25,7 +25,7 @@
2525
@Substitution(
2626
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
2727
random_state=_random_state_docstring)
28-
class SMOTETomek(SamplerMixin):
28+
class SMOTETomek(BaseSampler):
2929
"""Class to perform over-sampling using SMOTE and cleaning using
3030
Tomek links.
3131
@@ -133,14 +133,6 @@ def _validate_estimator(self):
133133
else:
134134
self.tomek_ = TomekLinks(sampling_strategy='all')
135135

136-
@property
137-
def ratio_(self):
138-
# FIXME: remove in 0.6
139-
warnings.warn("'ratio' and 'ratio_' are deprecated. Use "
140-
"'sampling_strategy' and 'sampling_strategy_' instead.",
141-
DeprecationWarning)
142-
return self.sampling_strategy_
143-
144136
def fit(self, X, y):
145137
"""Find the classes statistics before to perform sampling.
146138

imblearn/ensemble/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def sample(self, X, y):
6060
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
6161

6262
check_is_fitted(self, 'sampling_strategy_')
63-
self._check_X_y(X, y)
63+
self._check_X_y_hash(X, y)
6464

6565
output = self._sample(X, y)
6666

imblearn/over_sampling/random_over_sampler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from collections import Counter
99

1010
import numpy as np
11-
from sklearn.utils import check_random_state, safe_indexing
11+
from sklearn.utils import check_X_y, check_random_state, safe_indexing
1212

1313
from .base import BaseOverSampler
14+
from ..utils import check_target_type
1415
from ..utils import Substitution
1516
from ..utils._docstring import _random_state_docstring
1617

@@ -44,6 +45,8 @@ class RandomOverSampler(BaseOverSampler):
4445
Notes
4546
-----
4647
Supports multi-class resampling by sampling each class independently.
48+
Supports heterogeneous data as object array containing string and numeric
49+
data.
4750
4851
See
4952
:ref:`sphx_glr_auto_examples_over-sampling_plot_comparison_over_sampling.py`,
@@ -79,6 +82,12 @@ def __init__(self, sampling_strategy='auto',
7982
self.return_indices = return_indices
8083
self.random_state = random_state
8184

85+
@staticmethod
86+
def _check_X_y(X, y):
87+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
88+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
89+
return X, y, binarize_y
90+
8291
def _sample(self, X, y):
8392
"""Resample the dataset.
8493

imblearn/over_sampling/tests/test_random_over_sampler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,16 @@ def test_multiclass_fit_sample():
8888
assert count_y_res[0] == 5
8989
assert count_y_res[1] == 5
9090
assert count_y_res[2] == 5
91+
92+
93+
def test_random_over_sampling_heterogeneous_data():
94+
X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
95+
dtype=np.object)
96+
y = np.array([0, 0, 1])
97+
ros = RandomOverSampler(random_state=RND_SEED)
98+
X_res, y_res = ros.fit_sample(X_hetero, y)
99+
100+
assert X_res.shape[0] == 4
101+
assert y_res.shape[0] == 4
102+
assert X_res.dtype == object
103+
assert X_res[-1, 0] in X_hetero[:, 0]

imblearn/under_sampling/prototype_selection/random_under_sampler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from __future__ import division
88

99
import numpy as np
10-
from sklearn.utils import check_random_state, safe_indexing
10+
11+
from sklearn.utils import check_X_y, check_random_state, safe_indexing
1112

1213
from ..base import BaseUnderSampler
14+
from ...utils import check_target_type
1315
from ...utils import Substitution
1416
from ...utils._docstring import _random_state_docstring
1517

@@ -46,6 +48,8 @@ class RandomUnderSampler(BaseUnderSampler):
4648
Notes
4749
-----
4850
Supports multi-class resampling by sampling each class independently.
51+
Supports heterogeneous data as object array containing string and numeric
52+
data.
4953
5054
See
5155
:ref:`sphx_glr_auto_examples_plot_sampling_strategy_usage.py` and
@@ -82,6 +86,12 @@ def __init__(self,
8286
self.return_indices = return_indices
8387
self.replacement = replacement
8488

89+
@staticmethod
90+
def _check_X_y(X, y):
91+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
92+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
93+
return X, y, binarize_y
94+
8595
def _sample(self, X, y):
8696
"""Resample the dataset.
8797

imblearn/under_sampling/prototype_selection/tests/test_random_under_sampler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def test_rus_fit_sample_half():
6363
[0.15490546, 0.3130677], [0.20792588, 1.49407907],
6464
[0.15490546, 0.3130677], [0.12372842, 0.6536186]])
6565
y_gt = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
66-
print(X_resampled)
6766
assert_array_equal(X_resampled, X_gt)
6867
assert_array_equal(y_resampled, y_gt)
6968

@@ -78,3 +77,15 @@ def test_multiclass_fit_sample():
7877
assert count_y_res[0] == 2
7978
assert count_y_res[1] == 2
8079
assert count_y_res[2] == 2
80+
81+
82+
def test_random_under_sampling_heterogeneous_data():
83+
X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
84+
dtype=np.object)
85+
y = np.array([0, 0, 1])
86+
rus = RandomUnderSampler(random_state=RND_SEED)
87+
X_res, y_res = rus.fit_sample(X_hetero, y)
88+
89+
assert X_res.shape[0] == 2
90+
assert y_res.shape[0] == 2
91+
assert X_res.dtype == object

imblearn/utils/estimator_checks.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import numpy as np
1717
from scipy import sparse
1818

19+
from sklearn.base import clone
1920
from sklearn.datasets import make_classification
2021
from sklearn.cluster import KMeans
2122
from sklearn.preprocessing import label_binarize
2223
from sklearn.utils.estimator_checks import check_estimator \
2324
as sklearn_check_estimator, check_parameters_default_constructible
2425
from sklearn.exceptions import NotFittedError
2526
from sklearn.utils.testing import assert_allclose
27+
from sklearn.utils.testing import assert_raises_regex
2628
from sklearn.utils.testing import set_random_state
2729
from sklearn.utils.multiclass import type_of_target
2830

@@ -35,6 +37,32 @@
3537
from imblearn.utils.testing import warns
3638

3739
DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE']
40+
SUPPORT_STRING = ['RandomUnderSampler', 'RandomOverSampler']
41+
42+
43+
def monkey_patch_check_dtype_object(name, estimator_orig):
44+
# check that estimators treat dtype object as numeric if possible
45+
rng = np.random.RandomState(0)
46+
X = rng.rand(40, 10).astype(object)
47+
y = np.array([0] * 10 + [1] * 30, dtype=np.int)
48+
estimator = clone(estimator_orig)
49+
50+
estimator.fit(X, y)
51+
if hasattr(estimator, "sample"):
52+
estimator.sample(X, y)
53+
54+
try:
55+
estimator.fit(X, y.astype(object))
56+
except Exception as e:
57+
if "Unknown label type" not in str(e):
58+
raise
59+
60+
if name not in SUPPORT_STRING:
61+
X[0, 0] = {'foo': 'bar'}
62+
msg = "argument must be a string or a number"
63+
assert_raises_regex(TypeError, msg, estimator.fit, X, y)
64+
else:
65+
estimator.fit(X, y)
3866

3967

4068
def _yield_sampler_checks(name, Estimator):
@@ -74,7 +102,11 @@ def check_estimator(Estimator):
74102
Class to check. Estimator is a class object (not an instance).
75103
"""
76104
name = Estimator.__name__
77-
# test scikit-learn compatibility
105+
# monkey patch check_dtype_object for the sampler allowing strings
106+
import sklearn.utils.estimator_checks
107+
sklearn.utils.estimator_checks.check_dtype_object = \
108+
monkey_patch_check_dtype_object
109+
# scikit-learn common tests
78110
sklearn_check_estimator(Estimator)
79111
check_parameters_default_constructible(name, Estimator)
80112
for check in _yield_all_checks(name, Estimator):

imblearn/utils/tests/test_validation.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ def test_hash_X_y():
375375
assert hash_X_y(X, y) == (joblib.hash(X), joblib.hash(y))
376376

377377

378+
def test_hash_X_y_pandas():
379+
pd = pytest.importorskip("pandas")
380+
rng = check_random_state(0)
381+
X = pd.DataFrame(rng.randn(2000, 20))
382+
y = pd.Series([0] * 500 + [1] * 1500)
383+
assert hash_X_y(X, y, 10, 10) == (joblib.hash(X.iloc[::200, ::2]),
384+
joblib.hash(y.iloc[::200]))
385+
386+
X = pd.DataFrame(rng.randn(5, 2))
387+
y = pd.Series([0] * 2 + [1] * 3)
388+
# all data will be used in this case
389+
assert hash_X_y(X, y) == (joblib.hash(X), joblib.hash(y))
390+
391+
378392
@pytest.mark.parametrize(
379393
"sampling_strategy, sampling_type, expected_result",
380394
[({3: 25, 1: 25, 2: 25}, 'under-sampling',

0 commit comments

Comments
 (0)