10
10
from scipy import sparse
11
11
from scipy import stats
12
12
13
- from sklearn .utils import safe_indexing , safe_mask
13
+ from sklearn .utils import safe_mask
14
+ from sklearn .utils import _safe_indexing
14
15
15
16
from .base import BasePreprocessSampler
16
17
from ...utils import check_neighbors_object
17
18
from ...utils import Substitution
19
+ from ..utils ._docstring import _n_jobs_docstring
18
20
19
- SEL_KIND = (' weak' , ' relabel' , ' strong' )
21
+ SEL_KIND = (" weak" , " relabel" , " strong" )
20
22
21
23
22
24
@Substitution (
23
- sampling_strategy = BasePreprocessSampler ._sampling_strategy_docstring )
25
+ sampling_strategy = BasePreprocessSampler ._sampling_strategy_docstring ,
26
+ n_jobs = _n_jobs_docstring ,
27
+ )
24
28
class SPIDER (BasePreprocessSampler ):
25
29
"""Perform filtering and over-sampling using Selective Pre-processing of
26
30
Imbalanced Data (SPIDER) sampling approach for imbalanced datasets.
@@ -55,8 +59,7 @@ class SPIDER(BasePreprocessSampler):
55
59
The number to add to amplified samples during if ``kind`` is
56
60
``'strong'``. This has no effect otherwise.
57
61
58
- n_jobs : int, optional (default=1)
59
- Number of threads to run the algorithm when it is possible.
62
+ {n_jobs}
60
63
61
64
Notes
62
65
-----
@@ -101,11 +104,11 @@ class SPIDER(BasePreprocessSampler):
101
104
102
105
def __init__ (
103
106
self ,
104
- sampling_strategy = ' auto' ,
105
- kind = ' weak' ,
107
+ sampling_strategy = " auto" ,
108
+ kind = " weak" ,
106
109
n_neighbors = 3 ,
107
110
additional_neighbors = 2 ,
108
- n_jobs = 1 ,
111
+ n_jobs = None ,
109
112
):
110
113
super ().__init__ (sampling_strategy = sampling_strategy )
111
114
self .kind = kind
@@ -116,19 +119,20 @@ def __init__(
116
119
def _validate_estimator (self ):
117
120
"""Create the necessary objects for SPIDER"""
118
121
self .nn_ = check_neighbors_object (
119
- ' n_neighbors' , self .n_neighbors , additional_neighbor = 1 )
120
- self .nn_ .set_params (** {' n_jobs' : self .n_jobs })
122
+ " n_neighbors" , self .n_neighbors , additional_neighbor = 1 )
123
+ self .nn_ .set_params (** {" n_jobs" : self .n_jobs })
121
124
122
125
if self .kind not in SEL_KIND :
123
- raise ValueError ('The possible "kind" of algorithm are '
124
- '"weak", "relabel", and "strong".'
125
- 'Got {} instead.' .format (self .kind ))
126
+ raise ValueError (
127
+ 'The possible "kind" of algorithm are "weak", "relabel",'
128
+ ' and "strong". Got {} instead.' .format (self .kind )
129
+ )
126
130
127
131
if self .additional_neighbors < 1 :
128
- raise ValueError (' additional_neighbors must be at least 1.' )
132
+ raise ValueError (" additional_neighbors must be at least 1." )
129
133
130
134
if not isinstance (self .additional_neighbors , Integral ):
131
- raise TypeError (' additional_neighbors must be an integer.' )
135
+ raise TypeError (" additional_neighbors must be an integer." )
132
136
133
137
def _locate_neighbors (self , X , additional = False ):
134
138
"""Find nearest neighbors for samples.
@@ -249,22 +253,22 @@ def _fit_resample(self, X, y):
249
253
discard_indices = np .flatnonzero (~ is_class & ~ is_safe )
250
254
251
255
class_noisy_indices = np .flatnonzero (is_class & ~ is_safe )
252
- X_class_noisy = safe_indexing (X , class_noisy_indices )
256
+ X_class_noisy = _safe_indexing (X , class_noisy_indices )
253
257
y_class_noisy = y [class_noisy_indices ]
254
258
255
- if self .kind in (' weak' , ' relabel' ):
259
+ if self .kind in (" weak" , " relabel" ):
256
260
nn_indices = self ._amplify (X_class_noisy , y_class_noisy )
257
261
258
- if self .kind == ' relabel' :
262
+ if self .kind == " relabel" :
259
263
relabel_mask = np .isin (nn_indices , discard_indices )
260
264
relabel_indices = np .unique (nn_indices [relabel_mask ])
261
265
self ._y [relabel_indices ] = class_sample
262
266
discard_indices = np .setdiff1d (
263
267
discard_indices , relabel_indices )
264
268
265
- elif self .kind == ' strong' :
269
+ elif self .kind == " strong" :
266
270
class_safe_indices = np .flatnonzero (is_class & is_safe )
267
- X_class_safe = safe_indexing (X , class_safe_indices )
271
+ X_class_safe = _safe_indexing (X , class_safe_indices )
268
272
y_class_safe = y [class_safe_indices ]
269
273
self ._amplify (X_class_safe , y_class_safe )
270
274
0 commit comments