Skip to content

Commit e7e0baa

Browse files
glemaitrechkoar
authored andcommitted
FIX: bug fix in Nearmiss-3 for not returning the right indices (#282)
1 parent 5ca3037 commit e7e0baa

File tree

3 files changed

+29
-12
lines changed

3 files changed

+29
-12
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ Release history
99
Changelog
1010
---------
1111

12+
Bug fixes
13+
---------
14+
15+
- Fixed a bug in :class:`under_sampling.NearMiss` version 3. The
16+
indices returned were wrong. By `Guillaume Lemaitre`_.
17+
1218
New features
1319
~~~~~~~~~~~~
1420

imblearn/under_sampling/nearmiss.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,11 @@ def _sample(self, X, y):
355355

356356
# Create the subset containing the samples found during the NN
357357
# search. Linearize the indexes and remove the double values
358-
idx_vec = np.unique(idx_vec.reshape(-1))
358+
idx_vec_farthest = np.unique(idx_vec.reshape(-1))
359359

360360
# Create the subset
361-
sub_samples_x = sub_samples_x[idx_vec, :]
362-
sub_samples_y = sub_samples_y[idx_vec]
361+
sub_samples_x = sub_samples_x[idx_vec_farthest, :]
362+
sub_samples_y = sub_samples_y[idx_vec_farthest]
363363

364364
# Compute the NN considering the current class
365365
dist_vec, idx_vec = self.nn_.kneighbors(
@@ -372,6 +372,10 @@ def _sample(self, X, y):
372372
num_samples,
373373
key,
374374
sel_strategy='farthest')
375+
376+
# idx_tmp is relative to the feature selected in the
377+
# previous step and we need to find the indirection
378+
idx_tmp = np.flatnonzero(y == key)[idx_vec_farthest[idx_tmp]]
375379
else:
376380
raise NotImplementedError
377381

imblearn/under_sampling/tests/test_nearmiss.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,24 @@
1010

1111
# Generate a global dataset to use for the 3 version of nearmiss
1212
RND_SEED = 0
13-
X = np.array([[1.17737838, -0.2002118], [0.4960075, 0.86130762],
14-
[-0.05903827, 0.10947647], [0.91464286, 1.61369212],
15-
[-0.54619583, 1.73009918], [-0.60413357, 0.24628718],
16-
[0.45713638, 1.31069295], [-0.04032409, 3.01186964],
17-
[0.03142011, 0.12323596], [0.50701028, -0.17636928],
18-
[-0.80809175, -1.09917302], [-0.20497017, -0.26630228],
19-
[0.99272351, -0.11631728], [-1.95581933, 0.69609604],
13+
X = np.array([[1.17737838, -0.2002118],
14+
[0.4960075, 0.86130762],
15+
[-0.05903827, 0.10947647],
16+
[0.91464286, 1.61369212],
17+
[-0.54619583, 1.73009918],
18+
[-0.60413357, 0.24628718],
19+
[0.45713638, 1.31069295],
20+
[-0.04032409, 3.01186964],
21+
[0.03142011, 0.12323596],
22+
[0.50701028, -0.17636928],
23+
[-0.80809175, -1.09917302],
24+
[-0.20497017, -0.26630228],
25+
[0.99272351, -0.11631728],
26+
[-1.95581933, 0.69609604],
2027
[1.15157493, -1.2981518]])
2128
Y = np.array([1, 2, 1, 0, 2, 1, 2, 2, 1, 2, 0, 0, 2, 1, 2])
2229

23-
VERSION_NEARMISS = [1, 2, 3]
30+
VERSION_NEARMISS = (1, 2, 3)
2431

2532

2633
# FIXME remove at the end of the deprecation 0.4
@@ -134,7 +141,7 @@ def test_nm_fit_sample_auto_indices():
134141

135142
idx_gt = [np.array([3, 10, 11, 2, 8, 5, 9, 1, 6]),
136143
np.array([3, 10, 11, 2, 8, 5, 9, 1, 6]),
137-
np.array([3, 10, 11, 0, 2, 3, 5, 1, 4])]
144+
np.array([3, 10, 11, 0, 5, 8, 14, 4, 12])]
138145

139146
for version_idx, version in enumerate(VERSION_NEARMISS):
140147
nm = NearMiss(ratio=ratio, random_state=RND_SEED,

0 commit comments

Comments
 (0)