From 5618e9f2dee2a20acde4d06d341852fcd7ccf73a Mon Sep 17 00:00:00 2001 From: Matt Eding Date: Wed, 4 Sep 2019 15:02:15 -0700 Subject: [PATCH 1/6] vectorized BaseSMOTE, SMOTE, and SMOTENC --- imblearn/over_sampling/_smote.py | 112 +++++++++++++++---------------- 1 file changed, 53 insertions(+), 59 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 905629de4..04527eb80 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -107,33 +107,17 @@ def _make_samples(self, random_state = check_random_state(self.random_state) samples_indices = random_state.randint( low=0, high=len(nn_num.flatten()), size=n_samples) - steps = step_size * random_state.uniform(size=n_samples) + + # np.newaxis for backwards compatability with random_state + steps = step_size * random_state.uniform(size=n_samples)[:, np.newaxis] rows = np.floor_divide(samples_indices, nn_num.shape[1]) cols = np.mod(samples_indices, nn_num.shape[1]) - y_new = np.array([y_type] * len(samples_indices), dtype=y_dtype) - - if sparse.issparse(X): - row_indices, col_indices, samples = [], [], [] - for i, (row, col, step) in enumerate(zip(rows, cols, steps)): - if X[row].nnz: - sample = self._generate_sample(X, nn_data, nn_num, - row, col, step) - row_indices += [i] * len(sample.indices) - col_indices += sample.indices.tolist() - samples += sample.data.tolist() - return (sparse.csr_matrix((samples, (row_indices, col_indices)), - [len(samples_indices), X.shape[1]], - dtype=X.dtype), - y_new) - else: - X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype) - for i, (row, col, step) in enumerate(zip(rows, cols, steps)): - X_new[i] = self._generate_sample(X, nn_data, nn_num, - row, col, step) - return X_new, y_new + X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps) + y_new = np.full(len(samples_indices), fill_value=y_type, dtype=y_dtype) + return X_new, y_new - def _generate_sample(self, X, nn_data, nn_num, row, col, step): + def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): r"""Generate a synthetic sample. The rule for the generation is: @@ -157,24 +141,33 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): nn_num : ndarray, shape (n_samples_all, k_nearest_neighbours) The nearest neighbours of each sample in `nn_data`. - row : int - Index pointing at feature vector in X which will be used - as a base for creating new sample. + rows : ndarray[int], shape (n_samples,) + Indices pointing at feature vector in X which will be used + as a base for creating new samples. - col : int - Index pointing at which nearest neighbor of base feature vector - will be used when creating new sample. + cols : ndarray[int], shape (n_samples,) + Indices pointing at which nearest neighbor of base feature vector + will be used when creating new samples. - step : float - Step size for new sample. + steps : ndarray[float], shape (n_samples,) + Step sizes for new samples. Returns ------- - X_new : {ndarray, sparse matrix}, shape (n_features,) - Single synthetically generated sample. + X_new : {ndarray, sparse matrix}, shape (n_samples, n_features) + Synthetically generated samples. """ - return X[row] - step * (X[row] - nn_data[nn_num[row, col]]) + diffs = nn_data[nn_num[rows, cols]] - X[rows] + + if sparse.issparse(X): + sparse_func = type(X).__name__ + steps = getattr(sparse, sparse_func)(steps) + X_new = X[rows] + steps.multiply(diffs) + else: + X_new = X[rows] + steps * diffs + + return X_new.astype(X.dtype) def _in_danger_noise(self, nn_estimator, samples, target_class, y, kind='danger'): @@ -800,8 +793,8 @@ def _sample(self, X, y): # FIXME: uncomment in version 0.6 # self._validate_estimator() - X_resampled = X.copy() - y_resampled = y.copy() + X_resampled = [X.copy()] + y_resampled = [y.copy()] for class_sample, n_samples in self.sampling_strategy_.items(): if n_samples == 0: @@ -813,14 +806,14 @@ def _sample(self, X, y): nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:] X_new, y_new = self._make_samples(X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0) + X_resampled.append(X_new) + y_resampled.append(y_new) - if sparse.issparse(X_new): - X_resampled = sparse.vstack([X_resampled, X_new]) - sparse_func = 'tocsc' if X.format == 'csc' else 'tocsr' - X_resampled = getattr(X_resampled, sparse_func)() - else: - X_resampled = np.vstack((X_resampled, X_new)) - y_resampled = np.hstack((y_resampled, y_new)) + if sparse.issparse(X_new): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) return X_resampled, y_resampled @@ -1068,29 +1061,30 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): of the majority class. """ rng = check_random_state(self.random_state) - sample = super()._generate_sample(X, nn_data, nn_num, - row, col, step) - # To avoid conversion and since there is only few samples used, we - # convert those samples to dense array. - sample = (sample.toarray().squeeze() - if sparse.issparse(sample) else sample) - all_neighbors = nn_data[nn_num[row]] - all_neighbors = (all_neighbors.toarray() - if sparse.issparse(all_neighbors) else all_neighbors) - + X_new = super()._generate_samples(X, nn_data, nn_num, rows, cols, steps) + # change in sparsity structure more efficient with LIL than CSR + X_new = (X_new.tolil() if sparse.issparse(X_new) else X_new) + # convert to dense array since scipy.sparse doesn't handle 3D + nn_data = (nn_data.toarray() if sparse.issparse(nn_data) else nn_data) + all_neighbors = nn_data[nn_num[rows]] + categories_size = ([self.continuous_features_.size] + [cat.size for cat in self.ohe_.categories_]) for start_idx, end_idx in zip(np.cumsum(categories_size)[:-1], np.cumsum(categories_size)[1:]): - col_max = all_neighbors[:, start_idx:end_idx].sum(axis=0) + col_maxs = all_neighbors[:, :, start_idx:end_idx].sum(axis=1) # tie breaking argmax - col_sel = rng.choice(np.flatnonzero( - np.isclose(col_max, col_max.max()))) - sample[start_idx:end_idx] = 0 - sample[start_idx + col_sel] = 1 + is_max = np.isclose(col_maxs, col_maxs.max(axis=1, keepdims=True)) + max_idxs = rng.permutation(np.argwhere(is_max)) + xs, idx_sels = np.unique(max_idxs[:, 0], return_index=True) + col_sels = max_idxs[idx_sels, 1] + + ys = start_idx + col_sels + X_new[:, start_idx:end_idx] = 0 + X_new[xs, ys] = 1 - return sparse.csr_matrix(sample) if sparse.issparse(X) else sample + return X_new @Substitution( From fa8fdf0af8ac969e9a27f4c0de729e6258e4c7eb Mon Sep 17 00:00:00 2001 From: Matt Eding Date: Thu, 5 Sep 2019 08:12:07 -0700 Subject: [PATCH 2/6] fix PEP8; fix _generate_samples params for SMOTENC --- imblearn/over_sampling/_smote.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 04527eb80..ed27aca07 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1052,7 +1052,7 @@ def _fit_resample(self, X, y): return X_resampled, y_resampled - def _generate_sample(self, X, nn_data, nn_num, row, col, step): + def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): """Generate a synthetic sample with an additional steps for the categorical features. @@ -1061,13 +1061,15 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): of the majority class. """ rng = check_random_state(self.random_state) - X_new = super()._generate_samples(X, nn_data, nn_num, rows, cols, steps) + X_new = super()._generate_samples( + X, nn_data, nn_num, rows, cols, steps) # change in sparsity structure more efficient with LIL than CSR X_new = (X_new.tolil() if sparse.issparse(X_new) else X_new) + # convert to dense array since scipy.sparse doesn't handle 3D nn_data = (nn_data.toarray() if sparse.issparse(nn_data) else nn_data) all_neighbors = nn_data[nn_num[rows]] - + categories_size = ([self.continuous_features_.size] + [cat.size for cat in self.ohe_.categories_]) From 26a8c4d96353e6c44a8750ce21f9b21954ced9c9 Mon Sep 17 00:00:00 2001 From: Matt Eding Date: Sun, 8 Sep 2019 13:48:40 -0700 Subject: [PATCH 3/6] BaseSMOTE len(sample_indices) -> n_samples; doctest SMOTENC A to B due to random state change --- doc/over_sampling.rst | 2 +- imblearn/over_sampling/_smote.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 6159e925b..0048f54db 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -190,7 +190,7 @@ features or a boolean mask marking these features:: >>> print(X_resampled[-5:]) [['A' 0.5246469549655818 2] ['B' -0.3657680728116921 2] - ['A' 0.9344237230779993 2] + ['B' 0.9344237230779993 2] ['B' 0.3710891618824609 2] ['B' 0.3327240726719727 2]] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index ed27aca07..f5e00d4a4 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -114,7 +114,7 @@ def _make_samples(self, cols = np.mod(samples_indices, nn_num.shape[1]) X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps) - y_new = np.full(len(samples_indices), fill_value=y_type, dtype=y_dtype) + y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype) return X_new, y_new def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): From dbce8f0af097ed24092bdbe3ce6b4fb749bb64e7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Nov 2019 20:28:26 +0100 Subject: [PATCH 4/6] nitpicks style min diff --- imblearn/over_sampling/_smote.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 1670e734f..23b59ec42 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -98,7 +98,8 @@ def _make_samples( """ random_state = check_random_state(self.random_state) samples_indices = random_state.randint( - low=0, high=len(nn_num.flatten()), size=n_samples) + low=0, high=len(nn_num.flatten()), size=n_samples + ) # np.newaxis for backwards compatability with random_state steps = step_size * random_state.uniform(size=n_samples)[:, np.newaxis] @@ -133,20 +134,20 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): nn_num : ndarray of shape (n_samples_all, k_nearest_neighbours) The nearest neighbours of each sample in `nn_data`. - rows : ndarray[int], shape (n_samples,) + rows : ndarray of shape (n_samples,), dtype=int Indices pointing at feature vector in X which will be used as a base for creating new samples. - cols : ndarray[int], shape (n_samples,) + cols : ndarray of shape (n_samples,), dtype=int Indices pointing at which nearest neighbor of base feature vector will be used when creating new samples. - steps : ndarray[float], shape (n_samples,) + steps : ndarray of shape (n_samples,), dtype=float Step sizes for new samples. Returns ------- - X_new : {ndarray, sparse matrix}, shape (n_samples, n_features) + X_new : {ndarray, sparse matrix} of shape (n_samples, n_features) Synthetically generated samples. """ diffs = nn_data[nn_num[rows, cols]] - X[rows] @@ -724,8 +725,9 @@ def _fit_resample(self, X, y): self.nn_k_.fit(X_class) nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:] - X_new, y_new = self._make_samples(X_class, y.dtype, class_sample, - X_class, nns, n_samples, 1.0) + X_new, y_new = self._make_samples( + X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0 + ) X_resampled.append(X_new) y_resampled.append(y_new) @@ -1011,7 +1013,8 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): """ rng = check_random_state(self.random_state) X_new = super()._generate_samples( - X, nn_data, nn_num, rows, cols, steps) + X, nn_data, nn_num, rows, cols, steps + ) # change in sparsity structure more efficient with LIL than CSR X_new = (X_new.tolil() if sparse.issparse(X_new) else X_new) @@ -1019,7 +1022,6 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): nn_data = (nn_data.toarray() if sparse.issparse(nn_data) else nn_data) all_neighbors = nn_data[nn_num[rows]] - categories_size = [self.continuous_features_.size] + [ cat.size for cat in self.ohe_.categories_ ] From 96c52aefddb7ca6ff8867c37912e9a008f1b64fa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Nov 2019 20:35:39 +0100 Subject: [PATCH 5/6] add whats new --- doc/whats_new/v0.6.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index 612845e03..7f754690a 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -64,6 +64,15 @@ Enhancement finite values are present in ``X``. :pr:`643` by `Guillaume Lemaitre `. +- The samples generation in + :class:`imblearn.over_sampling.SMOTE`, + :class:`imblearn.over_sampling.BorderlineSMOTE`, + :class:`imblearn.over_sampling.SVMSMOTE`, + :class:`imblearn.over_sampling.KMeansSMOTE`, + :class:`imblearn.over_sampling.SMOTENC` is now vectorize with giving + an additional speed-up when `X` in sparse. + :pr:`596` by :user:`Matt Edding `. + Deprecation ........... From 98a0b2f20485a4bc535a4f5578b17c1bce63b8ed Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Nov 2019 20:57:40 +0100 Subject: [PATCH 6/6] DOC add info regarding change in SMOTENC --- doc/whats_new/v0.6.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index 7f754690a..96451c225 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -15,6 +15,11 @@ scikit-learn: - :class:`imblearn.under_sampling.ClusterCentroids` - :class:`imblearn.under_sampling.InstanceHardnessThreshold` +The following samplers will give different results due to change linked to +the random state internal usage: + +- :class:`imblearn.over_sampling.SMOTENC` + Bug fixes .........