Skip to content

Commit 6622afb

Browse files
authored
FIX/DEPR follow literature for the implementation of NCR (#1012)
1 parent 95e21e1 commit 6622afb

File tree

5 files changed

+177
-107
lines changed

5 files changed

+177
-107
lines changed

doc/under_sampling.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,10 @@ union of samples to be rejected between the :class:`EditedNearestNeighbours`
353353
and the output a 3 nearest neighbors classifier. The class can be used as::
354354

355355
>>> from imblearn.under_sampling import NeighbourhoodCleaningRule
356-
>>> ncr = NeighbourhoodCleaningRule()
356+
>>> ncr = NeighbourhoodCleaningRule(n_neighbors=11)
357357
>>> X_resampled, y_resampled = ncr.fit_resample(X, y)
358358
>>> print(sorted(Counter(y_resampled).items()))
359-
[(0, 64), (1, 234), (2, 4666)]
359+
[(0, 64), (1, 193), (2, 4535)]
360360

361361
.. image:: ./auto_examples/under-sampling/images/sphx_glr_plot_comparison_under_sampling_005.png
362362
:target: ./auto_examples/under-sampling/plot_comparison_under_sampling.html

doc/whats_new/v0.12.rst

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,27 @@ Version 0.12.0 (Under development)
66
Changelog
77
---------
88

9+
Bug fixes
10+
.........
11+
12+
- Fix a bug in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule` where the
13+
`kind_sel="all"` was not working as explained in the literature.
14+
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.
15+
16+
- Fix a bug in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule` where the
17+
`threshold_cleaning` ratio was multiplied on the total number of samples instead of
18+
the number of samples in the minority class.
19+
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.
20+
921
Deprecations
1022
............
1123

1224
- Deprecate `estimator_` argument in favor of `estimators_` for the classes
1325
:class:`~imblearn.under_sampling.CondensedNearestNeighbour` and
1426
:class:`~imblearn.under_sampling.OneSidedSelection`. `estimator_` will be removed
1527
in 0.14.
16-
:pr:`xxx` by :user:`Guillaume Lemaitre <glemaitre>`.
28+
:pr:`1011` by :user:`Guillaume Lemaitre <glemaitre>`.
29+
30+
- Deprecate `kind_sel` in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule.
31+
It will be removed in 0.14. The parameter does not have any effect.
32+
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.

examples/under-sampling/plot_comparison_under_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def plot_decision_function(X, y, clf, ax, title=None):
264264
samplers = [
265265
CondensedNearestNeighbour(random_state=0),
266266
OneSidedSelection(random_state=0),
267-
NeighbourhoodCleaningRule(),
267+
NeighbourhoodCleaningRule(n_neighbors=11),
268268
]
269269

270270
for ax, sampler in zip(axs, samplers):

imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
# License: MIT
66

77
import numbers
8+
import warnings
89
from collections import Counter
910

1011
import numpy as np
12+
from sklearn.base import clone
13+
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
1114
from sklearn.utils import _safe_indexing
1215

13-
from ...utils import Substitution, check_neighbors_object
16+
from ...utils import Substitution
1417
from ...utils._docstring import _n_jobs_docstring
15-
from ...utils._param_validation import HasMethods, Interval, StrOptions
16-
from ...utils.fixes import _mode
18+
from ...utils._param_validation import HasMethods, Hidden, Interval, StrOptions
1719
from ..base import BaseCleaningSampler
1820
from ._edited_nearest_neighbours import EditedNearestNeighbours
1921

@@ -35,9 +37,14 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
3537
----------
3638
{sampling_strategy}
3739
40+
edited_nearest_neighbours : estimator object, default=None
41+
The :class:`~imblearn.under_sampling.EditedNearestNeighbours` (ENN)
42+
object to clean the dataset. If `None`, a default ENN is created with
43+
`kind_sel="mode"` and `n_neighbors=n_neighbors`.
44+
3845
n_neighbors : int or estimator object, default=3
3946
If ``int``, size of the neighbourhood to consider to compute the
40-
nearest neighbors. If object, an estimator that inherits from
47+
K-nearest neighbors. If object, an estimator that inherits from
4148
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
4249
find the nearest-neighbors. By default, it will be a 3-NN.
4350
@@ -52,6 +59,11 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
5259
The strategy `"all"` will be less conservative than `'mode'`. Thus,
5360
more samples will be removed when `kind_sel="all"` generally.
5461
62+
.. deprecated:: 0.12
63+
`kind_sel` is deprecated in 0.12 and will be removed in 0.14.
64+
Currently the parameter has no effect and corresponds always to the
65+
`"all"` strategy.
66+
5567
threshold_cleaning : float, default=0.5
5668
Threshold used to whether consider a class or not during the cleaning
5769
after applying ENN. A class will be considered during cleaning when:
@@ -70,9 +82,16 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
7082
corresponds to the class labels from which to sample and the values
7183
are the number of samples to sample.
7284
85+
edited_nearest_neighbours_ : estimator object
86+
The edited nearest neighbour object used to make the first resampling.
87+
7388
nn_ : estimator object
7489
Validated K-nearest Neighbours object created from `n_neighbors` parameter.
7590
91+
classes_to_clean_ : list
92+
The classes considered with under-sampling by `nn_` in the second cleaning
93+
phase.
94+
7695
sample_indices_ : ndarray of shape (n_new_samples,)
7796
Indices of the samples selected.
7897
@@ -118,52 +137,75 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
118137
>>> ncr = NeighbourhoodCleaningRule()
119138
>>> X_res, y_res = ncr.fit_resample(X, y)
120139
>>> print('Resampled dataset shape %s' % Counter(y_res))
121-
Resampled dataset shape Counter({{1: 877, 0: 100}})
140+
Resampled dataset shape Counter({{1: 888, 0: 100}})
122141
"""
123142

124143
_parameter_constraints: dict = {
125144
**BaseCleaningSampler._parameter_constraints,
145+
"edited_nearest_neighbours": [
146+
HasMethods(["fit_resample"]),
147+
None,
148+
],
126149
"n_neighbors": [
127150
Interval(numbers.Integral, 1, None, closed="left"),
128151
HasMethods(["kneighbors", "kneighbors_graph"]),
129152
],
130-
"kind_sel": [StrOptions({"all", "mode"})],
131-
"threshold_cleaning": [Interval(numbers.Real, 0, 1, closed="neither")],
153+
"kind_sel": [StrOptions({"all", "mode"}), Hidden(StrOptions({"deprecated"}))],
154+
"threshold_cleaning": [Interval(numbers.Real, 0, None, closed="neither")],
132155
"n_jobs": [numbers.Integral, None],
133156
}
134157

135158
def __init__(
136159
self,
137160
*,
138161
sampling_strategy="auto",
162+
edited_nearest_neighbours=None,
139163
n_neighbors=3,
140-
kind_sel="all",
164+
kind_sel="deprecated",
141165
threshold_cleaning=0.5,
142166
n_jobs=None,
143167
):
144168
super().__init__(sampling_strategy=sampling_strategy)
169+
self.edited_nearest_neighbours = edited_nearest_neighbours
145170
self.n_neighbors = n_neighbors
146171
self.kind_sel = kind_sel
147172
self.threshold_cleaning = threshold_cleaning
148173
self.n_jobs = n_jobs
149174

150175
def _validate_estimator(self):
151176
"""Create the objects required by NCR."""
152-
self.nn_ = check_neighbors_object(
153-
"n_neighbors", self.n_neighbors, additional_neighbor=1
154-
)
155-
self.nn_.set_params(**{"n_jobs": self.n_jobs})
177+
if isinstance(self.n_neighbors, numbers.Integral):
178+
self.nn_ = KNeighborsClassifier(
179+
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
180+
)
181+
elif isinstance(self.n_neighbors, NearestNeighbors):
182+
# backward compatibility when passing a NearestNeighbors object
183+
self.nn_ = KNeighborsClassifier(
184+
n_neighbors=self.n_neighbors.n_neighbors - 1, n_jobs=self.n_jobs
185+
)
186+
else:
187+
self.nn_ = clone(self.n_neighbors)
188+
189+
if self.edited_nearest_neighbours is None:
190+
self.edited_nearest_neighbours_ = EditedNearestNeighbours(
191+
sampling_strategy=self.sampling_strategy,
192+
n_neighbors=self.n_neighbors,
193+
kind_sel="mode",
194+
n_jobs=self.n_jobs,
195+
)
196+
else:
197+
self.edited_nearest_neighbours_ = clone(self.edited_nearest_neighbours)
156198

157199
def _fit_resample(self, X, y):
200+
if self.kind_sel != "deprecated":
201+
warnings.warn(
202+
"`kind_sel` is deprecated in 0.12 and will be removed in 0.14. "
203+
"It already has not effect and corresponds to the `'all'` option.",
204+
FutureWarning,
205+
)
158206
self._validate_estimator()
159-
enn = EditedNearestNeighbours(
160-
sampling_strategy=self.sampling_strategy,
161-
n_neighbors=self.n_neighbors,
162-
kind_sel="mode",
163-
n_jobs=self.n_jobs,
164-
)
165-
enn.fit_resample(X, y)
166-
index_not_a1 = enn.sample_indices_
207+
self.edited_nearest_neighbours_.fit_resample(X, y)
208+
index_not_a1 = self.edited_nearest_neighbours_.sample_indices_
167209
index_a1 = np.ones(y.shape, dtype=bool)
168210
index_a1[index_not_a1] = False
169211
index_a1 = np.flatnonzero(index_a1)
@@ -172,30 +214,34 @@ def _fit_resample(self, X, y):
172214
target_stats = Counter(y)
173215
class_minority = min(target_stats, key=target_stats.get)
174216
# compute which classes to consider for cleaning for the A2 group
175-
classes_under_sample = [
217+
self.classes_to_clean_ = [
176218
c
177219
for c, n_samples in target_stats.items()
178220
if (
179221
c in self.sampling_strategy_.keys()
180-
and (n_samples > X.shape[0] * self.threshold_cleaning)
222+
and (n_samples > target_stats[class_minority] * self.threshold_cleaning)
181223
)
182224
]
183-
self.nn_.fit(X)
225+
self.nn_.fit(X, y)
226+
184227
class_minority_indices = np.flatnonzero(y == class_minority)
185-
X_class = _safe_indexing(X, class_minority_indices)
186-
y_class = _safe_indexing(y, class_minority_indices)
187-
nnhood_idx = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
188-
nnhood_label = y[nnhood_idx]
189-
if self.kind_sel == "mode":
190-
nnhood_label_majority, _ = _mode(nnhood_label, axis=1)
191-
nnhood_bool = np.ravel(nnhood_label_majority) == y_class
192-
else: # self.kind_sel == "all":
193-
nnhood_label_majority = nnhood_label == class_minority
194-
nnhood_bool = np.all(nnhood_label, axis=1)
195-
# compute a2 group
196-
index_a2 = np.ravel(nnhood_idx[~nnhood_bool])
197-
index_a2 = np.unique(
198-
[index for index in index_a2 if y[index] in classes_under_sample]
228+
X_minority = _safe_indexing(X, class_minority_indices)
229+
y_minority = _safe_indexing(y, class_minority_indices)
230+
231+
y_pred_minority = self.nn_.predict(X_minority)
232+
# add an additional sample since the query points contains the original dataset
233+
neighbors_to_minority_indices = self.nn_.kneighbors(
234+
X_minority, n_neighbors=self.nn_.n_neighbors + 1, return_distance=False
235+
)[:, 1:]
236+
237+
mask_misclassified_minority = y_pred_minority != y_minority
238+
index_a2 = np.ravel(neighbors_to_minority_indices[mask_misclassified_minority])
239+
index_a2 = np.array(
240+
[
241+
index
242+
for index in np.unique(index_a2)
243+
if y[index] in self.classes_to_clean_
244+
]
199245
)
200246

201247
union_a1_a2 = np.union1d(index_a1, index_a2).astype(int)

0 commit comments

Comments
 (0)