diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 06242772c..b52de3436 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -190,7 +190,7 @@ features or a boolean mask marking these features:: >>> print(X_resampled[-5:]) [['A' 0.5246469549655818 2] ['B' -0.3657680728116921 2] - ['A' 0.9344237230779993 2] + ['B' 0.9344237230779993 2] ['B' 0.3710891618824609 2] ['B' 0.3327240726719727 2]] diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index 612845e03..96451c225 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -15,6 +15,11 @@ scikit-learn: - :class:`imblearn.under_sampling.ClusterCentroids` - :class:`imblearn.under_sampling.InstanceHardnessThreshold` +The following samplers will give different results due to change linked to +the random state internal usage: + +- :class:`imblearn.over_sampling.SMOTENC` + Bug fixes ......... @@ -64,6 +69,15 @@ Enhancement finite values are present in ``X``. :pr:`643` by `Guillaume Lemaitre `. +- The samples generation in + :class:`imblearn.over_sampling.SMOTE`, + :class:`imblearn.over_sampling.BorderlineSMOTE`, + :class:`imblearn.over_sampling.SVMSMOTE`, + :class:`imblearn.over_sampling.KMeansSMOTE`, + :class:`imblearn.over_sampling.SMOTENC` is now vectorize with giving + an additional speed-up when `X` in sparse. + :pr:`596` by :user:`Matt Edding `. + Deprecation ........... diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index c583abb20..23b59ec42 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -100,39 +100,17 @@ def _make_samples( samples_indices = random_state.randint( low=0, high=len(nn_num.flatten()), size=n_samples ) - steps = step_size * random_state.uniform(size=n_samples) + + # np.newaxis for backwards compatability with random_state + steps = step_size * random_state.uniform(size=n_samples)[:, np.newaxis] rows = np.floor_divide(samples_indices, nn_num.shape[1]) cols = np.mod(samples_indices, nn_num.shape[1]) - y_new = np.array([y_type] * len(samples_indices), dtype=y_dtype) + X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps) + y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype) + return X_new, y_new - if sparse.issparse(X): - row_indices, col_indices, samples = [], [], [] - for i, (row, col, step) in enumerate(zip(rows, cols, steps)): - if X[row].nnz: - sample = self._generate_sample( - X, nn_data, nn_num, row, col, step - ) - row_indices += [i] * len(sample.indices) - col_indices += sample.indices.tolist() - samples += sample.data.tolist() - return ( - sparse.csr_matrix( - (samples, (row_indices, col_indices)), - [len(samples_indices), X.shape[1]], - dtype=X.dtype, - ), - y_new, - ) - else: - X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype) - for i, (row, col, step) in enumerate(zip(rows, cols, steps)): - X_new[i] = self._generate_sample( - X, nn_data, nn_num, row, col, step - ) - return X_new, y_new - - def _generate_sample(self, X, nn_data, nn_num, row, col, step): + def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): r"""Generate a synthetic sample. The rule for the generation is: @@ -156,23 +134,32 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): nn_num : ndarray of shape (n_samples_all, k_nearest_neighbours) The nearest neighbours of each sample in `nn_data`. - row : int - Index pointing at feature vector in X which will be used - as a base for creating new sample. + rows : ndarray of shape (n_samples,), dtype=int + Indices pointing at feature vector in X which will be used + as a base for creating new samples. - col : int - Index pointing at which nearest neighbor of base feature vector - will be used when creating new sample. + cols : ndarray of shape (n_samples,), dtype=int + Indices pointing at which nearest neighbor of base feature vector + will be used when creating new samples. - step : float - Step size for new sample. + steps : ndarray of shape (n_samples,), dtype=float + Step sizes for new samples. Returns ------- - X_new : {ndarray, sparse matrix} of shape (n_features,) - Single synthetically generated sample. + X_new : {ndarray, sparse matrix} of shape (n_samples, n_features) + Synthetically generated samples. """ - return X[row] - step * (X[row] - nn_data[nn_num[row, col]]) + diffs = nn_data[nn_num[rows, cols]] - X[rows] + + if sparse.issparse(X): + sparse_func = type(X).__name__ + steps = getattr(sparse, sparse_func)(steps) + X_new = X[rows] + steps.multiply(diffs) + else: + X_new = X[rows] + steps * diffs + + return X_new.astype(X.dtype) def _in_danger_noise( self, nn_estimator, samples, target_class, y, kind="danger" @@ -727,8 +714,8 @@ def __init__( def _fit_resample(self, X, y): self._validate_estimator() - X_resampled = X.copy() - y_resampled = y.copy() + X_resampled = [X.copy()] + y_resampled = [y.copy()] for class_sample, n_samples in self.sampling_strategy_.items(): if n_samples == 0: @@ -741,14 +728,15 @@ def _fit_resample(self, X, y): X_new, y_new = self._make_samples( X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0 ) + X_resampled.append(X_new) + y_resampled.append(y_new) + + if sparse.issparse(X_new): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) - if sparse.issparse(X_new): - X_resampled = sparse.vstack([X_resampled, X_new]) - sparse_func = "tocsc" if X.format == "csc" else "tocsr" - X_resampled = getattr(X_resampled, sparse_func)() - else: - X_resampled = np.vstack((X_resampled, X_new)) - y_resampled = np.hstack((y_resampled, y_new)) return X_resampled, y_resampled @@ -1015,7 +1003,7 @@ def _fit_resample(self, X, y): return X_resampled, y_resampled - def _generate_sample(self, X, nn_data, nn_num, row, col, step): + def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): """Generate a synthetic sample with an additional steps for the categorical features. @@ -1024,35 +1012,34 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): of the majority class. """ rng = check_random_state(self.random_state) - sample = super()._generate_sample(X, nn_data, nn_num, row, col, step) - # To avoid conversion and since there is only few samples used, we - # convert those samples to dense array. - sample = ( - sample.toarray().squeeze() if sparse.issparse(sample) else sample - ) - all_neighbors = nn_data[nn_num[row]] - all_neighbors = ( - all_neighbors.toarray() - if sparse.issparse(all_neighbors) - else all_neighbors + X_new = super()._generate_samples( + X, nn_data, nn_num, rows, cols, steps ) + # change in sparsity structure more efficient with LIL than CSR + X_new = (X_new.tolil() if sparse.issparse(X_new) else X_new) + + # convert to dense array since scipy.sparse doesn't handle 3D + nn_data = (nn_data.toarray() if sparse.issparse(nn_data) else nn_data) + all_neighbors = nn_data[nn_num[rows]] categories_size = [self.continuous_features_.size] + [ cat.size for cat in self.ohe_.categories_ ] - for start_idx, end_idx in zip( - np.cumsum(categories_size)[:-1], np.cumsum(categories_size)[1:] - ): - col_max = all_neighbors[:, start_idx:end_idx].sum(axis=0) + for start_idx, end_idx in zip(np.cumsum(categories_size)[:-1], + np.cumsum(categories_size)[1:]): + col_maxs = all_neighbors[:, :, start_idx:end_idx].sum(axis=1) # tie breaking argmax - col_sel = rng.choice( - np.flatnonzero(np.isclose(col_max, col_max.max())) - ) - sample[start_idx:end_idx] = 0 - sample[start_idx + col_sel] = 1 + is_max = np.isclose(col_maxs, col_maxs.max(axis=1, keepdims=True)) + max_idxs = rng.permutation(np.argwhere(is_max)) + xs, idx_sels = np.unique(max_idxs[:, 0], return_index=True) + col_sels = max_idxs[idx_sels, 1] + + ys = start_idx + col_sels + X_new[:, start_idx:end_idx] = 0 + X_new[xs, ys] = 1 - return sparse.csr_matrix(sample) if sparse.issparse(X) else sample + return X_new @Substitution(