|
| 1 | +""" |
| 2 | +========================================================================== |
| 3 | +Illustration of the sample selection for the different SPIDER algorithms |
| 4 | +========================================================================== |
| 5 | +
|
| 6 | +This example illustrates the different ways of resampling with SPIDER. |
| 7 | +
|
| 8 | +""" |
| 9 | + |
| 10 | +# Authors: Matthew Eding |
| 11 | +# License: MIT |
| 12 | + |
| 13 | +from collections import namedtuple |
| 14 | +from functools import partial |
| 15 | + |
| 16 | +import matplotlib.pyplot as plt |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +from imblearn.combine import SPIDER |
| 20 | +from matplotlib.patches import Circle |
| 21 | +from scipy.stats import mode |
| 22 | +from sklearn.neighbors import NearestNeighbors |
| 23 | + |
| 24 | +print(__doc__) |
| 25 | + |
| 26 | +############################################################################### |
| 27 | +# These are helper functions for plotting aspects of the algorithm |
| 28 | + |
| 29 | +Neighborhood = namedtuple('Neighborhood', 'radius, neighbors') |
| 30 | + |
| 31 | +def plot_X(X, ax, **kwargs): |
| 32 | + ax.scatter(X[:, 0], X[:, 1], **kwargs) |
| 33 | + |
| 34 | +def correct(nn, y_fit, X, y, additional=False): |
| 35 | + n_neighbors = nn.n_neighbors |
| 36 | + if additional: |
| 37 | + n_neighbors += 2 |
| 38 | + nn_idxs = nn.kneighbors(X, n_neighbors, return_distance=False)[:, 1:] |
| 39 | + y_pred, _ = mode(y_fit[nn_idxs], axis=1) |
| 40 | + return (y == y_pred.ravel()) |
| 41 | + |
| 42 | +def get_neighborhoods(spider, X_fit, y_fit, X_flagged, y_flagged, idx): |
| 43 | + point = X_flagged[idx] |
| 44 | + |
| 45 | + additional = (spider.kind == 'strong') |
| 46 | + if correct(spider.nn_, y_fit, point[np.newaxis], y_flagged[idx][np.newaxis], |
| 47 | + additional=additional): |
| 48 | + additional = False |
| 49 | + |
| 50 | + idxs_k = spider._locate_neighbors(point[np.newaxis]) |
| 51 | + neighbors_k = X_fit[idxs_k].squeeze() |
| 52 | + farthest_k = neighbors_k[-1] |
| 53 | + radius_k = np.linalg.norm(point - farthest_k) |
| 54 | + neighborhood_k = Neighborhood(radius_k, neighbors_k) |
| 55 | + |
| 56 | + idxs_k2 = spider._locate_neighbors(point[np.newaxis], additional=True) |
| 57 | + neighbors_k2 = X_fit[idxs_k2].squeeze() |
| 58 | + farthest_k2 = neighbors_k2[-1] |
| 59 | + radius_k2 = np.linalg.norm(point - farthest_k2) |
| 60 | + neighborhood_k2 = Neighborhood(radius_k2, neighbors_k2) |
| 61 | + |
| 62 | + return neighborhood_k, neighborhood_k2, point, additional |
| 63 | + |
| 64 | +def draw_neighborhoods(spider, neighborhood_k, neighborhood_k2, point, |
| 65 | + additional, ax, outer=True, alpha=0.5): |
| 66 | + PartialCircle = partial(Circle, facecolor='none', edgecolor='black', |
| 67 | + alpha=alpha) |
| 68 | + |
| 69 | + circle_k = PartialCircle(point, neighborhood_k.radius, linestyle='-') |
| 70 | + |
| 71 | + circle_k2 = PartialCircle(point, neighborhood_k2.radius, |
| 72 | + linestyle=('-' if additional else '--')) |
| 73 | + |
| 74 | + if additional: |
| 75 | + neighbors = neighborhood_k2.neighbors |
| 76 | + else: |
| 77 | + neighbors = neighborhood_k.neighbors |
| 78 | + ax.add_patch(circle_k) |
| 79 | + |
| 80 | + if (spider.kind == 'strong') and outer: |
| 81 | + ax.add_patch(circle_k2) |
| 82 | + |
| 83 | +def draw_amplification(X_flagged, point, neighbors, ax): |
| 84 | + for neigh in neighbors: |
| 85 | + arr = np.vstack([point, neigh]) |
| 86 | + xs, ys = np.split(arr, 2, axis=1) |
| 87 | + linestyle = 'solid' if neigh in X_flagged else 'dotted' |
| 88 | + ax.plot(xs, ys, color='black', linestyle=linestyle) |
| 89 | + |
| 90 | +def plot_spider(kind, X, y): |
| 91 | + if kind == 'strong': |
| 92 | + _, axes = plt.subplots(2, 1, figsize=(12, 16)) |
| 93 | + else: |
| 94 | + _, axes = plt.subplots(1, 1, figsize=(12, 8)) |
| 95 | + axes = np.atleast_1d(axes) |
| 96 | + |
| 97 | + spider = SPIDER(kind=kind) |
| 98 | + spider.fit_resample(X, y) |
| 99 | + |
| 100 | + is_safe = correct(spider.nn_, y, X, y) |
| 101 | + is_minor = (y == 1) |
| 102 | + |
| 103 | + X_major = X[~is_minor] |
| 104 | + X_minor = X[is_minor] |
| 105 | + X_noise = X[~is_safe] |
| 106 | + |
| 107 | + X_minor_noise = X[is_minor & ~is_safe] |
| 108 | + y_minor_noise = y[is_minor & ~is_safe] |
| 109 | + X_major_safe = X[~is_minor & is_safe] |
| 110 | + X_minor_safe = X[is_minor & is_safe] |
| 111 | + y_minor_safe = y[is_minor & is_safe] |
| 112 | + |
| 113 | + partial_neighborhoods = partial(get_neighborhoods, spider, X, y) |
| 114 | + partial_amplification = partial(draw_amplification, X_major_safe) |
| 115 | + partial_draw_neighborhoods = partial(draw_neighborhoods, spider) |
| 116 | + |
| 117 | + size = 500 |
| 118 | + for axis in axes: |
| 119 | + plot_X(X_minor, ax=axis, label='Minority class', s=size, marker='_') |
| 120 | + plot_X(X_major, ax=axis, label='Minority class', s=size, marker='+') |
| 121 | + |
| 122 | + #: Overlay ring around noisy samples for both classes |
| 123 | + plot_X(X_noise, ax=axis, label='Noisy Sample', s=size, marker='o', |
| 124 | + facecolors='none', edgecolors='black') |
| 125 | + |
| 126 | + #: Neighborhoods for Noisy Minority Samples |
| 127 | + for idx in range(len(X_minor_noise)): |
| 128 | + neighborhoods = partial_neighborhoods(X_minor_noise, y_minor_noise, |
| 129 | + idx=idx) |
| 130 | + partial_draw_neighborhoods(*neighborhoods, ax=axes[0], |
| 131 | + outer=(spider.kind == 'strong')) |
| 132 | + neigh_k, neigh_k2, point, additional = neighborhoods |
| 133 | + neighbors = neigh_k2.neighbors if additional else neigh_k.neighbors |
| 134 | + partial_amplification(point, neighbors, ax=axes[0]) |
| 135 | + |
| 136 | + axes[0].axis('equal') |
| 137 | + axes[0].legend(markerscale=0.5) |
| 138 | + axes[0].set_title(f'SPIDER-{spider.kind.title()}') |
| 139 | + |
| 140 | + #: Neighborhoods for Safe Minority Samples (kind='strong' only) |
| 141 | + if spider.kind == 'strong': |
| 142 | + for idx in range(len(X_minor_safe)): |
| 143 | + neighborhoods = partial_neighborhoods(X_minor_safe, y_minor_safe, |
| 144 | + idx=idx) |
| 145 | + neigh_k, _, point, additional = neighborhoods |
| 146 | + neighbors = neigh_k.neighbors |
| 147 | + draw_flag = np.any(np.isin(neighbors, X_major_safe)) |
| 148 | + |
| 149 | + alpha = 0.5 if draw_flag else 0.1 |
| 150 | + partial_draw_neighborhoods(*neighborhoods[:-1], additional=False, |
| 151 | + ax=axes[1], outer=False, alpha=alpha) |
| 152 | + |
| 153 | + if draw_flag: |
| 154 | + partial_amplification(point, neighbors, ax=axes[1]) |
| 155 | + |
| 156 | + axes[1].axis('equal') |
| 157 | + axes[1].legend(markerscale=0.5) |
| 158 | + axes[1].set_title(f'SPIDER-{spider.kind.title()}') |
| 159 | + |
| 160 | + |
| 161 | +############################################################################### |
| 162 | +# We can start by generating some data to later illustrate the principle of |
| 163 | +# each SPIDER heuritic rules. |
| 164 | + |
| 165 | +X = np.array([ |
| 166 | + [-11.83, -6.81], |
| 167 | + [-11.72, -2.34], |
| 168 | + [-11.43, -5.85], |
| 169 | + [-10.66, -4.33], |
| 170 | + [-9.64, -7.05], |
| 171 | + [-8.39, -4.41], |
| 172 | + [-8.07, -5.66], |
| 173 | + [-7.28, 0.91], |
| 174 | + [-7.24, -2.41], |
| 175 | + [-6.13, -4.81], |
| 176 | + [-5.92, -6.81], |
| 177 | + [-4., -1.81], |
| 178 | + [-3.96, 2.67], |
| 179 | + [-3.74, -7.31], |
| 180 | + [-2.96, 4.69], |
| 181 | + [-1.56, -2.33], |
| 182 | + [-1.02, -4.57], |
| 183 | + [0.46, 4.07], |
| 184 | + [1.2, -1.53], |
| 185 | + [1.32, 0.41], |
| 186 | + [1.56, -5.19], |
| 187 | + [2.52, 5.89], |
| 188 | + [3.03, -4.15], |
| 189 | + [4., -0.59], |
| 190 | + [4.4, 2.07], |
| 191 | + [4.41, -7.45], |
| 192 | + [4.45, -4.12], |
| 193 | + [5.13, -6.28], |
| 194 | + [5.4, -5], |
| 195 | + [6.26, 4.65], |
| 196 | + [7.02, -6.22], |
| 197 | + [7.5, -0.11], |
| 198 | + [8.1, -2.05], |
| 199 | + [8.42, 2.47], |
| 200 | + [9.62, 3.87], |
| 201 | + [10.54, -4.47], |
| 202 | + [11.42, 0.01] |
| 203 | +]) |
| 204 | + |
| 205 | +y = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, |
| 206 | + 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0]) |
| 207 | + |
| 208 | + |
| 209 | +############################################################################### |
| 210 | +# SPIDER-Weak / SPIDER-Relabel |
| 211 | +############################################################################### |
| 212 | + |
| 213 | +############################################################################### |
| 214 | +# Both SPIDER-Weak and SPIDER-Relabel start by labeling whether samples are |
| 215 | +# 'safe' or 'noisy' by looking at each point's 3-NN and seeing if it would be |
| 216 | +# classified correctly using KNN classification. For each minority-noisy sample, |
| 217 | +# we amplify it by the number of majority-safe samples in its 3-NN. In the |
| 218 | +# diagram below, the amplification amount is indicated by the number of solid |
| 219 | +# lines for a given minority-noisy sample's neighborhood. |
| 220 | +# |
| 221 | +# We can observe that the leftmost minority-noisy sample will be duplicated 3 |
| 222 | +# times, the middle one 1 time, and the rightmost one will not be amplified. |
| 223 | +# |
| 224 | +# Then if SPIDER-Weak, every majority-noisy sample is removed from the dataset. |
| 225 | +# Othewise if SPIDER-Relabel, we relabel their class to be the minority class |
| 226 | +# instead. These would be the samples indicated by a circled plus-sign. |
| 227 | + |
| 228 | +plot_spider('weak', X, y) |
| 229 | + |
| 230 | +############################################################################### |
| 231 | +# SPIDER-Strong |
| 232 | +############################################################################### |
| 233 | + |
| 234 | +############################################################################### |
| 235 | +# SPIDER-Strong still uses 3-NN to classify samples as 'safe' or 'noisy' as the |
| 236 | +# first step. However for the amplification step, each minority-noisy sample |
| 237 | +# looks at its 5-NN, and if the larger neighborhood still misclassifies the |
| 238 | +# sample, the 5-NN is used to amplify. Otherwise if the sample is correctly |
| 239 | +# classified with 5-NN, the regular 3-NN is used to amplify. |
| 240 | +# |
| 241 | +# In the diagram below, we can see that the left/rightmost minority-noisy |
| 242 | +# samples are misclassified using 5-NN and will be amplified by 5 and 1 |
| 243 | +# respectively. The middle minority-noisy sample is classified correctly by |
| 244 | +# using 5-NN, so amplification will be done using 3-NN. |
| 245 | +# |
| 246 | +# Next for each minority-safe sample, the amplification process is applied using |
| 247 | +# 3-NN. In the lower subplot, all but one of these samples will not be amplified |
| 248 | +# since they do not have majority-safe samples in their neighborhoods. The one |
| 249 | +# minority-safe sample to be amplified is indicated in a darker neighborhood |
| 250 | +# with lines. |
| 251 | + |
| 252 | +plot_spider('strong', X, y) |
| 253 | + |
| 254 | +plt.show() |
0 commit comments