6
6
# License: MIT
7
7
8
8
from collections import Counter
9
+ from numbers import Integral
9
10
10
11
import numpy as np
11
12
12
13
from scipy .sparse import issparse
13
14
14
- from sklearn .base import clone
15
+ from sklearn .base import clone , is_classifier
15
16
from sklearn .neighbors import KNeighborsClassifier
16
17
from sklearn .utils import check_random_state , _safe_indexing
18
+ from sklearn .utils .deprecation import deprecated
17
19
18
20
from ..base import BaseCleaningSampler
19
21
from ...utils import Substitution
@@ -58,9 +60,16 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
58
60
corresponds to the class labels from which to sample and the values
59
61
are the number of samples to sample.
60
62
63
+ n_neighbors_ : estimator object
64
+ The validated K-nearest neighbor estimator created from `n_neighbors` parameter.
65
+
61
66
estimator_ : estimator object
62
67
The validated K-nearest neighbor estimator created from `n_neighbors` parameter.
63
68
69
+ .. deprecated:: 0.10
70
+ `estimator_` is deprecated in 0.10 and will be removed in 0.12.
71
+ Use `n_neighbors_` instead.
72
+
64
73
sample_indices_ : ndarray of shape (n_new_samples,)
65
74
Indices of the samples selected.
66
75
@@ -94,18 +103,17 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
94
103
95
104
Examples
96
105
--------
97
- >>> from collections import Counter # doctest: +SKIP
98
- >>> from sklearn.datasets import fetch_mldata # doctest: +SKIP
106
+ >>> from collections import Counter
107
+ >>> from sklearn.datasets import load_breast_cancer
99
108
>>> from imblearn.under_sampling import \
100
- CondensedNearestNeighbour # doctest: +SKIP
101
- >>> pima = fetch_mldata('diabetes_scale') # doctest: +SKIP
102
- >>> X, y = pima['data'], pima['target'] # doctest: +SKIP
103
- >>> print('Original dataset shape %s' % Counter(y)) # doctest: +SKIP
104
- Original dataset shape Counter({{1: 500, -1: 268}}) # doctest: +SKIP
105
- >>> cnn = CondensedNearestNeighbour(random_state=42) # doctest: +SKIP
106
- >>> X_res, y_res = cnn.fit_resample(X, y) #doctest: +SKIP
107
- >>> print('Resampled dataset shape %s' % Counter(y_res)) # doctest: +SKIP
108
- Resampled dataset shape Counter({{-1: 268, 1: 227}}) # doctest: +SKIP
109
+ CondensedNearestNeighbour
110
+ >>> X, y = load_breast_cancer(return_X_y=True)
111
+ >>> print('Original dataset shape %s' % Counter(y))
112
+ Original dataset shape Counter({{1: 357, 0: 212}})
113
+ >>> cnn = CondensedNearestNeighbour(random_state=42)
114
+ >>> X_res, y_res = cnn.fit_resample(X, y)
115
+ >>> print('Resampled dataset shape %s' % Counter(y_res))
116
+ Resampled dataset shape Counter({{0: 212, 1: 50}})
109
117
"""
110
118
111
119
@_deprecate_positional_args
@@ -125,20 +133,20 @@ def __init__(
125
133
self .n_jobs = n_jobs
126
134
127
135
def _validate_estimator (self ):
128
- """Private function to create the NN estimator"""
129
136
if self .n_neighbors is None :
130
- self .estimator_ = KNeighborsClassifier (n_neighbors = 1 , n_jobs = self .n_jobs )
131
- elif isinstance (self .n_neighbors , int ):
132
- self .estimator_ = KNeighborsClassifier (
137
+ self .n_neighbors_ = KNeighborsClassifier (n_neighbors = 1 , n_jobs = self .n_jobs )
138
+ elif isinstance (self .n_neighbors , Integral ):
139
+ self .n_neighbors_ = KNeighborsClassifier (
133
140
n_neighbors = self .n_neighbors , n_jobs = self .n_jobs
134
141
)
135
- elif isinstance (self .n_neighbors , KNeighborsClassifier ):
136
- self .estimator_ = clone (self .n_neighbors )
142
+ elif is_classifier (self .n_neighbors ) and hasattr (
143
+ self .n_neighbors , "n_neighbors"
144
+ ):
145
+ self .n_neighbors_ = clone (self .n_neighbors )
137
146
else :
138
147
raise ValueError (
139
- f"`n_neighbors` has to be a int or an object"
140
- f" inhereited from KNeighborsClassifier."
141
- f" Got { type (self .n_neighbors )} instead."
148
+ "`n_neighbors` must be an integer or a KNN classifier having an "
149
+ f"attribute `n_neighbors`. Got { self .n_neighbors !r} instead."
142
150
)
143
151
144
152
def _fit_resample (self , X , y ):
@@ -175,7 +183,7 @@ def _fit_resample(self, X, y):
175
183
S_y = _safe_indexing (y , S_indices )
176
184
177
185
# fit knn on C
178
- self .estimator_ .fit (C_x , C_y )
186
+ self .n_neighbors_ .fit (C_x , C_y )
179
187
180
188
good_classif_label = idx_maj_sample .copy ()
181
189
# Check each sample in S if we keep it or drop it
@@ -188,7 +196,7 @@ def _fit_resample(self, X, y):
188
196
# Classify on S
189
197
if not issparse (x_sam ):
190
198
x_sam = x_sam .reshape (1 , - 1 )
191
- pred_y = self .estimator_ .predict (x_sam )
199
+ pred_y = self .n_neighbors_ .predict (x_sam )
192
200
193
201
# If the prediction do not agree with the true label
194
202
# append it in C_x
@@ -202,12 +210,12 @@ def _fit_resample(self, X, y):
202
210
C_y = _safe_indexing (y , C_indices )
203
211
204
212
# fit a knn on C
205
- self .estimator_ .fit (C_x , C_y )
213
+ self .n_neighbors_ .fit (C_x , C_y )
206
214
207
215
# This experimental to speed up the search
208
216
# Classify all the element in S and avoid to test the
209
217
# well classified elements
210
- pred_S_y = self .estimator_ .predict (S_x )
218
+ pred_S_y = self .n_neighbors_ .predict (S_x )
211
219
good_classif_label = np .unique (
212
220
np .append (idx_maj_sample , np .flatnonzero (pred_S_y == S_y ))
213
221
)
@@ -224,3 +232,11 @@ def _fit_resample(self, X, y):
224
232
225
233
def _more_tags (self ):
226
234
return {"sample_indices" : True }
235
+
236
+ @deprecated (
237
+ "`estimator_` is deprecated in version 0.10 and will be removed in version "
238
+ "0.12. Use `n_neighbors_` instead."
239
+ )
240
+ @property
241
+ def estimator_ (self ):
242
+ return self .n_neighbors_
0 commit comments