diff --git a/doc/whats_new/v0.0.4.rst b/doc/whats_new/v0.0.4.rst index 8cd326d2f..83529b246 100644 --- a/doc/whats_new/v0.0.4.rst +++ b/doc/whats_new/v0.0.4.rst @@ -40,6 +40,11 @@ Enhancement :class:`BorderlineSMOTE` and :class:`SVMSMOTE`. :issue:`440` by :user:`Guillaume Lemaitre `. +- Allow :class:`imblearn.over_sampling.RandomOverSampler` can return indices + using the attributes ``return_indices``. + :issue:`439` by :user:`Hugo Gascon` and + :user:`Guillaume Lemaitre `. + Bug fixes ......... diff --git a/imblearn/over_sampling/random_over_sampler.py b/imblearn/over_sampling/random_over_sampler.py index e7726047d..35181e387 100644 --- a/imblearn/over_sampling/random_over_sampler.py +++ b/imblearn/over_sampling/random_over_sampler.py @@ -32,6 +32,10 @@ class RandomOverSampler(BaseOverSampler): {random_state} + return_indices : bool, optional (default=False) + Whether or not to return the indices of the samples randomly selected + in the corresponding classes. + ratio : str, dict, or callable .. deprecated:: 0.4 Use the parameter ``sampling_strategy`` instead. It will be removed @@ -66,10 +70,13 @@ class RandomOverSampler(BaseOverSampler): """ - def __init__(self, sampling_strategy='auto', random_state=None, + def __init__(self, sampling_strategy='auto', + return_indices=False, + random_state=None, ratio=None): super(RandomOverSampler, self).__init__( sampling_strategy=sampling_strategy, ratio=ratio) + self.return_indices = return_indices self.random_state = random_state def _sample(self, X, y): @@ -106,5 +113,9 @@ def _sample(self, X, y): sample_indices = np.append(sample_indices, target_class_indices[indices]) - return (safe_indexing(X, sample_indices), safe_indexing( - y, sample_indices)) + if self.return_indices: + return (safe_indexing(X, sample_indices), safe_indexing( + y, sample_indices), sample_indices) + else: + return (safe_indexing(X, sample_indices), safe_indexing( + y, sample_indices)) diff --git a/imblearn/over_sampling/tests/test_random_over_sampler.py b/imblearn/over_sampling/tests/test_random_over_sampler.py index 13d0067c8..6b7ed686c 100644 --- a/imblearn/over_sampling/tests/test_random_over_sampler.py +++ b/imblearn/over_sampling/tests/test_random_over_sampler.py @@ -8,6 +8,7 @@ from collections import Counter import numpy as np +from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_array_equal from imblearn.over_sampling import RandomOverSampler @@ -40,7 +41,7 @@ def test_ros_fit_sample(): [0.92923648, 0.76103773], [0.47104475, 0.44386323], [0.92923648, 0.76103773], [0.47104475, 0.44386323]]) y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]) - assert_array_equal(X_resampled, X_gt) + assert_allclose(X_resampled, X_gt) assert_array_equal(y_resampled, y_gt) @@ -56,10 +57,27 @@ def test_ros_fit_sample_half(): [0.09125309, -0.85409574], [0.12372842, 0.6536186], [0.13347175, 0.12167502], [0.094035, -2.55298982]]) y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1]) - assert_array_equal(X_resampled, X_gt) + assert_allclose(X_resampled, X_gt) assert_array_equal(y_resampled, y_gt) +def test_random_over_sampling_return_indices(): + ros = RandomOverSampler(return_indices=True, random_state=RND_SEED) + X_resampled, y_resampled, sample_indices = ros.fit_sample(X, Y) + X_gt = np.array([[0.04352327, -0.20515826], [0.92923648, 0.76103773], [ + 0.20792588, 1.49407907 + ], [0.47104475, 0.44386323], [0.22950086, 0.33367433], [ + 0.15490546, 0.3130677 + ], [0.09125309, -0.85409574], [0.12372842, 0.6536186], + [0.13347175, 0.12167502], [0.094035, -2.55298982], + [0.92923648, 0.76103773], [0.47104475, 0.44386323], + [0.92923648, 0.76103773], [0.47104475, 0.44386323]]) + y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]) + assert_allclose(X_resampled, X_gt) + assert_array_equal(y_resampled, y_gt) + assert_array_equal(np.sort(np.unique(sample_indices)), np.arange(len(X))) + + def test_multiclass_fit_sample(): y = Y.copy() y[5] = 2