diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index 85f314c78..f014243e7 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -1,4 +1,4 @@ -"""Class to perform random over-sampling.""" +"""Class to perform over-sampling using ADASYN.""" # Authors: Guillaume Lemaitre # Christos Aridas @@ -104,8 +104,8 @@ def _fit_resample(self, X, y): self._validate_estimator() random_state = check_random_state(self.random_state) - X_resampled = X.copy() - y_resampled = y.copy() + X_resampled = [X.copy()] + y_resampled = [y.copy()] for class_sample, n_samples in self.sampling_strategy_.items(): if n_samples == 0: @@ -114,13 +114,12 @@ def _fit_resample(self, X, y): X_class = _safe_indexing(X, target_class_indices) self.nn_.fit(X) - _, nn_index = self.nn_.kneighbors(X_class) + nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:] # The ratio is computed using a one-vs-rest manner. Using majority # in multi-class would lead to slightly different results at the # cost of introducing a new parameter. - ratio_nn = np.sum(y[nn_index[:, 1:]] != class_sample, axis=1) / ( - self.nn_.n_neighbors - 1 - ) + n_neighbors = self.nn_.n_neighbors - 1 + ratio_nn = np.sum(y[nns] != class_sample, axis=1) / n_neighbors if not np.sum(ratio_nn): raise RuntimeError( "Not any neigbours belong to the majority" @@ -131,7 +130,9 @@ def _fit_resample(self, X, y): ) ratio_nn /= np.sum(ratio_nn) n_samples_generate = np.rint(ratio_nn * n_samples).astype(int) - if not np.sum(n_samples_generate): + # rounding may cause new amount for n_samples + n_samples = np.sum(n_samples_generate) + if not n_samples: raise ValueError( "No samples will be generated with the" " provided ratio settings." @@ -140,66 +141,30 @@ def _fit_resample(self, X, y): # the nearest neighbors need to be fitted only on the current class # to find the class NN to generate new samples self.nn_.fit(X_class) - _, nn_index = self.nn_.kneighbors(X_class) + nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:] - if sparse.issparse(X): - row_indices, col_indices, samples = [], [], [] - n_samples_generated = 0 - for x_i, x_i_nn, num_sample_i in zip( - X_class, nn_index, n_samples_generate - ): - if num_sample_i == 0: - continue - nn_zs = random_state.randint( - 1, high=self.nn_.n_neighbors, size=num_sample_i - ) - steps = random_state.uniform(size=len(nn_zs)) - if x_i.nnz: - for step, nn_z in zip(steps, nn_zs): - sample = x_i + step * ( - X_class[x_i_nn[nn_z], :] - x_i - ) - row_indices += [n_samples_generated] * len( - sample.indices - ) - col_indices += sample.indices.tolist() - samples += sample.data.tolist() - n_samples_generated += 1 - X_new = sparse.csr_matrix( - (samples, (row_indices, col_indices)), - [np.sum(n_samples_generate), X.shape[1]], - dtype=X.dtype, - ) - y_new = np.array( - [class_sample] * np.sum(n_samples_generate), dtype=y.dtype - ) - else: - x_class_gen = [] - for x_i, x_i_nn, num_sample_i in zip( - X_class, nn_index, n_samples_generate - ): - if num_sample_i == 0: - continue - nn_zs = random_state.randint( - 1, high=self.nn_.n_neighbors, size=num_sample_i - ) - steps = random_state.uniform(size=len(nn_zs)) - x_class_gen.append( - [ - x_i + step * (X_class[x_i_nn[nn_z], :] - x_i) - for step, nn_z in zip(steps, nn_zs) - ] - ) - - X_new = np.concatenate(x_class_gen).astype(X.dtype) - y_new = np.array( - [class_sample] * np.sum(n_samples_generate), dtype=y.dtype - ) + enumerated_class_indices = np.arange(len(target_class_indices)) + rows = np.repeat(enumerated_class_indices, n_samples_generate) + cols = random_state.choice(n_neighbors, size=n_samples) + diffs = X_class[nns[rows, cols]] - X_class[rows] + steps = random_state.uniform(size=(n_samples, 1)) - if sparse.issparse(X_new): - X_resampled = sparse.vstack([X_resampled, X_new]) + if sparse.issparse(X): + sparse_func = type(X).__name__ + steps = getattr(sparse, sparse_func)(steps) + X_new = X_class[rows] + steps.multiply(diffs) else: - X_resampled = np.vstack((X_resampled, X_new)) - y_resampled = np.hstack((y_resampled, y_new)) + X_new = X_class[rows] + steps * diffs + + X_new = X_new.astype(X.dtype) + y_new = np.full(n_samples, fill_value=class_sample, dtype=y.dtype) + X_resampled.append(X_new) + y_resampled.append(y_new) + + if sparse.issparse(X): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) return X_resampled, y_resampled diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index b764da6b6..cea14cfd2 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -98,7 +98,7 @@ def _make_samples( """ random_state = check_random_state(self.random_state) samples_indices = random_state.randint( - low=0, high=len(nn_num.flatten()), size=n_samples + low=0, high=nn_num.size, size=n_samples ) # np.newaxis for backwards compatability with random_state @@ -731,13 +731,12 @@ def _fit_resample(self, X, y): X_resampled.append(X_new) y_resampled.append(y_new) - if sparse.issparse(X_new): + if sparse.issparse(X): X_resampled = sparse.vstack(X_resampled, format=X.format) else: X_resampled = np.vstack(X_resampled) y_resampled = np.hstack(y_resampled) - return X_resampled, y_resampled diff --git a/imblearn/over_sampling/tests/test_adasyn.py b/imblearn/over_sampling/tests/test_adasyn.py index f7fcb07c7..87769f08e 100644 --- a/imblearn/over_sampling/tests/test_adasyn.py +++ b/imblearn/over_sampling/tests/test_adasyn.py @@ -72,10 +72,10 @@ def test_ada_fit_resample(): [-0.41635887, -0.38299653], [0.08711622, 0.93259929], [1.70580611, -0.11219234], - [0.94899098, -0.30508981], - [0.28204936, -0.13953426], - [1.58028868, -0.04089947], - [0.66117333, -0.28009063], + [0.88161986, -0.2829741], + [0.35681689, -0.18814597], + [1.4148276, 0.05308106], + [0.3136591, -0.31327875], ] ) y_gt = np.array( @@ -136,10 +136,10 @@ def test_ada_fit_resample_nn_obj(): [-0.41635887, -0.38299653], [0.08711622, 0.93259929], [1.70580611, -0.11219234], - [0.94899098, -0.30508981], - [0.28204936, -0.13953426], - [1.58028868, -0.04089947], - [0.66117333, -0.28009063], + [0.88161986, -0.2829741], + [0.35681689, -0.18814597], + [1.4148276, 0.05308106], + [0.3136591, -0.31327875], ] ) y_gt = np.array(