Skip to content

Commit aa7fbdd

Browse files
hgasconglemaitre
authored andcommitted
EHN: Add option to return indices in RandomOverSampler (#439)
1 parent 7c48491 commit aa7fbdd

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ Enhancement
4040
:class:`BorderlineSMOTE` and :class:`SVMSMOTE`.
4141
:issue:`440` by :user:`Guillaume Lemaitre <glemaitre>`.
4242

43+
- Allow :class:`imblearn.over_sampling.RandomOverSampler` can return indices
44+
using the attributes ``return_indices``.
45+
:issue:`439` by :user:`Hugo Gascon<hgascon>` and
46+
:user:`Guillaume Lemaitre <glemaitre>`.
47+
4348
Bug fixes
4449
.........
4550

imblearn/over_sampling/random_over_sampler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class RandomOverSampler(BaseOverSampler):
3232
3333
{random_state}
3434
35+
return_indices : bool, optional (default=False)
36+
Whether or not to return the indices of the samples randomly selected
37+
in the corresponding classes.
38+
3539
ratio : str, dict, or callable
3640
.. deprecated:: 0.4
3741
Use the parameter ``sampling_strategy`` instead. It will be removed
@@ -66,10 +70,13 @@ class RandomOverSampler(BaseOverSampler):
6670
6771
"""
6872

69-
def __init__(self, sampling_strategy='auto', random_state=None,
73+
def __init__(self, sampling_strategy='auto',
74+
return_indices=False,
75+
random_state=None,
7076
ratio=None):
7177
super(RandomOverSampler, self).__init__(
7278
sampling_strategy=sampling_strategy, ratio=ratio)
79+
self.return_indices = return_indices
7380
self.random_state = random_state
7481

7582
def _sample(self, X, y):
@@ -106,5 +113,9 @@ def _sample(self, X, y):
106113
sample_indices = np.append(sample_indices,
107114
target_class_indices[indices])
108115

109-
return (safe_indexing(X, sample_indices), safe_indexing(
110-
y, sample_indices))
116+
if self.return_indices:
117+
return (safe_indexing(X, sample_indices), safe_indexing(
118+
y, sample_indices), sample_indices)
119+
else:
120+
return (safe_indexing(X, sample_indices), safe_indexing(
121+
y, sample_indices))

imblearn/over_sampling/tests/test_random_over_sampler.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections import Counter
99

1010
import numpy as np
11+
from sklearn.utils.testing import assert_allclose
1112
from sklearn.utils.testing import assert_array_equal
1213

1314
from imblearn.over_sampling import RandomOverSampler
@@ -40,7 +41,7 @@ def test_ros_fit_sample():
4041
[0.92923648, 0.76103773], [0.47104475, 0.44386323],
4142
[0.92923648, 0.76103773], [0.47104475, 0.44386323]])
4243
y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0])
43-
assert_array_equal(X_resampled, X_gt)
44+
assert_allclose(X_resampled, X_gt)
4445
assert_array_equal(y_resampled, y_gt)
4546

4647

@@ -56,10 +57,27 @@ def test_ros_fit_sample_half():
5657
[0.09125309, -0.85409574], [0.12372842, 0.6536186],
5758
[0.13347175, 0.12167502], [0.094035, -2.55298982]])
5859
y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1])
59-
assert_array_equal(X_resampled, X_gt)
60+
assert_allclose(X_resampled, X_gt)
6061
assert_array_equal(y_resampled, y_gt)
6162

6263

64+
def test_random_over_sampling_return_indices():
65+
ros = RandomOverSampler(return_indices=True, random_state=RND_SEED)
66+
X_resampled, y_resampled, sample_indices = ros.fit_sample(X, Y)
67+
X_gt = np.array([[0.04352327, -0.20515826], [0.92923648, 0.76103773], [
68+
0.20792588, 1.49407907
69+
], [0.47104475, 0.44386323], [0.22950086, 0.33367433], [
70+
0.15490546, 0.3130677
71+
], [0.09125309, -0.85409574], [0.12372842, 0.6536186],
72+
[0.13347175, 0.12167502], [0.094035, -2.55298982],
73+
[0.92923648, 0.76103773], [0.47104475, 0.44386323],
74+
[0.92923648, 0.76103773], [0.47104475, 0.44386323]])
75+
y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0])
76+
assert_allclose(X_resampled, X_gt)
77+
assert_array_equal(y_resampled, y_gt)
78+
assert_array_equal(np.sort(np.unique(sample_indices)), np.arange(len(X)))
79+
80+
6381
def test_multiclass_fit_sample():
6482
y = Y.copy()
6583
y[5] = 2

0 commit comments

Comments
 (0)