Skip to content

Commit 4ec89cf

Browse files
committed
pep8; remove old commented-out code in spider amplify
1 parent c0fe84e commit 4ec89cf

File tree

3 files changed

+19
-32
lines changed

3 files changed

+19
-32
lines changed

examples/combine/plot_illustration_spider.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828

2929
Neighborhood = namedtuple('Neighborhood', 'radius, neighbors')
3030

31+
3132
def plot_X(X, ax, **kwargs):
3233
ax.scatter(X[:, 0], X[:, 1], **kwargs)
3334

35+
3436
def correct(nn, y_fit, X, y, additional=False):
3537
n_neighbors = nn.n_neighbors
3638
if additional:
@@ -39,12 +41,13 @@ def correct(nn, y_fit, X, y, additional=False):
3941
y_pred, _ = mode(y_fit[nn_idxs], axis=1)
4042
return (y == y_pred.ravel())
4143

44+
4245
def get_neighborhoods(spider, X_fit, y_fit, X_flagged, y_flagged, idx):
4346
point = X_flagged[idx]
4447

4548
additional = (spider.kind == 'strong')
46-
if correct(spider.nn_, y_fit, point[np.newaxis], y_flagged[idx][np.newaxis],
47-
additional=additional):
49+
if correct(spider.nn_, y_fit, point[np.newaxis],
50+
y_flagged[idx][np.newaxis], additional=additional):
4851
additional = False
4952

5053
idxs_k = spider._locate_neighbors(point[np.newaxis])
@@ -61,6 +64,7 @@ def get_neighborhoods(spider, X_fit, y_fit, X_flagged, y_flagged, idx):
6164

6265
return neighborhood_k, neighborhood_k2, point, additional
6366

67+
6468
def draw_neighborhoods(spider, neighborhood_k, neighborhood_k2, point,
6569
additional, ax, outer=True, alpha=0.5):
6670
PartialCircle = partial(Circle, facecolor='none', edgecolor='black',
@@ -80,13 +84,15 @@ def draw_neighborhoods(spider, neighborhood_k, neighborhood_k2, point,
8084
if (spider.kind == 'strong') and outer:
8185
ax.add_patch(circle_k2)
8286

87+
8388
def draw_amplification(X_flagged, point, neighbors, ax):
8489
for neigh in neighbors:
8590
arr = np.vstack([point, neigh])
8691
xs, ys = np.split(arr, 2, axis=1)
8792
linestyle = 'solid' if neigh in X_flagged else 'dotted'
8893
ax.plot(xs, ys, color='black', linestyle=linestyle)
8994

95+
9096
def plot_spider(kind, X, y):
9197
if kind == 'strong':
9298
_, axes = plt.subplots(2, 1, figsize=(12, 16))
@@ -203,7 +209,7 @@ def plot_spider(kind, X, y):
203209
])
204210

205211
y = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0,
206-
0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0])
212+
0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0])
207213

208214

209215
###############################################################################
@@ -213,10 +219,10 @@ def plot_spider(kind, X, y):
213219
###############################################################################
214220
# Both SPIDER-Weak and SPIDER-Relabel start by labeling whether samples are
215221
# 'safe' or 'noisy' by looking at each point's 3-NN and seeing if it would be
216-
# classified correctly using KNN classification. For each minority-noisy sample,
217-
# we amplify it by the number of majority-safe samples in its 3-NN. In the
218-
# diagram below, the amplification amount is indicated by the number of solid
219-
# lines for a given minority-noisy sample's neighborhood.
222+
# classified correctly using KNN classification. For each minority-noisy
223+
# sample, we amplify it by the number of majority-safe samples in its 3-NN. In
224+
# the diagram below, the amplification amount is indicated by the number of
225+
# solid lines for a given minority-noisy sample's neighborhood.
220226
#
221227
# We can observe that the leftmost minority-noisy sample will be duplicated 3
222228
# times, the middle one 1 time, and the rightmost one will not be amplified.
@@ -243,11 +249,11 @@ def plot_spider(kind, X, y):
243249
# respectively. The middle minority-noisy sample is classified correctly by
244250
# using 5-NN, so amplification will be done using 3-NN.
245251
#
246-
# Next for each minority-safe sample, the amplification process is applied using
247-
# 3-NN. In the lower subplot, all but one of these samples will not be amplified
248-
# since they do not have majority-safe samples in their neighborhoods. The one
249-
# minority-safe sample to be amplified is indicated in a darker neighborhood
250-
# with lines.
252+
# Next for each minority-safe sample, the amplification process is applied
253+
# using 3-NN. In the lower subplot, all but one of these samples will not be
254+
# amplified since they do not have majority-safe samples in their
255+
# neighborhoods. The one minority-safe sample to be amplified is indicated in a
256+
# darker neighborhood with lines.
251257

252258
plot_spider('strong', X, y)
253259

imblearn/combine/_preprocess/_spider.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -213,26 +213,6 @@ def _amplify(self, X, y, additional=False):
213213
amplify_amounts = np.isin(
214214
nn_indices, self._amplify_indices).sum(axis=1)
215215

216-
# if sparse.issparse(X):
217-
# X_parts = []
218-
# y_parts = []
219-
# for amount in filter(bool, np.unique(amplify_amounts)):
220-
# mask = safe_mask(X, amplify_amounts == amount)
221-
# # breakpoint()
222-
# X_part = X[mask]
223-
# y_part = y[mask]
224-
# X_parts.extend([X_part] * amount)
225-
# y_parts.extend([y_part] * amount)
226-
# # try:
227-
# X_new = sparse.vstack(X_parts)
228-
# y_new = np.hstack(y_parts)
229-
# # except ValueError: # -- bool filter makes this unnecessary
230-
# # X_new = np.empty(0, dtype=X.dtype)
231-
# # y_new = np.empty(0, dtype=y.dtype)
232-
# else:
233-
# X_new = np.repeat(X, amplify_amounts, axis=0)
234-
# y_new = np.repeat(y, amplify_amounts)
235-
236216
X_parts = []
237217
y_parts = []
238218
for amount in filter(bool, np.unique(amplify_amounts)):

imblearn/combine/tests/test_spider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
RND_SEED = 0
5858
R_TOL = 1e-4
5959

60+
6061
def test_weak():
6162
X_expected = np.array([
6263
[3.03, -4.15],

0 commit comments

Comments
 (0)