28
28
29
29
Neighborhood = namedtuple ('Neighborhood' , 'radius, neighbors' )
30
30
31
+
31
32
def plot_X (X , ax , ** kwargs ):
32
33
ax .scatter (X [:, 0 ], X [:, 1 ], ** kwargs )
33
34
35
+
34
36
def correct (nn , y_fit , X , y , additional = False ):
35
37
n_neighbors = nn .n_neighbors
36
38
if additional :
@@ -39,12 +41,13 @@ def correct(nn, y_fit, X, y, additional=False):
39
41
y_pred , _ = mode (y_fit [nn_idxs ], axis = 1 )
40
42
return (y == y_pred .ravel ())
41
43
44
+
42
45
def get_neighborhoods (spider , X_fit , y_fit , X_flagged , y_flagged , idx ):
43
46
point = X_flagged [idx ]
44
47
45
48
additional = (spider .kind == 'strong' )
46
- if correct (spider .nn_ , y_fit , point [np .newaxis ], y_flagged [ idx ][ np . newaxis ],
47
- additional = additional ):
49
+ if correct (spider .nn_ , y_fit , point [np .newaxis ],
50
+ y_flagged [ idx ][ np . newaxis ], additional = additional ):
48
51
additional = False
49
52
50
53
idxs_k = spider ._locate_neighbors (point [np .newaxis ])
@@ -61,6 +64,7 @@ def get_neighborhoods(spider, X_fit, y_fit, X_flagged, y_flagged, idx):
61
64
62
65
return neighborhood_k , neighborhood_k2 , point , additional
63
66
67
+
64
68
def draw_neighborhoods (spider , neighborhood_k , neighborhood_k2 , point ,
65
69
additional , ax , outer = True , alpha = 0.5 ):
66
70
PartialCircle = partial (Circle , facecolor = 'none' , edgecolor = 'black' ,
@@ -80,13 +84,15 @@ def draw_neighborhoods(spider, neighborhood_k, neighborhood_k2, point,
80
84
if (spider .kind == 'strong' ) and outer :
81
85
ax .add_patch (circle_k2 )
82
86
87
+
83
88
def draw_amplification (X_flagged , point , neighbors , ax ):
84
89
for neigh in neighbors :
85
90
arr = np .vstack ([point , neigh ])
86
91
xs , ys = np .split (arr , 2 , axis = 1 )
87
92
linestyle = 'solid' if neigh in X_flagged else 'dotted'
88
93
ax .plot (xs , ys , color = 'black' , linestyle = linestyle )
89
94
95
+
90
96
def plot_spider (kind , X , y ):
91
97
if kind == 'strong' :
92
98
_ , axes = plt .subplots (2 , 1 , figsize = (12 , 16 ))
@@ -203,7 +209,7 @@ def plot_spider(kind, X, y):
203
209
])
204
210
205
211
y = np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 0 ,
206
- 0 , 1 , 1 , 1 , 1 , 0 , 1 , 1 , 1 , 1 , 0 , 0 , 1 , 0 , 0 , 0 ])
212
+ 0 , 1 , 1 , 1 , 1 , 0 , 1 , 1 , 1 , 1 , 0 , 0 , 1 , 0 , 0 , 0 ])
207
213
208
214
209
215
###############################################################################
@@ -213,10 +219,10 @@ def plot_spider(kind, X, y):
213
219
###############################################################################
214
220
# Both SPIDER-Weak and SPIDER-Relabel start by labeling whether samples are
215
221
# 'safe' or 'noisy' by looking at each point's 3-NN and seeing if it would be
216
- # classified correctly using KNN classification. For each minority-noisy sample,
217
- # we amplify it by the number of majority-safe samples in its 3-NN. In the
218
- # diagram below, the amplification amount is indicated by the number of solid
219
- # lines for a given minority-noisy sample's neighborhood.
222
+ # classified correctly using KNN classification. For each minority-noisy
223
+ # sample, we amplify it by the number of majority-safe samples in its 3-NN. In
224
+ # the diagram below, the amplification amount is indicated by the number of
225
+ # solid lines for a given minority-noisy sample's neighborhood.
220
226
#
221
227
# We can observe that the leftmost minority-noisy sample will be duplicated 3
222
228
# times, the middle one 1 time, and the rightmost one will not be amplified.
@@ -243,11 +249,11 @@ def plot_spider(kind, X, y):
243
249
# respectively. The middle minority-noisy sample is classified correctly by
244
250
# using 5-NN, so amplification will be done using 3-NN.
245
251
#
246
- # Next for each minority-safe sample, the amplification process is applied using
247
- # 3-NN. In the lower subplot, all but one of these samples will not be amplified
248
- # since they do not have majority-safe samples in their neighborhoods. The one
249
- # minority-safe sample to be amplified is indicated in a darker neighborhood
250
- # with lines.
252
+ # Next for each minority-safe sample, the amplification process is applied
253
+ # using 3-NN. In the lower subplot, all but one of these samples will not be
254
+ # amplified since they do not have majority-safe samples in their
255
+ # neighborhoods. The one minority-safe sample to be amplified is indicated in a
256
+ # darker neighborhood with lines.
251
257
252
258
plot_spider ('strong' , X , y )
253
259
0 commit comments