5
5
# License: MIT
6
6
7
7
import numbers
8
+ import warnings
8
9
from collections import Counter
9
10
10
11
import numpy as np
12
+ from sklearn .base import clone
13
+ from sklearn .neighbors import KNeighborsClassifier , NearestNeighbors
11
14
from sklearn .utils import _safe_indexing
12
15
13
- from ...utils import Substitution , check_neighbors_object
16
+ from ...utils import Substitution
14
17
from ...utils ._docstring import _n_jobs_docstring
15
- from ...utils ._param_validation import HasMethods , Interval , StrOptions
16
- from ...utils .fixes import _mode
18
+ from ...utils ._param_validation import HasMethods , Hidden , Interval , StrOptions
17
19
from ..base import BaseCleaningSampler
18
20
from ._edited_nearest_neighbours import EditedNearestNeighbours
19
21
@@ -35,9 +37,14 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
35
37
----------
36
38
{sampling_strategy}
37
39
40
+ edited_nearest_neighbours : estimator object, default=None
41
+ The :class:`~imblearn.under_sampling.EditedNearestNeighbours` (ENN)
42
+ object to clean the dataset. If `None`, a default ENN is created with
43
+ `kind_sel="mode"` and `n_neighbors=n_neighbors`.
44
+
38
45
n_neighbors : int or estimator object, default=3
39
46
If ``int``, size of the neighbourhood to consider to compute the
40
- nearest neighbors. If object, an estimator that inherits from
47
+ K- nearest neighbors. If object, an estimator that inherits from
41
48
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
42
49
find the nearest-neighbors. By default, it will be a 3-NN.
43
50
@@ -52,6 +59,11 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
52
59
The strategy `"all"` will be less conservative than `'mode'`. Thus,
53
60
more samples will be removed when `kind_sel="all"` generally.
54
61
62
+ .. deprecated:: 0.12
63
+ `kind_sel` is deprecated in 0.12 and will be removed in 0.14.
64
+ Currently the parameter has no effect and corresponds always to the
65
+ `"all"` strategy.
66
+
55
67
threshold_cleaning : float, default=0.5
56
68
Threshold used to whether consider a class or not during the cleaning
57
69
after applying ENN. A class will be considered during cleaning when:
@@ -70,9 +82,16 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
70
82
corresponds to the class labels from which to sample and the values
71
83
are the number of samples to sample.
72
84
85
+ edited_nearest_neighbours_ : estimator object
86
+ The edited nearest neighbour object used to make the first resampling.
87
+
73
88
nn_ : estimator object
74
89
Validated K-nearest Neighbours object created from `n_neighbors` parameter.
75
90
91
+ classes_to_clean_ : list
92
+ The classes considered with under-sampling by `nn_` in the second cleaning
93
+ phase.
94
+
76
95
sample_indices_ : ndarray of shape (n_new_samples,)
77
96
Indices of the samples selected.
78
97
@@ -118,52 +137,75 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
118
137
>>> ncr = NeighbourhoodCleaningRule()
119
138
>>> X_res, y_res = ncr.fit_resample(X, y)
120
139
>>> print('Resampled dataset shape %s' % Counter(y_res))
121
- Resampled dataset shape Counter({{1: 877 , 0: 100}})
140
+ Resampled dataset shape Counter({{1: 888 , 0: 100}})
122
141
"""
123
142
124
143
_parameter_constraints : dict = {
125
144
** BaseCleaningSampler ._parameter_constraints ,
145
+ "edited_nearest_neighbours" : [
146
+ HasMethods (["fit_resample" ]),
147
+ None ,
148
+ ],
126
149
"n_neighbors" : [
127
150
Interval (numbers .Integral , 1 , None , closed = "left" ),
128
151
HasMethods (["kneighbors" , "kneighbors_graph" ]),
129
152
],
130
- "kind_sel" : [StrOptions ({"all" , "mode" })],
131
- "threshold_cleaning" : [Interval (numbers .Real , 0 , 1 , closed = "neither" )],
153
+ "kind_sel" : [StrOptions ({"all" , "mode" }), Hidden ( StrOptions ({ "deprecated" })) ],
154
+ "threshold_cleaning" : [Interval (numbers .Real , 0 , None , closed = "neither" )],
132
155
"n_jobs" : [numbers .Integral , None ],
133
156
}
134
157
135
158
def __init__ (
136
159
self ,
137
160
* ,
138
161
sampling_strategy = "auto" ,
162
+ edited_nearest_neighbours = None ,
139
163
n_neighbors = 3 ,
140
- kind_sel = "all " ,
164
+ kind_sel = "deprecated " ,
141
165
threshold_cleaning = 0.5 ,
142
166
n_jobs = None ,
143
167
):
144
168
super ().__init__ (sampling_strategy = sampling_strategy )
169
+ self .edited_nearest_neighbours = edited_nearest_neighbours
145
170
self .n_neighbors = n_neighbors
146
171
self .kind_sel = kind_sel
147
172
self .threshold_cleaning = threshold_cleaning
148
173
self .n_jobs = n_jobs
149
174
150
175
def _validate_estimator (self ):
151
176
"""Create the objects required by NCR."""
152
- self .nn_ = check_neighbors_object (
153
- "n_neighbors" , self .n_neighbors , additional_neighbor = 1
154
- )
155
- self .nn_ .set_params (** {"n_jobs" : self .n_jobs })
177
+ if isinstance (self .n_neighbors , numbers .Integral ):
178
+ self .nn_ = KNeighborsClassifier (
179
+ n_neighbors = self .n_neighbors , n_jobs = self .n_jobs
180
+ )
181
+ elif isinstance (self .n_neighbors , NearestNeighbors ):
182
+ # backward compatibility when passing a NearestNeighbors object
183
+ self .nn_ = KNeighborsClassifier (
184
+ n_neighbors = self .n_neighbors .n_neighbors - 1 , n_jobs = self .n_jobs
185
+ )
186
+ else :
187
+ self .nn_ = clone (self .n_neighbors )
188
+
189
+ if self .edited_nearest_neighbours is None :
190
+ self .edited_nearest_neighbours_ = EditedNearestNeighbours (
191
+ sampling_strategy = self .sampling_strategy ,
192
+ n_neighbors = self .n_neighbors ,
193
+ kind_sel = "mode" ,
194
+ n_jobs = self .n_jobs ,
195
+ )
196
+ else :
197
+ self .edited_nearest_neighbours_ = clone (self .edited_nearest_neighbours )
156
198
157
199
def _fit_resample (self , X , y ):
200
+ if self .kind_sel != "deprecated" :
201
+ warnings .warn (
202
+ "`kind_sel` is deprecated in 0.12 and will be removed in 0.14. "
203
+ "It already has not effect and corresponds to the `'all'` option." ,
204
+ FutureWarning ,
205
+ )
158
206
self ._validate_estimator ()
159
- enn = EditedNearestNeighbours (
160
- sampling_strategy = self .sampling_strategy ,
161
- n_neighbors = self .n_neighbors ,
162
- kind_sel = "mode" ,
163
- n_jobs = self .n_jobs ,
164
- )
165
- enn .fit_resample (X , y )
166
- index_not_a1 = enn .sample_indices_
207
+ self .edited_nearest_neighbours_ .fit_resample (X , y )
208
+ index_not_a1 = self .edited_nearest_neighbours_ .sample_indices_
167
209
index_a1 = np .ones (y .shape , dtype = bool )
168
210
index_a1 [index_not_a1 ] = False
169
211
index_a1 = np .flatnonzero (index_a1 )
@@ -172,30 +214,34 @@ def _fit_resample(self, X, y):
172
214
target_stats = Counter (y )
173
215
class_minority = min (target_stats , key = target_stats .get )
174
216
# compute which classes to consider for cleaning for the A2 group
175
- classes_under_sample = [
217
+ self . classes_to_clean_ = [
176
218
c
177
219
for c , n_samples in target_stats .items ()
178
220
if (
179
221
c in self .sampling_strategy_ .keys ()
180
- and (n_samples > X . shape [ 0 ] * self .threshold_cleaning )
222
+ and (n_samples > target_stats [ class_minority ] * self .threshold_cleaning )
181
223
)
182
224
]
183
- self .nn_ .fit (X )
225
+ self .nn_ .fit (X , y )
226
+
184
227
class_minority_indices = np .flatnonzero (y == class_minority )
185
- X_class = _safe_indexing (X , class_minority_indices )
186
- y_class = _safe_indexing (y , class_minority_indices )
187
- nnhood_idx = self .nn_ .kneighbors (X_class , return_distance = False )[:, 1 :]
188
- nnhood_label = y [nnhood_idx ]
189
- if self .kind_sel == "mode" :
190
- nnhood_label_majority , _ = _mode (nnhood_label , axis = 1 )
191
- nnhood_bool = np .ravel (nnhood_label_majority ) == y_class
192
- else : # self.kind_sel == "all":
193
- nnhood_label_majority = nnhood_label == class_minority
194
- nnhood_bool = np .all (nnhood_label , axis = 1 )
195
- # compute a2 group
196
- index_a2 = np .ravel (nnhood_idx [~ nnhood_bool ])
197
- index_a2 = np .unique (
198
- [index for index in index_a2 if y [index ] in classes_under_sample ]
228
+ X_minority = _safe_indexing (X , class_minority_indices )
229
+ y_minority = _safe_indexing (y , class_minority_indices )
230
+
231
+ y_pred_minority = self .nn_ .predict (X_minority )
232
+ # add an additional sample since the query points contains the original dataset
233
+ neighbors_to_minority_indices = self .nn_ .kneighbors (
234
+ X_minority , n_neighbors = self .nn_ .n_neighbors + 1 , return_distance = False
235
+ )[:, 1 :]
236
+
237
+ mask_misclassified_minority = y_pred_minority != y_minority
238
+ index_a2 = np .ravel (neighbors_to_minority_indices [mask_misclassified_minority ])
239
+ index_a2 = np .array (
240
+ [
241
+ index
242
+ for index in np .unique (index_a2 )
243
+ if y [index ] in self .classes_to_clean_
244
+ ]
199
245
)
200
246
201
247
union_a1_a2 = np .union1d (index_a1 , index_a2 ).astype (int )
0 commit comments