Skip to content

Commit dbb0564

Browse files
committed
docstring SPIDER; spider sample strategy list
1 parent 58f4d78 commit dbb0564

File tree

3 files changed

+46
-29
lines changed

3 files changed

+46
-29
lines changed

imblearn/combine/_preprocess/_spider.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from ...utils import check_neighbors_object
1717
from ...utils import Substitution
1818

19+
SEL_KIND = ('weak', 'relabel', 'strong')
20+
1921

2022
@Substitution(sampling_strategy=BasePreprocessSampler._sampling_strategy_docstring)
2123
class SPIDER(BasePreprocessSampler):
@@ -27,9 +29,6 @@ class SPIDER(BasePreprocessSampler):
2729
Parameters
2830
----------
2931
{sampling_strategy}
30-
#TODO see dict vs list sampling_strategy of other samplers
31-
# to see if applicable to this
32-
# NCR would be good to check
3332
3433
kind : str (default='weak')
3534
Possible choices are:
@@ -62,14 +61,11 @@ class SPIDER(BasePreprocessSampler):
6261
-----
6362
The implementation is based on [1]_ and [2]_.
6463
65-
# TODO verify this will work
6664
Supports multi-class resampling. A one-vs.-rest scheme is used.
6765
6866
See also
6967
--------
70-
NCR : Clean-sample using NeighborhoodClearingRule.
71-
72-
ROS : Over-sample using RandomOverSampling
68+
NeighborhoodClearingRule and RandomOverSampler
7369
7470
References
7571
----------
@@ -85,7 +81,21 @@ class SPIDER(BasePreprocessSampler):
8581
8682
Examples
8783
--------
88-
TODO
84+
85+
>>> from collections import Counter
86+
>>> from sklearn.datasets import make_classification
87+
>>> from imblearn.combine import \
88+
SPIDER # doctest: +NORMALIZE_WHITESPACE
89+
>>> X, y = make_classification(n_classes=2, class_sep=2,
90+
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
91+
... n_features=20, n_clusters_per_class=1, n_samples=1000,
92+
... random_state=10)
93+
>>> print('Original dataset shape %s' % Counter(y))
94+
Original dataset shape Counter({{1: 900, 0: 100}})
95+
>>> spider = SPIDER()
96+
>>> X_res, y_res = spider.fit_resample(X, y)
97+
>>> print('Resampled dataset shape %s' % Counter(y_res))
98+
Resampled dataset shape Counter({{1: 897, 0: 115}})
8999
"""
90100

91101
def __init__(
@@ -108,7 +118,7 @@ def _validate_estimator(self):
108118
'n_neighbors', self.n_neighbors, additional_neighbor=1)
109119
self.nn_.set_params(**{'n_jobs': self.n_jobs})
110120

111-
if self.kind not in ('weak', 'relabel', 'strong'):
121+
if self.kind not in SEL_KIND:
112122
raise ValueError('The possible "kind" of algorithm are '
113123
'"weak", "relabel", and "strong".'
114124
'Got {} instead.'.format(self.kind))
@@ -124,16 +134,16 @@ def _locate_neighbors(self, X, additional=False):
124134
125135
Parameters
126136
----------
127-
X : ndarray, size(m_samples, n_features)
137+
X : ndarray, shape (n_samples, n_features)
128138
The feature samples to find neighbors for.
129139
130-
additional : bool, optional (defaul=False)
140+
additional : bool, optional (default=False)
131141
Flag to indicate whether to increase ``n_neighbors`` by
132142
``additional_neighbors``.
133143
134144
Returns
135145
-------
136-
nn_indices : ndarray, size(TODO)
146+
nn_indices : ndarray, shape (n_samples, n_neighbors)
137147
Indices of the nearest neighbors for the subset.
138148
"""
139149
n_neighbors = self.nn_.n_neighbors
@@ -149,25 +159,26 @@ def _knn_correct(self, X, y, additional=False):
149159
150160
Parameters
151161
----------
152-
X : ndarray, size(m_samples, n_features)
162+
X : ndarray, shape (n_samples, n_features)
153163
The feature samples to classify.
154164
155-
y : ndarray, size(m_samples,)
165+
y : ndarray, shape (n_samples,)
156166
The label samples to classify.
157167
158-
additional : bool, optional (defaul=False)
168+
additional : bool, optional (default=False)
159169
Flag to indicate whether to increase ``n_neighbors`` by
160170
additional_neighbors``.
161171
162172
Returns
163173
-------
164-
is_correct : ndarray[bool], size(m_samples,)
174+
is_correct : ndarray[bool], shape (n_samples,)
165175
Mask that indicates if KNN classifed samples correctly.
166176
"""
167177
try:
168178
nn_indices = self._locate_neighbors(X, additional)
169179
except ValueError:
170180
return np.empty(0, dtype=bool)
181+
171182
mode, _ = stats.mode(self._y[nn_indices], axis=1)
172183
is_correct = (y == mode.ravel())
173184
return is_correct
@@ -179,19 +190,19 @@ def _amplify(self, X, y, additional=False):
179190
180191
Parameters
181192
----------
182-
X : ndarray, size(m_samples, n_features)
193+
X : ndarray, shape (n_samples, n_features)
183194
The feature samples to amplify.
184195
185-
y : ndarray, size(m_samples,)
196+
y : ndarray, shape (n_samples,)
186197
The label samples to amplify.
187198
188-
additional : bool, optional (defaul=False)
199+
additional : bool, optional (default=False)
189200
Flag to indicate whether to amplify with ``additional_neighbors``.
190201
191202
Returns
192203
-------
193-
nn_indices : TODO
194-
TODO
204+
nn_indices : ndarray, shape (n_samples, n_neighbors)
205+
Indices of the nearest neighbors for the subset.
195206
"""
196207
try:
197208
nn_indices = self._locate_neighbors(X, additional)
@@ -276,7 +287,10 @@ def _fit_resample(self, X, y):
276287
raise NotImplementedError(self.kind)
277288

278289
discard_mask = np.ones_like(y, dtype=bool)
279-
discard_mask[discard_indices] = False
290+
try:
291+
discard_mask[discard_indices] = False
292+
except UnboundLocalError:
293+
pass
280294

281295
X_resampled = self._X_resampled
282296
y_resampled = self._y_resampled
@@ -290,6 +304,6 @@ def _fit_resample(self, X, y):
290304
X_resampled = np.vstack(X_resampled)
291305
y_resampled = np.hstack(y_resampled)
292306

293-
del self._X_resampled, self._y_resampled
294-
del self._y, self._amplify_indices
307+
del self._X_resampled, self._y_resampled, self._y
308+
self._amplify_indices = None
295309
return X_resampled, y_resampled

imblearn/combine/_preprocess/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class BasePreprocessSampler(BaseSampler):
3232
3333
``'auto'``: equivalent to ``'not majority'``.
3434
35+
- When ``list``, the list contains the classes targeted by the
36+
resampling.
37+
3538
- When callable, function taking ``y`` and returns a ``dict``. The keys
3639
correspond to the targeted classes. The values correspond to the
3740
desired number of samples for each class.

imblearn/utils/_validation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _sampling_strategy_minority(y, sampling_type):
193193
key: n_sample_majority - value
194194
for (key, value) in target_stats.items() if key == class_minority
195195
}
196-
elif sampling_strategy in ('under-sampling', 'clean-sampling'):
196+
elif sampling_type in ('under-sampling', 'clean-sampling'):
197197
raise ValueError("'sampling_strategy'='minority' cannot be used with"
198198
" under-sampler and clean-sampler.")
199199
else:
@@ -279,9 +279,9 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
279279
def _sampling_strategy_list(sampling_strategy, y, sampling_type):
280280
"""With cleaning methods, sampling_strategy can be a list to target the
281281
class of interest."""
282-
if sampling_type != 'clean-sampling':
282+
if sampling_type not in ('clean-sampling', 'preprocess-sampling'):
283283
raise ValueError("'sampling_strategy' cannot be a list for samplers "
284-
"which are not cleaning methods.")
284+
"which are not cleaning or preprocess methods.")
285285

286286
target_stats = _count_class_sample(y)
287287
# check that all keys in sampling_strategy are also in y
@@ -400,8 +400,8 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs):
400400
for **cleaning methods**.
401401
402402
.. warning::
403-
``list`` is available for **cleaning methods**. An error is raised
404-
with **under-, over-, and preprocess-sampling methods**.
403+
``list`` is available for **cleaning and preprocess methods**. An
404+
error is raised with **under- and over-sampling methods**.
405405
406406
- When callable, function taking ``y`` and returns a ``dict``. The keys
407407
correspond to the targeted classes. The values correspond to the

0 commit comments

Comments
 (0)