Skip to content

Commit b606cb9

Browse files
MattEdingglemaitre
authored andcommitted
ENH Vectorized samples generation for SMOTE-based algorithms (#596)
1 parent 9b31677 commit b606cb9

File tree

3 files changed

+73
-72
lines changed

3 files changed

+73
-72
lines changed

doc/over_sampling.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ features or a boolean mask marking these features::
190190
>>> print(X_resampled[-5:])
191191
[['A' 0.5246469549655818 2]
192192
['B' -0.3657680728116921 2]
193-
['A' 0.9344237230779993 2]
193+
['B' 0.9344237230779993 2]
194194
['B' 0.3710891618824609 2]
195195
['B' 0.3327240726719727 2]]
196196

doc/whats_new/v0.6.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ scikit-learn:
1515
- :class:`imblearn.under_sampling.ClusterCentroids`
1616
- :class:`imblearn.under_sampling.InstanceHardnessThreshold`
1717

18+
The following samplers will give different results due to change linked to
19+
the random state internal usage:
20+
21+
- :class:`imblearn.over_sampling.SMOTENC`
22+
1823
Bug fixes
1924
.........
2025

@@ -64,6 +69,15 @@ Enhancement
6469
finite values are present in ``X``.
6570
:pr:`643` by `Guillaume Lemaitre <glemaitre>`.
6671

72+
- The samples generation in
73+
:class:`imblearn.over_sampling.SMOTE`,
74+
:class:`imblearn.over_sampling.BorderlineSMOTE`,
75+
:class:`imblearn.over_sampling.SVMSMOTE`,
76+
:class:`imblearn.over_sampling.KMeansSMOTE`,
77+
:class:`imblearn.over_sampling.SMOTENC` is now vectorize with giving
78+
an additional speed-up when `X` in sparse.
79+
:pr:`596` by :user:`Matt Edding <MattEding>`.
80+
6781
Deprecation
6882
...........
6983

imblearn/over_sampling/_smote.py

Lines changed: 58 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -100,39 +100,17 @@ def _make_samples(
100100
samples_indices = random_state.randint(
101101
low=0, high=len(nn_num.flatten()), size=n_samples
102102
)
103-
steps = step_size * random_state.uniform(size=n_samples)
103+
104+
# np.newaxis for backwards compatability with random_state
105+
steps = step_size * random_state.uniform(size=n_samples)[:, np.newaxis]
104106
rows = np.floor_divide(samples_indices, nn_num.shape[1])
105107
cols = np.mod(samples_indices, nn_num.shape[1])
106108

107-
y_new = np.array([y_type] * len(samples_indices), dtype=y_dtype)
109+
X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps)
110+
y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype)
111+
return X_new, y_new
108112

109-
if sparse.issparse(X):
110-
row_indices, col_indices, samples = [], [], []
111-
for i, (row, col, step) in enumerate(zip(rows, cols, steps)):
112-
if X[row].nnz:
113-
sample = self._generate_sample(
114-
X, nn_data, nn_num, row, col, step
115-
)
116-
row_indices += [i] * len(sample.indices)
117-
col_indices += sample.indices.tolist()
118-
samples += sample.data.tolist()
119-
return (
120-
sparse.csr_matrix(
121-
(samples, (row_indices, col_indices)),
122-
[len(samples_indices), X.shape[1]],
123-
dtype=X.dtype,
124-
),
125-
y_new,
126-
)
127-
else:
128-
X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype)
129-
for i, (row, col, step) in enumerate(zip(rows, cols, steps)):
130-
X_new[i] = self._generate_sample(
131-
X, nn_data, nn_num, row, col, step
132-
)
133-
return X_new, y_new
134-
135-
def _generate_sample(self, X, nn_data, nn_num, row, col, step):
113+
def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
136114
r"""Generate a synthetic sample.
137115
138116
The rule for the generation is:
@@ -156,23 +134,32 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step):
156134
nn_num : ndarray of shape (n_samples_all, k_nearest_neighbours)
157135
The nearest neighbours of each sample in `nn_data`.
158136
159-
row : int
160-
Index pointing at feature vector in X which will be used
161-
as a base for creating new sample.
137+
rows : ndarray of shape (n_samples,), dtype=int
138+
Indices pointing at feature vector in X which will be used
139+
as a base for creating new samples.
162140
163-
col : int
164-
Index pointing at which nearest neighbor of base feature vector
165-
will be used when creating new sample.
141+
cols : ndarray of shape (n_samples,), dtype=int
142+
Indices pointing at which nearest neighbor of base feature vector
143+
will be used when creating new samples.
166144
167-
step : float
168-
Step size for new sample.
145+
steps : ndarray of shape (n_samples,), dtype=float
146+
Step sizes for new samples.
169147
170148
Returns
171149
-------
172-
X_new : {ndarray, sparse matrix} of shape (n_features,)
173-
Single synthetically generated sample.
150+
X_new : {ndarray, sparse matrix} of shape (n_samples, n_features)
151+
Synthetically generated samples.
174152
"""
175-
return X[row] - step * (X[row] - nn_data[nn_num[row, col]])
153+
diffs = nn_data[nn_num[rows, cols]] - X[rows]
154+
155+
if sparse.issparse(X):
156+
sparse_func = type(X).__name__
157+
steps = getattr(sparse, sparse_func)(steps)
158+
X_new = X[rows] + steps.multiply(diffs)
159+
else:
160+
X_new = X[rows] + steps * diffs
161+
162+
return X_new.astype(X.dtype)
176163

177164
def _in_danger_noise(
178165
self, nn_estimator, samples, target_class, y, kind="danger"
@@ -727,8 +714,8 @@ def __init__(
727714
def _fit_resample(self, X, y):
728715
self._validate_estimator()
729716

730-
X_resampled = X.copy()
731-
y_resampled = y.copy()
717+
X_resampled = [X.copy()]
718+
y_resampled = [y.copy()]
732719

733720
for class_sample, n_samples in self.sampling_strategy_.items():
734721
if n_samples == 0:
@@ -741,14 +728,15 @@ def _fit_resample(self, X, y):
741728
X_new, y_new = self._make_samples(
742729
X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0
743730
)
731+
X_resampled.append(X_new)
732+
y_resampled.append(y_new)
733+
734+
if sparse.issparse(X_new):
735+
X_resampled = sparse.vstack(X_resampled, format=X.format)
736+
else:
737+
X_resampled = np.vstack(X_resampled)
738+
y_resampled = np.hstack(y_resampled)
744739

745-
if sparse.issparse(X_new):
746-
X_resampled = sparse.vstack([X_resampled, X_new])
747-
sparse_func = "tocsc" if X.format == "csc" else "tocsr"
748-
X_resampled = getattr(X_resampled, sparse_func)()
749-
else:
750-
X_resampled = np.vstack((X_resampled, X_new))
751-
y_resampled = np.hstack((y_resampled, y_new))
752740

753741
return X_resampled, y_resampled
754742

@@ -1015,7 +1003,7 @@ def _fit_resample(self, X, y):
10151003

10161004
return X_resampled, y_resampled
10171005

1018-
def _generate_sample(self, X, nn_data, nn_num, row, col, step):
1006+
def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
10191007
"""Generate a synthetic sample with an additional steps for the
10201008
categorical features.
10211009
@@ -1024,35 +1012,34 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step):
10241012
of the majority class.
10251013
"""
10261014
rng = check_random_state(self.random_state)
1027-
sample = super()._generate_sample(X, nn_data, nn_num, row, col, step)
1028-
# To avoid conversion and since there is only few samples used, we
1029-
# convert those samples to dense array.
1030-
sample = (
1031-
sample.toarray().squeeze() if sparse.issparse(sample) else sample
1032-
)
1033-
all_neighbors = nn_data[nn_num[row]]
1034-
all_neighbors = (
1035-
all_neighbors.toarray()
1036-
if sparse.issparse(all_neighbors)
1037-
else all_neighbors
1015+
X_new = super()._generate_samples(
1016+
X, nn_data, nn_num, rows, cols, steps
10381017
)
1018+
# change in sparsity structure more efficient with LIL than CSR
1019+
X_new = (X_new.tolil() if sparse.issparse(X_new) else X_new)
1020+
1021+
# convert to dense array since scipy.sparse doesn't handle 3D
1022+
nn_data = (nn_data.toarray() if sparse.issparse(nn_data) else nn_data)
1023+
all_neighbors = nn_data[nn_num[rows]]
10391024

10401025
categories_size = [self.continuous_features_.size] + [
10411026
cat.size for cat in self.ohe_.categories_
10421027
]
10431028

1044-
for start_idx, end_idx in zip(
1045-
np.cumsum(categories_size)[:-1], np.cumsum(categories_size)[1:]
1046-
):
1047-
col_max = all_neighbors[:, start_idx:end_idx].sum(axis=0)
1029+
for start_idx, end_idx in zip(np.cumsum(categories_size)[:-1],
1030+
np.cumsum(categories_size)[1:]):
1031+
col_maxs = all_neighbors[:, :, start_idx:end_idx].sum(axis=1)
10481032
# tie breaking argmax
1049-
col_sel = rng.choice(
1050-
np.flatnonzero(np.isclose(col_max, col_max.max()))
1051-
)
1052-
sample[start_idx:end_idx] = 0
1053-
sample[start_idx + col_sel] = 1
1033+
is_max = np.isclose(col_maxs, col_maxs.max(axis=1, keepdims=True))
1034+
max_idxs = rng.permutation(np.argwhere(is_max))
1035+
xs, idx_sels = np.unique(max_idxs[:, 0], return_index=True)
1036+
col_sels = max_idxs[idx_sels, 1]
1037+
1038+
ys = start_idx + col_sels
1039+
X_new[:, start_idx:end_idx] = 0
1040+
X_new[xs, ys] = 1
10541041

1055-
return sparse.csr_matrix(sample) if sparse.issparse(X) else sample
1042+
return X_new
10561043

10571044

10581045
@Substitution(

0 commit comments

Comments
 (0)