16
16
from ...utils import check_neighbors_object
17
17
from ...utils import Substitution
18
18
19
+ SEL_KIND = ('weak' , 'relabel' , 'strong' )
20
+
19
21
20
22
@Substitution (sampling_strategy = BasePreprocessSampler ._sampling_strategy_docstring )
21
23
class SPIDER (BasePreprocessSampler ):
@@ -27,9 +29,6 @@ class SPIDER(BasePreprocessSampler):
27
29
Parameters
28
30
----------
29
31
{sampling_strategy}
30
- #TODO see dict vs list sampling_strategy of other samplers
31
- # to see if applicable to this
32
- # NCR would be good to check
33
32
34
33
kind : str (default='weak')
35
34
Possible choices are:
@@ -62,14 +61,11 @@ class SPIDER(BasePreprocessSampler):
62
61
-----
63
62
The implementation is based on [1]_ and [2]_.
64
63
65
- # TODO verify this will work
66
64
Supports multi-class resampling. A one-vs.-rest scheme is used.
67
65
68
66
See also
69
67
--------
70
- NCR : Clean-sample using NeighborhoodClearingRule.
71
-
72
- ROS : Over-sample using RandomOverSampling
68
+ NeighborhoodClearingRule and RandomOverSampler
73
69
74
70
References
75
71
----------
@@ -85,7 +81,21 @@ class SPIDER(BasePreprocessSampler):
85
81
86
82
Examples
87
83
--------
88
- TODO
84
+
85
+ >>> from collections import Counter
86
+ >>> from sklearn.datasets import make_classification
87
+ >>> from imblearn.combine import \
88
+ SPIDER # doctest: +NORMALIZE_WHITESPACE
89
+ >>> X, y = make_classification(n_classes=2, class_sep=2,
90
+ ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
91
+ ... n_features=20, n_clusters_per_class=1, n_samples=1000,
92
+ ... random_state=10)
93
+ >>> print('Original dataset shape %s' % Counter(y))
94
+ Original dataset shape Counter({{1: 900, 0: 100}})
95
+ >>> spider = SPIDER()
96
+ >>> X_res, y_res = spider.fit_resample(X, y)
97
+ >>> print('Resampled dataset shape %s' % Counter(y_res))
98
+ Resampled dataset shape Counter({{1: 897, 0: 115}})
89
99
"""
90
100
91
101
def __init__ (
@@ -108,7 +118,7 @@ def _validate_estimator(self):
108
118
'n_neighbors' , self .n_neighbors , additional_neighbor = 1 )
109
119
self .nn_ .set_params (** {'n_jobs' : self .n_jobs })
110
120
111
- if self .kind not in ( 'weak' , 'relabel' , 'strong' ) :
121
+ if self .kind not in SEL_KIND :
112
122
raise ValueError ('The possible "kind" of algorithm are '
113
123
'"weak", "relabel", and "strong".'
114
124
'Got {} instead.' .format (self .kind ))
@@ -124,16 +134,16 @@ def _locate_neighbors(self, X, additional=False):
124
134
125
135
Parameters
126
136
----------
127
- X : ndarray, size(m_samples , n_features)
137
+ X : ndarray, shape (n_samples , n_features)
128
138
The feature samples to find neighbors for.
129
139
130
- additional : bool, optional (defaul =False)
140
+ additional : bool, optional (default =False)
131
141
Flag to indicate whether to increase ``n_neighbors`` by
132
142
``additional_neighbors``.
133
143
134
144
Returns
135
145
-------
136
- nn_indices : ndarray, size(TODO )
146
+ nn_indices : ndarray, shape (n_samples, n_neighbors )
137
147
Indices of the nearest neighbors for the subset.
138
148
"""
139
149
n_neighbors = self .nn_ .n_neighbors
@@ -149,25 +159,26 @@ def _knn_correct(self, X, y, additional=False):
149
159
150
160
Parameters
151
161
----------
152
- X : ndarray, size(m_samples , n_features)
162
+ X : ndarray, shape (n_samples , n_features)
153
163
The feature samples to classify.
154
164
155
- y : ndarray, size(m_samples ,)
165
+ y : ndarray, shape (n_samples ,)
156
166
The label samples to classify.
157
167
158
- additional : bool, optional (defaul =False)
168
+ additional : bool, optional (default =False)
159
169
Flag to indicate whether to increase ``n_neighbors`` by
160
170
additional_neighbors``.
161
171
162
172
Returns
163
173
-------
164
- is_correct : ndarray[bool], size(m_samples ,)
174
+ is_correct : ndarray[bool], shape (n_samples ,)
165
175
Mask that indicates if KNN classifed samples correctly.
166
176
"""
167
177
try :
168
178
nn_indices = self ._locate_neighbors (X , additional )
169
179
except ValueError :
170
180
return np .empty (0 , dtype = bool )
181
+
171
182
mode , _ = stats .mode (self ._y [nn_indices ], axis = 1 )
172
183
is_correct = (y == mode .ravel ())
173
184
return is_correct
@@ -179,19 +190,19 @@ def _amplify(self, X, y, additional=False):
179
190
180
191
Parameters
181
192
----------
182
- X : ndarray, size(m_samples , n_features)
193
+ X : ndarray, shape (n_samples , n_features)
183
194
The feature samples to amplify.
184
195
185
- y : ndarray, size(m_samples ,)
196
+ y : ndarray, shape (n_samples ,)
186
197
The label samples to amplify.
187
198
188
- additional : bool, optional (defaul =False)
199
+ additional : bool, optional (default =False)
189
200
Flag to indicate whether to amplify with ``additional_neighbors``.
190
201
191
202
Returns
192
203
-------
193
- nn_indices : TODO
194
- TODO
204
+ nn_indices : ndarray, shape (n_samples, n_neighbors)
205
+ Indices of the nearest neighbors for the subset.
195
206
"""
196
207
try :
197
208
nn_indices = self ._locate_neighbors (X , additional )
@@ -276,7 +287,10 @@ def _fit_resample(self, X, y):
276
287
raise NotImplementedError (self .kind )
277
288
278
289
discard_mask = np .ones_like (y , dtype = bool )
279
- discard_mask [discard_indices ] = False
290
+ try :
291
+ discard_mask [discard_indices ] = False
292
+ except UnboundLocalError :
293
+ pass
280
294
281
295
X_resampled = self ._X_resampled
282
296
y_resampled = self ._y_resampled
@@ -290,6 +304,6 @@ def _fit_resample(self, X, y):
290
304
X_resampled = np .vstack (X_resampled )
291
305
y_resampled = np .hstack (y_resampled )
292
306
293
- del self ._X_resampled , self ._y_resampled
294
- del self ._y , self . _amplify_indices
307
+ del self ._X_resampled , self ._y_resampled , self . _y
308
+ self ._amplify_indices = None
295
309
return X_resampled , y_resampled
0 commit comments