Skip to content

Commit 50dd208

Browse files
authored
EHN allow to become minority in AllKNN (#313)
1 parent 2c0628f commit 50dd208

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

imblearn/under_sampling/prototype_selection/edited_nearest_neighbours.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,12 @@ class AllKNN(BaseCleaningSampler):
477477
- If ``'mode'``, the majority vote of the neighbours will be used in
478478
order to exclude a sample.
479479
480+
allow_minority : bool, optional (default=False)
481+
If ``True``, it allows the majority classes to become the minority
482+
class without early stopping.
483+
484+
.. versionadded:: 0.3
485+
480486
n_jobs : int, optional (default=-1)
481487
The number of thread to open when it is possible.
482488
@@ -519,12 +525,14 @@ def __init__(self,
519525
size_ngh=None,
520526
n_neighbors=3,
521527
kind_sel='all',
528+
allow_minority=False,
522529
n_jobs=-1):
523530
super(AllKNN, self).__init__(ratio=ratio, random_state=random_state)
524531
self.return_indices = return_indices
525532
self.size_ngh = size_ngh
526533
self.n_neighbors = n_neighbors
527534
self.kind_sel = kind_sel
535+
self.allow_minority = allow_minority
528536
self.n_jobs = n_jobs
529537

530538
def _validate_estimator(self):
@@ -595,6 +603,9 @@ def _sample(self, X, y):
595603
])
596604
b_min_bec_maj = np.any(count_non_min <
597605
target_stats[class_minority])
606+
if self.allow_minority:
607+
# overwrite b_min_bec_maj
608+
b_min_bec_maj = False
598609

599610
# Case 2
600611
b_remove_maj_class = (len(stats_enn) < len(target_stats))

imblearn/under_sampling/prototype_selection/tests/test_allknn.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import numpy as np
99
from sklearn.utils.testing import (assert_allclose, assert_array_equal,
10-
assert_raises)
10+
assert_raises, assert_true)
1111
from sklearn.neighbors import NearestNeighbors
12+
from sklearn.datasets import make_classification
1213

1314
from imblearn.under_sampling import AllKNN
1415

@@ -66,6 +67,19 @@ def test_allknn_fit_sample():
6667
assert_allclose(y_resampled, y_gt, rtol=R_TOL)
6768

6869

70+
def test_all_knn_allow_minority():
71+
X, y = make_classification(n_samples=10000, n_features=2, n_informative=2,
72+
n_redundant=0, n_repeated=0, n_classes=3,
73+
n_clusters_per_class=1, weights=[0.2, 0.3, 0.5],
74+
class_sep=0.4, random_state=0)
75+
76+
allknn = AllKNN(random_state=RND_SEED, allow_minority=True)
77+
X_res_1, y_res_1 = allknn.fit_sample(X, y)
78+
allknn = AllKNN(random_state=RND_SEED)
79+
X_res_2, y_res_2 = allknn.fit_sample(X, y)
80+
assert_true(len(y_res_1) < len(y_res_2))
81+
82+
6983
def test_allknn_fit_sample_with_indices():
7084
allknn = AllKNN(return_indices=True, random_state=RND_SEED)
7185
X_resampled, y_resampled, idx_under = allknn.fit_sample(X, Y)

0 commit comments

Comments
 (0)