|
8 | 8 |
|
9 | 9 | from __future__ import division
|
10 | 10 |
|
| 11 | +import math |
11 | 12 | import types
|
12 | 13 | import warnings
|
13 | 14 | from collections import Counter
|
|
16 | 17 | from scipy import sparse
|
17 | 18 |
|
18 | 19 | from sklearn.base import clone
|
| 20 | +from sklearn.cluster import MiniBatchKMeans |
| 21 | +from sklearn.metrics import pairwise_distances |
19 | 22 | from sklearn.preprocessing import OneHotEncoder
|
20 | 23 | from sklearn.svm import SVC
|
21 | 24 | from sklearn.utils import check_random_state
|
@@ -1090,3 +1093,236 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step):
|
1090 | 1093 | sample[start_idx + col_sel] = 1
|
1091 | 1094 |
|
1092 | 1095 | return sparse.csr_matrix(sample) if sparse.issparse(X) else sample
|
| 1096 | + |
| 1097 | + |
| 1098 | +@Substitution( |
| 1099 | + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, |
| 1100 | + random_state=_random_state_docstring) |
| 1101 | +class KMeansSMOTE(BaseSMOTE): |
| 1102 | + """Apply a KMeans clustering before to over-sample using SMOTE. |
| 1103 | +
|
| 1104 | + This is an implementation of the algorithm described in [1]_. |
| 1105 | +
|
| 1106 | + Read more in the :ref:`User Guide <smote_adasyn>`. |
| 1107 | +
|
| 1108 | + Parameters |
| 1109 | + ---------- |
| 1110 | + {sampling_strategy} |
| 1111 | +
|
| 1112 | + {random_state} |
| 1113 | +
|
| 1114 | + k_neighbors : int or object, optional (default=2) |
| 1115 | + If ``int``, number of nearest neighbours to used to construct synthetic |
| 1116 | + samples. If object, an estimator that inherits from |
| 1117 | + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to |
| 1118 | + find the k_neighbors. |
| 1119 | +
|
| 1120 | + n_jobs : int, optional (default=1) |
| 1121 | + The number of threads to open if possible. |
| 1122 | +
|
| 1123 | + kmeans_estimator : int or object, optional (default=MiniBatchKMeans()) |
| 1124 | + A KMeans instance or the number of clusters to be used. By default, |
| 1125 | + we used a :class:`sklearn.cluster.MiniBatchKMeans` which tend to be |
| 1126 | + better with large number of samples. |
| 1127 | +
|
| 1128 | + cluster_balance_threshold : str or float, optional (default="auto") |
| 1129 | + The threshold at which a cluster is called balanced and where samples |
| 1130 | + of the class selected for SMOTE will be oversampled. If "auto", this |
| 1131 | + will be determined by the ratio for each class, or it can be set |
| 1132 | + manually. |
| 1133 | +
|
| 1134 | + density_exponent : str or float, optional (default="auto") |
| 1135 | + This exponent is used to determine the density of a cluster. Leaving |
| 1136 | + this to "auto" will use a feature-length based exponent. |
| 1137 | +
|
| 1138 | + Attributes |
| 1139 | + ---------- |
| 1140 | + kmeans_estimator_ : estimator |
| 1141 | + The fitted clustering method used before to apply SMOTE. |
| 1142 | +
|
| 1143 | + nn_k_ : estimator |
| 1144 | + The fitted k-NN estimator used in SMOTE. |
| 1145 | +
|
| 1146 | + cluster_balance_threshold_ : float |
| 1147 | + The threshold used during ``fit`` for calling a cluster balanced. |
| 1148 | +
|
| 1149 | + References |
| 1150 | + ---------- |
| 1151 | + .. [1] Felix Last, Georgios Douzas, Fernando Bacao, "Oversampling for |
| 1152 | + Imbalanced Learning Based on K-Means and SMOTE" |
| 1153 | + https://arxiv.org/abs/1711.00837 |
| 1154 | +
|
| 1155 | + Examples |
| 1156 | + -------- |
| 1157 | +
|
| 1158 | + >>> import numpy as np |
| 1159 | + >>> from imblearn.over_sampling import KMeansSMOTE |
| 1160 | + >>> from sklearn.datasets import make_blobs |
| 1161 | + >>> blobs = [100, 800, 100] |
| 1162 | + >>> X, y = make_blobs(blobs, centers=[(-10, 0), (0,0), (10, 0)]) |
| 1163 | + >>> # Add a single 0 sample in the middle blob |
| 1164 | + >>> X = np.concatenate([X, [[0, 0]]]) |
| 1165 | + >>> y = np.append(y, 0) |
| 1166 | + >>> # Make this a binary classification problem |
| 1167 | + >>> y = y == 1 |
| 1168 | + >>> sm = KMeansSMOTE(random_state=42) |
| 1169 | + >>> X_res, y_res = sm.fit_resample(X, y) |
| 1170 | + >>> # Find the number of new samples in the middle blob |
| 1171 | + >>> n_res_in_middle = ((X_res[:, 0] > -5) & (X_res[:, 0] < 5)).sum() |
| 1172 | + >>> print("Samples in the middle blob: %s" % n_res_in_middle) |
| 1173 | + Samples in the middle blob: 801 |
| 1174 | + >>> print("Middle blob unchanged: %s" % (n_res_in_middle == blobs[1] + 1)) |
| 1175 | + Middle blob unchanged: True |
| 1176 | + >>> print("More 0 samples: %s" % ((y_res == 0).sum() > (y == 0).sum())) |
| 1177 | + More 0 samples: True |
| 1178 | +
|
| 1179 | + """ |
| 1180 | + def __init__(self, |
| 1181 | + sampling_strategy='auto', |
| 1182 | + random_state=None, |
| 1183 | + k_neighbors=2, |
| 1184 | + n_jobs=1, |
| 1185 | + kmeans_estimator=None, |
| 1186 | + cluster_balance_threshold="auto", |
| 1187 | + density_exponent="auto"): |
| 1188 | + super().__init__( |
| 1189 | + sampling_strategy=sampling_strategy, random_state=random_state, |
| 1190 | + k_neighbors=k_neighbors, n_jobs=n_jobs) |
| 1191 | + self.kmeans_estimator = kmeans_estimator |
| 1192 | + self.cluster_balance_threshold = cluster_balance_threshold |
| 1193 | + self.density_exponent = density_exponent |
| 1194 | + |
| 1195 | + def _validate_estimator(self): |
| 1196 | + super()._validate_estimator() |
| 1197 | + if self.kmeans_estimator is None: |
| 1198 | + self.kmeans_estimator_ = MiniBatchKMeans( |
| 1199 | + random_state=self.random_state) |
| 1200 | + elif isinstance(self.kmeans_estimator, int): |
| 1201 | + self.kmeans_estimator_ = MiniBatchKMeans( |
| 1202 | + n_clusters=self.kmeans_estimator, |
| 1203 | + random_state=self.random_state) |
| 1204 | + else: |
| 1205 | + self.kmeans_estimator_ = clone(self.kmeans_estimator) |
| 1206 | + |
| 1207 | + # validate the parameters |
| 1208 | + for param_name in ('cluster_balance_threshold', 'density_exponent'): |
| 1209 | + param = getattr(self, param_name) |
| 1210 | + if isinstance(param, str) and param != 'auto': |
| 1211 | + raise ValueError( |
| 1212 | + "'{}' should be 'auto' when a string is passed. " |
| 1213 | + "Got {} instead.".format(param_name, repr(param)) |
| 1214 | + ) |
| 1215 | + |
| 1216 | + self.cluster_balance_threshold_ = ( |
| 1217 | + self.cluster_balance_threshold |
| 1218 | + if self.kmeans_estimator_.n_clusters != 1 else -np.inf |
| 1219 | + ) |
| 1220 | + |
| 1221 | + |
| 1222 | + def _find_cluster_sparsity(self, X): |
| 1223 | + """Compute the cluster sparsity.""" |
| 1224 | + euclidean_distances = pairwise_distances(X, metric="euclidean", |
| 1225 | + n_jobs=self.n_jobs) |
| 1226 | + # negate diagonal elements |
| 1227 | + for ind in range(X.shape[0]): |
| 1228 | + euclidean_distances[ind, ind] = 0 |
| 1229 | + |
| 1230 | + non_diag_elements = (X.shape[0] ** 2) - X.shape[0] |
| 1231 | + mean_distance = euclidean_distances.sum() / non_diag_elements |
| 1232 | + exponent = (math.log(X.shape[0], 1.6) ** 1.8 * 0.16 |
| 1233 | + if self.density_exponent == 'auto' |
| 1234 | + else self.density_exponent) |
| 1235 | + return (mean_distance ** exponent) / X.shape[0] |
| 1236 | + |
| 1237 | + # FIXME: rename _sample -> _fit_resample in 0.6 |
| 1238 | + def _fit_resample(self, X, y): |
| 1239 | + return self._sample(X, y) |
| 1240 | + |
| 1241 | + def _sample(self, X, y): |
| 1242 | + self._validate_estimator() |
| 1243 | + X_resampled = X.copy() |
| 1244 | + y_resampled = y.copy() |
| 1245 | + total_inp_samples = sum(self.sampling_strategy_.values()) |
| 1246 | + |
| 1247 | + for class_sample, n_samples in self.sampling_strategy_.items(): |
| 1248 | + if n_samples == 0: |
| 1249 | + continue |
| 1250 | + |
| 1251 | + # target_class_indices = np.flatnonzero(y == class_sample) |
| 1252 | + # X_class = safe_indexing(X, target_class_indices) |
| 1253 | + |
| 1254 | + X_clusters = self.kmeans_estimator_.fit_predict(X) |
| 1255 | + valid_clusters = [] |
| 1256 | + cluster_sparsities = [] |
| 1257 | + |
| 1258 | + # identify cluster which are answering the requirements |
| 1259 | + for cluster_idx in range(self.kmeans_estimator_.n_clusters): |
| 1260 | + |
| 1261 | + cluster_mask = np.flatnonzero(X_clusters == cluster_idx) |
| 1262 | + X_cluster = safe_indexing(X, cluster_mask) |
| 1263 | + y_cluster = safe_indexing(y, cluster_mask) |
| 1264 | + |
| 1265 | + cluster_class_mean = (y_cluster == class_sample).mean() |
| 1266 | + |
| 1267 | + if self.cluster_balance_threshold_ == "auto": |
| 1268 | + balance_threshold = n_samples / total_inp_samples / 2 |
| 1269 | + else: |
| 1270 | + balance_threshold = self.cluster_balance_threshold_ |
| 1271 | + |
| 1272 | + # the cluster is already considered balanced |
| 1273 | + if cluster_class_mean < balance_threshold: |
| 1274 | + continue |
| 1275 | + |
| 1276 | + # not enough samples to apply SMOTE |
| 1277 | + anticipated_samples = cluster_class_mean * X_cluster.shape[0] |
| 1278 | + if anticipated_samples < self.nn_k_.n_neighbors: |
| 1279 | + continue |
| 1280 | + |
| 1281 | + X_cluster_class = safe_indexing( |
| 1282 | + X_cluster, np.flatnonzero(y_cluster == class_sample) |
| 1283 | + ) |
| 1284 | + |
| 1285 | + valid_clusters.append(cluster_mask) |
| 1286 | + cluster_sparsities.append( |
| 1287 | + self._find_cluster_sparsity(X_cluster_class) |
| 1288 | + ) |
| 1289 | + |
| 1290 | + cluster_sparsities = np.array(cluster_sparsities) |
| 1291 | + cluster_weights = cluster_sparsities / cluster_sparsities.sum() |
| 1292 | + |
| 1293 | + if not valid_clusters: |
| 1294 | + raise RuntimeError( |
| 1295 | + "No clusters found with sufficient samples of " |
| 1296 | + "class {}. Try lowering the cluster_balance_threshold or " |
| 1297 | + "or increasing the number of " |
| 1298 | + "clusters.".format(class_sample)) |
| 1299 | + |
| 1300 | + for valid_cluster_idx, valid_cluster in enumerate(valid_clusters): |
| 1301 | + X_cluster = safe_indexing(X, valid_cluster) |
| 1302 | + y_cluster = safe_indexing(y, valid_cluster) |
| 1303 | + |
| 1304 | + X_cluster_class = safe_indexing( |
| 1305 | + X_cluster, np.flatnonzero(y_cluster == class_sample) |
| 1306 | + ) |
| 1307 | + |
| 1308 | + self.nn_k_.fit(X_cluster_class) |
| 1309 | + nns = self.nn_k_.kneighbors(X_cluster_class, |
| 1310 | + return_distance=False)[:, 1:] |
| 1311 | + |
| 1312 | + cluster_n_samples = int(math.ceil( |
| 1313 | + n_samples * cluster_weights[valid_cluster_idx]) |
| 1314 | + ) |
| 1315 | + |
| 1316 | + X_new, y_new = self._make_samples(X_cluster_class, |
| 1317 | + y.dtype, |
| 1318 | + class_sample, |
| 1319 | + X_cluster_class, |
| 1320 | + nns, |
| 1321 | + cluster_n_samples, |
| 1322 | + 1.0) |
| 1323 | + |
| 1324 | + stack = [np.vstack, sparse.vstack][int(sparse.issparse(X_new))] |
| 1325 | + X_resampled = stack((X_resampled, X_new)) |
| 1326 | + y_resampled = np.hstack((y_resampled, y_new)) |
| 1327 | + |
| 1328 | + return X_resampled, y_resampled |
0 commit comments