Skip to content

Commit 80e6826

Browse files
committed
API duck-typing for n_neighbors in CNN and deprecate estimator_
1 parent e802a19 commit 80e6826

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
# License: MIT
77

88
from collections import Counter
9+
from numbers import Integral
910

1011
import numpy as np
1112

1213
from scipy.sparse import issparse
1314

14-
from sklearn.base import clone
15+
from sklearn.base import clone, is_classifier
1516
from sklearn.neighbors import KNeighborsClassifier
1617
from sklearn.utils import check_random_state, _safe_indexing
18+
from sklearn.utils.deprecation import deprecated
1719

1820
from ..base import BaseCleaningSampler
1921
from ...utils import Substitution
@@ -58,9 +60,16 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
5860
corresponds to the class labels from which to sample and the values
5961
are the number of samples to sample.
6062
63+
n_neighbors_ : estimator object
64+
The validated K-nearest neighbor estimator created from `n_neighbors` parameter.
65+
6166
estimator_ : estimator object
6267
The validated K-nearest neighbor estimator created from `n_neighbors` parameter.
6368
69+
.. deprecated:: 0.10
70+
`estimator_` is deprecated in 0.10 and will be removed in 0.12.
71+
Use `n_neighbors_` instead.
72+
6473
sample_indices_ : ndarray of shape (n_new_samples,)
6574
Indices of the samples selected.
6675
@@ -94,18 +103,17 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
94103
95104
Examples
96105
--------
97-
>>> from collections import Counter # doctest: +SKIP
98-
>>> from sklearn.datasets import fetch_mldata # doctest: +SKIP
106+
>>> from collections import Counter
107+
>>> from sklearn.datasets import load_breast_cancer
99108
>>> from imblearn.under_sampling import \
100-
CondensedNearestNeighbour # doctest: +SKIP
101-
>>> pima = fetch_mldata('diabetes_scale') # doctest: +SKIP
102-
>>> X, y = pima['data'], pima['target'] # doctest: +SKIP
103-
>>> print('Original dataset shape %s' % Counter(y)) # doctest: +SKIP
104-
Original dataset shape Counter({{1: 500, -1: 268}}) # doctest: +SKIP
105-
>>> cnn = CondensedNearestNeighbour(random_state=42) # doctest: +SKIP
106-
>>> X_res, y_res = cnn.fit_resample(X, y) #doctest: +SKIP
107-
>>> print('Resampled dataset shape %s' % Counter(y_res)) # doctest: +SKIP
108-
Resampled dataset shape Counter({{-1: 268, 1: 227}}) # doctest: +SKIP
109+
CondensedNearestNeighbour
110+
>>> X, y = load_breast_cancer(return_X_y=True)
111+
>>> print('Original dataset shape %s' % Counter(y))
112+
Original dataset shape Counter({{1: 357, 0: 212}})
113+
>>> cnn = CondensedNearestNeighbour(random_state=42)
114+
>>> X_res, y_res = cnn.fit_resample(X, y)
115+
>>> print('Resampled dataset shape %s' % Counter(y_res))
116+
Resampled dataset shape Counter({{0: 212, 1: 50}})
109117
"""
110118

111119
@_deprecate_positional_args
@@ -125,20 +133,20 @@ def __init__(
125133
self.n_jobs = n_jobs
126134

127135
def _validate_estimator(self):
128-
"""Private function to create the NN estimator"""
129136
if self.n_neighbors is None:
130-
self.estimator_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
131-
elif isinstance(self.n_neighbors, int):
132-
self.estimator_ = KNeighborsClassifier(
137+
self.n_neighbors_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
138+
elif isinstance(self.n_neighbors, Integral):
139+
self.n_neighbors_ = KNeighborsClassifier(
133140
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
134141
)
135-
elif isinstance(self.n_neighbors, KNeighborsClassifier):
136-
self.estimator_ = clone(self.n_neighbors)
142+
elif is_classifier(self.n_neighbors) and hasattr(
143+
self.n_neighbors, "n_neighbors"
144+
):
145+
self.n_neighbors_ = clone(self.n_neighbors)
137146
else:
138147
raise ValueError(
139-
f"`n_neighbors` has to be a int or an object"
140-
f" inhereited from KNeighborsClassifier."
141-
f" Got {type(self.n_neighbors)} instead."
148+
"`n_neighbors` must be an integer or a KNN classifier having an "
149+
f"attribute `n_neighbors`. Got {self.n_neighbors!r} instead."
142150
)
143151

144152
def _fit_resample(self, X, y):
@@ -175,7 +183,7 @@ def _fit_resample(self, X, y):
175183
S_y = _safe_indexing(y, S_indices)
176184

177185
# fit knn on C
178-
self.estimator_.fit(C_x, C_y)
186+
self.n_neighbors_.fit(C_x, C_y)
179187

180188
good_classif_label = idx_maj_sample.copy()
181189
# Check each sample in S if we keep it or drop it
@@ -188,7 +196,7 @@ def _fit_resample(self, X, y):
188196
# Classify on S
189197
if not issparse(x_sam):
190198
x_sam = x_sam.reshape(1, -1)
191-
pred_y = self.estimator_.predict(x_sam)
199+
pred_y = self.n_neighbors_.predict(x_sam)
192200

193201
# If the prediction do not agree with the true label
194202
# append it in C_x
@@ -202,12 +210,12 @@ def _fit_resample(self, X, y):
202210
C_y = _safe_indexing(y, C_indices)
203211

204212
# fit a knn on C
205-
self.estimator_.fit(C_x, C_y)
213+
self.n_neighbors_.fit(C_x, C_y)
206214

207215
# This experimental to speed up the search
208216
# Classify all the element in S and avoid to test the
209217
# well classified elements
210-
pred_S_y = self.estimator_.predict(S_x)
218+
pred_S_y = self.n_neighbors_.predict(S_x)
211219
good_classif_label = np.unique(
212220
np.append(idx_maj_sample, np.flatnonzero(pred_S_y == S_y))
213221
)
@@ -224,3 +232,11 @@ def _fit_resample(self, X, y):
224232

225233
def _more_tags(self):
226234
return {"sample_indices": True}
235+
236+
@deprecated(
237+
"`estimator_` is deprecated in version 0.10 and will be removed in version "
238+
"0.12. Use `n_neighbors_` instead."
239+
)
240+
@property
241+
def estimator_(self):
242+
return self.n_neighbors_

imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,15 @@ def test_cnn_fit_resample_with_object():
101101
def test_cnn_fit_resample_with_wrong_object():
102102
knn = "rnd"
103103
cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn)
104-
with pytest.raises(ValueError, match="has to be a int or an "):
104+
msg = "`n_neighbors` must be an integer or a KNN classifier"
105+
with pytest.raises(ValueError, match=msg):
105106
cnn.fit_resample(X, Y)
107+
108+
109+
def test_cnn_estimator_deprecation():
110+
cnn = CondensedNearestNeighbour(random_state=RND_SEED)
111+
cnn.fit_resample(X, Y)
112+
113+
msg = "`estimator_` is deprecated in version 0.10"
114+
with pytest.warns(FutureWarning, match=msg):
115+
assert cnn.estimator_ == cnn.n_neighbors_

0 commit comments

Comments
 (0)