Skip to content

Commit f4753f2

Browse files
glemaitrechkoar
authored andcommitted
BUG: ADASYN generate from minority class only (#299)
1 parent e7ccf10 commit f4753f2

File tree

2 files changed

+73
-34
lines changed

2 files changed

+73
-34
lines changed

imblearn/over_sampling/adasyn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def _sample(self, X, y):
161161
ratio_nn /= np.sum(ratio_nn)
162162
n_samples_generate = np.rint(ratio_nn * n_samples).astype(int)
163163

164+
# the nearest neighbors need to be fitted only on the current class
165+
# to find the class NN to generate new samples
166+
self.nn_.fit(X_class)
167+
_, nn_index = self.nn_.kneighbors(X_class)
168+
164169
x_class_gen = []
165170
for x_i, x_i_nn, num_sample_i in zip(X_class, nn_index,
166171
n_samples_generate):

imblearn/over_sampling/tests/test_adasyn.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,30 @@ def test_ada_fit():
4242
def test_ada_fit_sample():
4343
ada = ADASYN(random_state=RND_SEED)
4444
X_resampled, y_resampled = ada.fit_sample(X, Y)
45-
X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141],
46-
[1.25192108, -0.22367336], [0.53366841, -0.30312976],
47-
[1.52091956, -0.49283504], [-0.28162401, -2.10400981],
48-
[0.83680821, 1.72827342], [0.3084254, 0.33299982],
49-
[0.70472253, -0.73309052], [0.28893132, -0.38761769],
50-
[1.15514042, 0.0129463], [0.88407872, 0.35454207],
51-
[1.31301027, -0.92648734], [-1.11515198, -0.93689695],
52-
[-0.18410027, -0.45194484], [0.9281014, 0.53085498],
53-
[-0.14374509, 0.27370049], [-0.41635887, -0.38299653],
54-
[0.08711622, 0.93259929], [1.70580611, -0.11219234],
55-
[-0.06182085, -0.28084828], [0.38614986, -0.35405599],
56-
[0.39635544, 0.33629036], [-0.24027923, 0.04116021]])
45+
X_gt = np.array([[0.11622591, -0.0317206],
46+
[0.77481731, 0.60935141],
47+
[1.25192108, -0.22367336],
48+
[0.53366841, -0.30312976],
49+
[1.52091956, -0.49283504],
50+
[-0.28162401, -2.10400981],
51+
[0.83680821, 1.72827342],
52+
[0.3084254, 0.33299982],
53+
[0.70472253, -0.73309052],
54+
[0.28893132, -0.38761769],
55+
[1.15514042, 0.0129463],
56+
[0.88407872, 0.35454207],
57+
[1.31301027, -0.92648734],
58+
[-1.11515198, -0.93689695],
59+
[-0.18410027, -0.45194484],
60+
[0.9281014, 0.53085498],
61+
[-0.14374509, 0.27370049],
62+
[-0.41635887, -0.38299653],
63+
[0.08711622, 0.93259929],
64+
[1.70580611, -0.11219234],
65+
[0.36370445, -0.19262406],
66+
[0.28204936, -0.13953426],
67+
[0.39635544, 0.33629036],
68+
[0.35301481, 0.25795516]])
5769
y_gt = np.array([
5870
0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0
5971
])
@@ -65,16 +77,26 @@ def test_ada_fit_sample_half():
6577
ratio = 0.8
6678
ada = ADASYN(ratio=ratio, random_state=RND_SEED)
6779
X_resampled, y_resampled = ada.fit_sample(X, Y)
68-
X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141],
69-
[1.25192108, -0.22367336], [0.53366841, -0.30312976],
70-
[1.52091956, -0.49283504], [-0.28162401, -2.10400981],
71-
[0.83680821, 1.72827342], [0.3084254, 0.33299982],
72-
[0.70472253, -0.73309052], [0.28893132, -0.38761769],
73-
[1.15514042, 0.0129463], [0.88407872, 0.35454207],
74-
[1.31301027, -0.92648734], [-1.11515198, -0.93689695],
75-
[-0.18410027, -0.45194484], [0.9281014, 0.53085498],
76-
[-0.14374509, 0.27370049], [-0.41635887, -0.38299653],
77-
[0.08711622, 0.93259929], [1.70580611, -0.11219234]])
80+
X_gt = np.array([[0.11622591, -0.0317206],
81+
[0.77481731, 0.60935141],
82+
[1.25192108, -0.22367336],
83+
[0.53366841, -0.30312976],
84+
[1.52091956, -0.49283504],
85+
[-0.28162401, -2.10400981],
86+
[0.83680821, 1.72827342],
87+
[0.3084254, 0.33299982],
88+
[0.70472253, -0.73309052],
89+
[0.28893132, -0.38761769],
90+
[1.15514042, 0.0129463],
91+
[0.88407872, 0.35454207],
92+
[1.31301027, -0.92648734],
93+
[-1.11515198, -0.93689695],
94+
[-0.18410027, -0.45194484],
95+
[0.9281014, 0.53085498],
96+
[-0.14374509, 0.27370049],
97+
[-0.41635887, -0.38299653],
98+
[0.08711622, 0.93259929],
99+
[1.70580611, -0.11219234]])
78100
y_gt = np.array(
79101
[0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
80102
assert_allclose(X_resampled, X_gt, rtol=R_TOL)
@@ -85,18 +107,30 @@ def test_ada_fit_sample_nn_obj():
85107
nn = NearestNeighbors(n_neighbors=6)
86108
ada = ADASYN(random_state=RND_SEED, n_neighbors=nn)
87109
X_resampled, y_resampled = ada.fit_sample(X, Y)
88-
X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141],
89-
[1.25192108, -0.22367336], [0.53366841, -0.30312976],
90-
[1.52091956, -0.49283504], [-0.28162401, -2.10400981],
91-
[0.83680821, 1.72827342], [0.3084254, 0.33299982],
92-
[0.70472253, -0.73309052], [0.28893132, -0.38761769],
93-
[1.15514042, 0.0129463], [0.88407872, 0.35454207],
94-
[1.31301027, -0.92648734], [-1.11515198, -0.93689695],
95-
[-0.18410027, -0.45194484], [0.9281014, 0.53085498],
96-
[-0.14374509, 0.27370049], [-0.41635887, -0.38299653],
97-
[0.08711622, 0.93259929], [1.70580611, -0.11219234],
98-
[-0.06182085, -0.28084828], [0.38614986, -0.35405599],
99-
[0.39635544, 0.33629036], [-0.24027923, 0.04116021]])
110+
X_gt = np.array([[0.11622591, -0.0317206],
111+
[0.77481731, 0.60935141],
112+
[1.25192108, -0.22367336],
113+
[0.53366841, -0.30312976],
114+
[1.52091956, -0.49283504],
115+
[-0.28162401, -2.10400981],
116+
[0.83680821, 1.72827342],
117+
[0.3084254, 0.33299982],
118+
[0.70472253, -0.73309052],
119+
[0.28893132, -0.38761769],
120+
[1.15514042, 0.0129463],
121+
[0.88407872, 0.35454207],
122+
[1.31301027, -0.92648734],
123+
[-1.11515198, -0.93689695],
124+
[-0.18410027, -0.45194484],
125+
[0.9281014, 0.53085498],
126+
[-0.14374509, 0.27370049],
127+
[-0.41635887, -0.38299653],
128+
[0.08711622, 0.93259929],
129+
[1.70580611, -0.11219234],
130+
[0.36370445, -0.19262406],
131+
[0.28204936, -0.13953426],
132+
[0.39635544, 0.33629036],
133+
[0.35301481, 0.25795516]])
100134
y_gt = np.array([
101135
0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0
102136
])

0 commit comments

Comments
 (0)