Skip to content

Commit 41cd9a6

Browse files
authored
BUG: Preserve dtype of X and y when generating samples (#450)
1 parent 267dd32 commit 41cd9a6

File tree

5 files changed

+47
-14
lines changed

5 files changed

+47
-14
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ Bug fixes
6363
- Force to clone scikit-learn estimator passed as attributes to samplers.
6464
:issue:`446` by :user:`Guillaume Lemaitre <glemaitre>`.
6565

66+
- Fix bug which was not preserving the dtype of X and y when generating
67+
samples.
68+
issue:`448` by :user:`Guillaume Lemaitre <glemaitre>`.
69+
6670
Maintenance
6771
...........
6872

imblearn/over_sampling/adasyn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,9 @@ def _sample(self, X, y):
185185
n_samples_generated += 1
186186
X_new = (sparse.csr_matrix(
187187
(samples, (row_indices, col_indices)),
188-
[np.sum(n_samples_generate), X.shape[1]]))
189-
y_new = np.array([class_sample] * np.sum(n_samples_generate))
188+
[np.sum(n_samples_generate), X.shape[1]], dtype=X.dtype))
189+
y_new = np.array([class_sample] * np.sum(n_samples_generate),
190+
dtype=y.dtype)
190191
else:
191192
x_class_gen = []
192193
for x_i, x_i_nn, num_sample_i in zip(X_class, nn_index,
@@ -201,8 +202,9 @@ def _sample(self, X, y):
201202
for step, nn_z in zip(steps, nn_zs)
202203
])
203204

204-
X_new = np.concatenate(x_class_gen)
205-
y_new = np.array([class_sample] * np.sum(n_samples_generate))
205+
X_new = np.concatenate(x_class_gen).astype(X.dtype)
206+
y_new = np.array([class_sample] * np.sum(n_samples_generate),
207+
dtype=y.dtype)
206208

207209
if sparse.issparse(X_new):
208210
X_resampled = sparse.vstack([X_resampled, X_new])

imblearn/over_sampling/smote.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _validate_estimator(self):
5252

5353
def _make_samples(self,
5454
X,
55+
y_dtype,
5556
y_type,
5657
nn_data,
5758
nn_num,
@@ -65,6 +66,9 @@ def _make_samples(self,
6566
X : {array-like, sparse matrix}, shape (n_samples, n_features)
6667
Points from which the points will be created.
6768
69+
y_dtype : dtype
70+
The data type of the targets.
71+
6872
y_type : str or int
6973
The minority target value, just so the function can return the
7074
target values for the synthetic variables with correct length in
@@ -108,15 +112,16 @@ def _make_samples(self,
108112
col_indices += sample.indices.tolist()
109113
samples += sample.data.tolist()
110114
else:
111-
X_new = np.zeros((n_samples, X.shape[1]))
115+
X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype)
112116
for i, (row, col, step) in enumerate(zip(rows, cols, steps)):
113117
X_new[i] = X[row] - step * (X[row] - nn_data[nn_num[row, col]])
114118

115-
y_new = np.array([y_type] * len(samples_indices))
119+
y_new = np.array([y_type] * len(samples_indices), dtype=y_dtype)
116120

117121
if sparse.issparse(X):
118122
return (sparse.csr_matrix((samples, (row_indices, col_indices)),
119-
[len(samples_indices), X.shape[1]]),
123+
[len(samples_indices), X.shape[1]],
124+
dtype=X.dtype),
120125
y_new)
121126
else:
122127
return X_new, y_new
@@ -301,8 +306,8 @@ def _sample(self, X, y):
301306
if self.kind == 'borderline-1':
302307
# Create synthetic samples for borderline points.
303308
X_new, y_new = self._make_samples(
304-
safe_indexing(X_class, danger_index), class_sample,
305-
X_class, nns, n_samples)
309+
safe_indexing(X_class, danger_index), y.dtype,
310+
class_sample, X_class, nns, n_samples)
306311
if sparse.issparse(X_new):
307312
X_resampled = sparse.vstack([X_resampled, X_new])
308313
else:
@@ -316,6 +321,7 @@ def _sample(self, X, y):
316321
# only minority
317322
X_new_1, y_new_1 = self._make_samples(
318323
safe_indexing(X_class, danger_index),
324+
y.dtype,
319325
class_sample,
320326
X_class,
321327
nns,
@@ -327,6 +333,7 @@ def _sample(self, X, y):
327333
# class but all over classes.
328334
X_new_2, y_new_2 = self._make_samples(
329335
safe_indexing(X_class, danger_index),
336+
y.dtype,
330337
class_sample,
331338
safe_indexing(X, np.flatnonzero(y != class_sample)),
332339
nns,
@@ -490,6 +497,7 @@ def _sample(self, X, y):
490497

491498
X_new_1, y_new_1 = self._make_samples(
492499
safe_indexing(support_vector, np.flatnonzero(danger_bool)),
500+
y.dtype,
493501
class_sample,
494502
X_class,
495503
nns,
@@ -503,6 +511,7 @@ def _sample(self, X, y):
503511

504512
X_new_2, y_new_2 = self._make_samples(
505513
safe_indexing(support_vector, np.flatnonzero(safety_bool)),
514+
y.dtype,
506515
class_sample,
507516
X_class,
508517
nns,
@@ -738,8 +747,8 @@ def _sample(self, X, y):
738747

739748
self.nn_k_.fit(X_class)
740749
nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:]
741-
X_new, y_new = self._make_samples(X_class, class_sample, X_class,
742-
nns, n_samples, 1.0)
750+
X_new, y_new = self._make_samples(X_class, y.dtype, class_sample,
751+
X_class, nns, n_samples, 1.0)
743752

744753
if sparse.issparse(X_new):
745754
X_resampled = sparse.vstack([X_resampled, X_new])

imblearn/under_sampling/prototype_generation/cluster_centroids.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def _generate_sample(self, X, y, centroids, target_class):
128128
X_new = safe_indexing(X, np.squeeze(indices))
129129
else:
130130
if sparse.issparse(X):
131-
X_new = sparse.csr_matrix(centroids)
131+
X_new = sparse.csr_matrix(centroids, dtype=X.dtype)
132132
else:
133133
X_new = centroids
134-
y_new = np.array([target_class] * centroids.shape[0])
134+
y_new = np.array([target_class] * centroids.shape[0], dtype=y.dtype)
135135

136136
return X_new, y_new
137137

@@ -191,4 +191,4 @@ def _sample(self, X, y):
191191
X_resampled = np.vstack(X_resampled)
192192
y_resampled = np.hstack(y_resampled)
193193

194-
return X_resampled, np.array(y_resampled)
194+
return X_resampled, np.array(y_resampled, dtype=y.dtype)

imblearn/utils/estimator_checks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _yield_sampler_checks(name, Estimator):
4949
yield check_samplers_sparse
5050
yield check_samplers_pandas
5151
yield check_samplers_multiclass_ova
52+
yield check_samplers_preserve_dtype
5253

5354

5455
def _yield_all_checks(name, estimator):
@@ -333,3 +334,20 @@ def check_samplers_multiclass_ova(name, Sampler):
333334
else:
334335
assert type_of_target(y_res_ova) == type_of_target(y_ova)
335336
assert_allclose(y_res, y_res_ova.argmax(axis=1))
337+
338+
339+
def check_samplers_preserve_dtype(name, Sampler):
340+
X, y = make_classification(
341+
n_samples=1000,
342+
n_classes=3,
343+
n_informative=4,
344+
weights=[0.2, 0.3, 0.5],
345+
random_state=0)
346+
# Cast X and y to not default dtype
347+
X = X.astype(np.float32)
348+
y = y.astype(np.int32)
349+
sampler = Sampler()
350+
set_random_state(sampler)
351+
X_res, y_res = sampler.fit_sample(X, y)
352+
assert X.dtype == X_res.dtype
353+
assert y.dtype == y_res.dtype

0 commit comments

Comments
 (0)