Skip to content

Commit c0fe84e

Browse files
committed
spider dense & sparse now resample in same order; add spider formulation illustration to docs
1 parent e82fe2c commit c0fe84e

File tree

4 files changed

+302
-62
lines changed

4 files changed

+302
-62
lines changed
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
"""
2+
==========================================================================
3+
Illustration of the sample selection for the different SPIDER algorithms
4+
==========================================================================
5+
6+
This example illustrates the different ways of resampling with SPIDER.
7+
8+
"""
9+
10+
# Authors: Matthew Eding
11+
# License: MIT
12+
13+
from collections import namedtuple
14+
from functools import partial
15+
16+
import matplotlib.pyplot as plt
17+
import numpy as np
18+
19+
from imblearn.combine import SPIDER
20+
from matplotlib.patches import Circle
21+
from scipy.stats import mode
22+
from sklearn.neighbors import NearestNeighbors
23+
24+
print(__doc__)
25+
26+
###############################################################################
27+
# These are helper functions for plotting aspects of the algorithm
28+
29+
Neighborhood = namedtuple('Neighborhood', 'radius, neighbors')
30+
31+
def plot_X(X, ax, **kwargs):
32+
ax.scatter(X[:, 0], X[:, 1], **kwargs)
33+
34+
def correct(nn, y_fit, X, y, additional=False):
35+
n_neighbors = nn.n_neighbors
36+
if additional:
37+
n_neighbors += 2
38+
nn_idxs = nn.kneighbors(X, n_neighbors, return_distance=False)[:, 1:]
39+
y_pred, _ = mode(y_fit[nn_idxs], axis=1)
40+
return (y == y_pred.ravel())
41+
42+
def get_neighborhoods(spider, X_fit, y_fit, X_flagged, y_flagged, idx):
43+
point = X_flagged[idx]
44+
45+
additional = (spider.kind == 'strong')
46+
if correct(spider.nn_, y_fit, point[np.newaxis], y_flagged[idx][np.newaxis],
47+
additional=additional):
48+
additional = False
49+
50+
idxs_k = spider._locate_neighbors(point[np.newaxis])
51+
neighbors_k = X_fit[idxs_k].squeeze()
52+
farthest_k = neighbors_k[-1]
53+
radius_k = np.linalg.norm(point - farthest_k)
54+
neighborhood_k = Neighborhood(radius_k, neighbors_k)
55+
56+
idxs_k2 = spider._locate_neighbors(point[np.newaxis], additional=True)
57+
neighbors_k2 = X_fit[idxs_k2].squeeze()
58+
farthest_k2 = neighbors_k2[-1]
59+
radius_k2 = np.linalg.norm(point - farthest_k2)
60+
neighborhood_k2 = Neighborhood(radius_k2, neighbors_k2)
61+
62+
return neighborhood_k, neighborhood_k2, point, additional
63+
64+
def draw_neighborhoods(spider, neighborhood_k, neighborhood_k2, point,
65+
additional, ax, outer=True, alpha=0.5):
66+
PartialCircle = partial(Circle, facecolor='none', edgecolor='black',
67+
alpha=alpha)
68+
69+
circle_k = PartialCircle(point, neighborhood_k.radius, linestyle='-')
70+
71+
circle_k2 = PartialCircle(point, neighborhood_k2.radius,
72+
linestyle=('-' if additional else '--'))
73+
74+
if additional:
75+
neighbors = neighborhood_k2.neighbors
76+
else:
77+
neighbors = neighborhood_k.neighbors
78+
ax.add_patch(circle_k)
79+
80+
if (spider.kind == 'strong') and outer:
81+
ax.add_patch(circle_k2)
82+
83+
def draw_amplification(X_flagged, point, neighbors, ax):
84+
for neigh in neighbors:
85+
arr = np.vstack([point, neigh])
86+
xs, ys = np.split(arr, 2, axis=1)
87+
linestyle = 'solid' if neigh in X_flagged else 'dotted'
88+
ax.plot(xs, ys, color='black', linestyle=linestyle)
89+
90+
def plot_spider(kind, X, y):
91+
if kind == 'strong':
92+
_, axes = plt.subplots(2, 1, figsize=(12, 16))
93+
else:
94+
_, axes = plt.subplots(1, 1, figsize=(12, 8))
95+
axes = np.atleast_1d(axes)
96+
97+
spider = SPIDER(kind=kind)
98+
spider.fit_resample(X, y)
99+
100+
is_safe = correct(spider.nn_, y, X, y)
101+
is_minor = (y == 1)
102+
103+
X_major = X[~is_minor]
104+
X_minor = X[is_minor]
105+
X_noise = X[~is_safe]
106+
107+
X_minor_noise = X[is_minor & ~is_safe]
108+
y_minor_noise = y[is_minor & ~is_safe]
109+
X_major_safe = X[~is_minor & is_safe]
110+
X_minor_safe = X[is_minor & is_safe]
111+
y_minor_safe = y[is_minor & is_safe]
112+
113+
partial_neighborhoods = partial(get_neighborhoods, spider, X, y)
114+
partial_amplification = partial(draw_amplification, X_major_safe)
115+
partial_draw_neighborhoods = partial(draw_neighborhoods, spider)
116+
117+
size = 500
118+
for axis in axes:
119+
plot_X(X_minor, ax=axis, label='Minority class', s=size, marker='_')
120+
plot_X(X_major, ax=axis, label='Minority class', s=size, marker='+')
121+
122+
#: Overlay ring around noisy samples for both classes
123+
plot_X(X_noise, ax=axis, label='Noisy Sample', s=size, marker='o',
124+
facecolors='none', edgecolors='black')
125+
126+
#: Neighborhoods for Noisy Minority Samples
127+
for idx in range(len(X_minor_noise)):
128+
neighborhoods = partial_neighborhoods(X_minor_noise, y_minor_noise,
129+
idx=idx)
130+
partial_draw_neighborhoods(*neighborhoods, ax=axes[0],
131+
outer=(spider.kind == 'strong'))
132+
neigh_k, neigh_k2, point, additional = neighborhoods
133+
neighbors = neigh_k2.neighbors if additional else neigh_k.neighbors
134+
partial_amplification(point, neighbors, ax=axes[0])
135+
136+
axes[0].axis('equal')
137+
axes[0].legend(markerscale=0.5)
138+
axes[0].set_title(f'SPIDER-{spider.kind.title()}')
139+
140+
#: Neighborhoods for Safe Minority Samples (kind='strong' only)
141+
if spider.kind == 'strong':
142+
for idx in range(len(X_minor_safe)):
143+
neighborhoods = partial_neighborhoods(X_minor_safe, y_minor_safe,
144+
idx=idx)
145+
neigh_k, _, point, additional = neighborhoods
146+
neighbors = neigh_k.neighbors
147+
draw_flag = np.any(np.isin(neighbors, X_major_safe))
148+
149+
alpha = 0.5 if draw_flag else 0.1
150+
partial_draw_neighborhoods(*neighborhoods[:-1], additional=False,
151+
ax=axes[1], outer=False, alpha=alpha)
152+
153+
if draw_flag:
154+
partial_amplification(point, neighbors, ax=axes[1])
155+
156+
axes[1].axis('equal')
157+
axes[1].legend(markerscale=0.5)
158+
axes[1].set_title(f'SPIDER-{spider.kind.title()}')
159+
160+
161+
###############################################################################
162+
# We can start by generating some data to later illustrate the principle of
163+
# each SPIDER heuritic rules.
164+
165+
X = np.array([
166+
[-11.83, -6.81],
167+
[-11.72, -2.34],
168+
[-11.43, -5.85],
169+
[-10.66, -4.33],
170+
[-9.64, -7.05],
171+
[-8.39, -4.41],
172+
[-8.07, -5.66],
173+
[-7.28, 0.91],
174+
[-7.24, -2.41],
175+
[-6.13, -4.81],
176+
[-5.92, -6.81],
177+
[-4., -1.81],
178+
[-3.96, 2.67],
179+
[-3.74, -7.31],
180+
[-2.96, 4.69],
181+
[-1.56, -2.33],
182+
[-1.02, -4.57],
183+
[0.46, 4.07],
184+
[1.2, -1.53],
185+
[1.32, 0.41],
186+
[1.56, -5.19],
187+
[2.52, 5.89],
188+
[3.03, -4.15],
189+
[4., -0.59],
190+
[4.4, 2.07],
191+
[4.41, -7.45],
192+
[4.45, -4.12],
193+
[5.13, -6.28],
194+
[5.4, -5],
195+
[6.26, 4.65],
196+
[7.02, -6.22],
197+
[7.5, -0.11],
198+
[8.1, -2.05],
199+
[8.42, 2.47],
200+
[9.62, 3.87],
201+
[10.54, -4.47],
202+
[11.42, 0.01]
203+
])
204+
205+
y = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0,
206+
0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0])
207+
208+
209+
###############################################################################
210+
# SPIDER-Weak / SPIDER-Relabel
211+
###############################################################################
212+
213+
###############################################################################
214+
# Both SPIDER-Weak and SPIDER-Relabel start by labeling whether samples are
215+
# 'safe' or 'noisy' by looking at each point's 3-NN and seeing if it would be
216+
# classified correctly using KNN classification. For each minority-noisy sample,
217+
# we amplify it by the number of majority-safe samples in its 3-NN. In the
218+
# diagram below, the amplification amount is indicated by the number of solid
219+
# lines for a given minority-noisy sample's neighborhood.
220+
#
221+
# We can observe that the leftmost minority-noisy sample will be duplicated 3
222+
# times, the middle one 1 time, and the rightmost one will not be amplified.
223+
#
224+
# Then if SPIDER-Weak, every majority-noisy sample is removed from the dataset.
225+
# Othewise if SPIDER-Relabel, we relabel their class to be the minority class
226+
# instead. These would be the samples indicated by a circled plus-sign.
227+
228+
plot_spider('weak', X, y)
229+
230+
###############################################################################
231+
# SPIDER-Strong
232+
###############################################################################
233+
234+
###############################################################################
235+
# SPIDER-Strong still uses 3-NN to classify samples as 'safe' or 'noisy' as the
236+
# first step. However for the amplification step, each minority-noisy sample
237+
# looks at its 5-NN, and if the larger neighborhood still misclassifies the
238+
# sample, the 5-NN is used to amplify. Otherwise if the sample is correctly
239+
# classified with 5-NN, the regular 3-NN is used to amplify.
240+
#
241+
# In the diagram below, we can see that the left/rightmost minority-noisy
242+
# samples are misclassified using 5-NN and will be amplified by 5 and 1
243+
# respectively. The middle minority-noisy sample is classified correctly by
244+
# using 5-NN, so amplification will be done using 3-NN.
245+
#
246+
# Next for each minority-safe sample, the amplification process is applied using
247+
# 3-NN. In the lower subplot, all but one of these samples will not be amplified
248+
# since they do not have majority-safe samples in their neighborhoods. The one
249+
# minority-safe sample to be amplified is indicated in a darker neighborhood
250+
# with lines.
251+
252+
plot_spider('strong', X, y)
253+
254+
plt.show()

imblearn/combine/_preprocess/_spider.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Class to perform cleaning and selective pre-processing using SPIDER"""
22

3-
# Author: Matthew Eding
3+
# Authors: Matthew Eding
44
# License: MIT
55

66

@@ -25,7 +25,7 @@ class SPIDER(BasePreprocessSampler):
2525
"""Perform filtering and over-sampling using Selective Pre-processing of
2626
Imbalanced Data (SPIDER) sampling approach for imbalanced datasets.
2727
28-
TODO Read more in the :ref:`User Guide <spider>`.
28+
Read more in the :ref:`User Guide <combine>`.
2929
3030
Parameters
3131
----------
@@ -213,24 +213,40 @@ def _amplify(self, X, y, additional=False):
213213
amplify_amounts = np.isin(
214214
nn_indices, self._amplify_indices).sum(axis=1)
215215

216-
if sparse.issparse(X):
217-
X_parts = []
218-
y_parts = []
219-
for amount in filter(bool, np.unique(amplify_amounts)):
220-
mask = safe_mask(X, amplify_amounts == amount)
221-
X_part = X[mask]
222-
y_part = y[mask]
223-
X_parts.extend([X_part] * amount)
224-
y_parts.extend([y_part] * amount)
225-
try:
216+
# if sparse.issparse(X):
217+
# X_parts = []
218+
# y_parts = []
219+
# for amount in filter(bool, np.unique(amplify_amounts)):
220+
# mask = safe_mask(X, amplify_amounts == amount)
221+
# # breakpoint()
222+
# X_part = X[mask]
223+
# y_part = y[mask]
224+
# X_parts.extend([X_part] * amount)
225+
# y_parts.extend([y_part] * amount)
226+
# # try:
227+
# X_new = sparse.vstack(X_parts)
228+
# y_new = np.hstack(y_parts)
229+
# # except ValueError: # -- bool filter makes this unnecessary
230+
# # X_new = np.empty(0, dtype=X.dtype)
231+
# # y_new = np.empty(0, dtype=y.dtype)
232+
# else:
233+
# X_new = np.repeat(X, amplify_amounts, axis=0)
234+
# y_new = np.repeat(y, amplify_amounts)
235+
236+
X_parts = []
237+
y_parts = []
238+
for amount in filter(bool, np.unique(amplify_amounts)):
239+
mask = safe_mask(X, amplify_amounts == amount)
240+
X_part = X[mask]
241+
y_part = y[mask]
242+
X_parts.extend([X_part] * amount)
243+
y_parts.extend([y_part] * amount)
244+
245+
if sparse.issparse(X):
226246
X_new = sparse.vstack(X_parts)
227-
y_new = np.hstack(y_parts)
228-
except ValueError:
229-
X_new = np.empty(0, dtype=X.dtype)
230-
y_new = np.empty(0, dtype=y.dtype)
231-
else:
232-
X_new = np.repeat(X, amplify_amounts, axis=0)
233-
y_new = np.repeat(y, amplify_amounts)
247+
else:
248+
X_new = np.vstack(X_parts)
249+
y_new = np.hstack(y_parts)
234250

235251
self._X_resampled.append(X_new)
236252
self._y_resampled.append(y_new)

0 commit comments

Comments
 (0)