Skip to content

Commit ea4fcc4

Browse files
authored
FIX: rename sparse to keep_sparse for keras and tensorflow (#453)
1 parent 6fc3207 commit ea4fcc4

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

imblearn/keras/_generator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class BalancedBatchGenerator(ParentClass):
5252
batch_size : int, optional (default=32)
5353
Number of samples per gradient update.
5454
55-
sparse : bool, optional (default=False)
55+
keep_sparse : bool, optional (default=False)
5656
Either or not to conserve or not the sparsity of the input (i.e. ``X``,
5757
``y``, ``sample_weight``). By default, the returned batches will be
5858
dense.
@@ -98,15 +98,15 @@ class BalancedBatchGenerator(ParentClass):
9898
9999
"""
100100
def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32,
101-
sparse=False, random_state=None):
101+
keep_sparse=False, random_state=None):
102102
if not HAS_KERAS:
103103
raise ImportError("'No module named 'keras'")
104104
self.X = X
105105
self.y = y
106106
self.sample_weight = sample_weight
107107
self.sampler = sampler
108108
self.batch_size = batch_size
109-
self.sparse = sparse
109+
self.keep_sparse = keep_sparse
110110
self.random_state = random_state
111111
self._sample()
112112

@@ -138,7 +138,7 @@ def __getitem__(self, index):
138138
y_resampled = safe_indexing(
139139
self.y, self.indices_[index * self.batch_size:
140140
(index + 1) * self.batch_size])
141-
if issparse(X_resampled) and not self.sparse:
141+
if issparse(X_resampled) and not self.keep_sparse:
142142
X_resampled = X_resampled.toarray()
143143
if self.sample_weight is not None:
144144
sample_weight_resampled = safe_indexing(
@@ -154,7 +154,8 @@ def __getitem__(self, index):
154154

155155
@Substitution(random_state=_random_state_docstring)
156156
def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
157-
batch_size=32, sparse=False, random_state=None):
157+
batch_size=32, keep_sparse=False,
158+
random_state=None):
158159
"""Create a balanced batch generator to train keras model.
159160
160161
Returns a generator --- as well as the number of step per epoch --- which
@@ -181,7 +182,7 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
181182
batch_size : int, optional (default=32)
182183
Number of samples per gradient update.
183184
184-
sparse : bool, optional (default=False)
185+
keep_sparse : bool, optional (default=False)
185186
Either or not to conserve or not the sparsity of the input (i.e. ``X``,
186187
``y``, ``sample_weight``). By default, the returned batches will be
187188
dense.
@@ -226,4 +227,4 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
226227

227228
return tf_bbg(X=X, y=y, sample_weight=sample_weight,
228229
sampler=sampler, batch_size=batch_size,
229-
sparse=sparse, random_state=random_state)
230+
keep_sparse=keep_sparse, random_state=random_state)

imblearn/keras/tests/test_generator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ def test_balanced_batch_generator_class(sampler, sample_weight):
5252
epochs=10)
5353

5454

55-
@pytest.mark.parametrize("is_sparse", [True, False])
56-
def test_balanced_batch_generator_class_sparse(is_sparse):
55+
@pytest.mark.parametrize("keep_sparse", [True, False])
56+
def test_balanced_batch_generator_class_sparse(keep_sparse):
5757
training_generator = BalancedBatchGenerator(sparse.csr_matrix(X), y,
5858
batch_size=10,
59-
sparse=is_sparse,
59+
keep_sparse=keep_sparse,
6060
random_state=42)
6161
for idx in range(len(training_generator)):
6262
X_batch, y_batch = training_generator.__getitem__(idx)
63-
if is_sparse:
63+
if keep_sparse:
6464
assert sparse.issparse(X_batch)
6565
else:
6666
assert not sparse.issparse(X_batch)
@@ -88,14 +88,14 @@ def test_balanced_batch_generator_function(sampler, sample_weight):
8888
epochs=10)
8989

9090

91-
@pytest.mark.parametrize("is_sparse", [True, False])
92-
def test_balanced_batch_generator_function_sparse(is_sparse):
91+
@pytest.mark.parametrize("keep_sparse", [True, False])
92+
def test_balanced_batch_generator_function_sparse(keep_sparse):
9393
training_generator, steps_per_epoch = balanced_batch_generator(
94-
sparse.csr_matrix(X), y, sparse=is_sparse, batch_size=10,
94+
sparse.csr_matrix(X), y, keep_sparse=keep_sparse, batch_size=10,
9595
random_state=42)
9696
for idx in range(steps_per_epoch):
9797
X_batch, y_batch = next(training_generator)
98-
if is_sparse:
98+
if keep_sparse:
9999
assert sparse.issparse(X_batch)
100100
else:
101101
assert not sparse.issparse(X_batch)

imblearn/tensorflow/_generator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
@Substitution(random_state=_random_state_docstring)
1818
def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
19-
batch_size=32, sparse=False, random_state=None):
19+
batch_size=32, keep_sparse=False,
20+
random_state=None):
2021
"""Create a balanced batch generator to train keras model.
2122
2223
Returns a generator --- as well as the number of step per epoch --- which
@@ -43,7 +44,7 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
4344
batch_size : int, optional (default=32)
4445
Number of samples per gradient update.
4546
46-
sparse : bool, optional (default=False)
47+
keep_sparse : bool, optional (default=False)
4748
Either or not to conserve or not the sparsity of the input ``X``. By
4849
default, the returned batches will be dense.
4950
@@ -137,7 +138,7 @@ def generator(X, y, sample_weight, indices, batch_size):
137138
for index in range(0, len(indices), batch_size):
138139
X_res = safe_indexing(X, indices[index:index + batch_size])
139140
y_res = safe_indexing(y, indices[index:index + batch_size])
140-
if issparse(X_res) and not sparse:
141+
if issparse(X_res) and not keep_sparse:
141142
X_res = X_res.toarray()
142143
if sample_weight is None:
143144
yield X_res, y_res

imblearn/tensorflow/tests/test_generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,18 @@ def accuracy(y_true, y_pred):
7272
.format(e, accuracy(y, predicts_train)))
7373

7474

75-
@pytest.mark.parametrize("is_sparse", [True, False])
76-
def test_balanced_batch_generator_function_sparse(is_sparse):
75+
@pytest.mark.parametrize("keep_sparse", [True, False])
76+
def test_balanced_batch_generator_function_sparse(keep_sparse):
7777
X, y = load_iris(return_X_y=True)
7878
X, y = make_imbalance(X, y, {0: 30, 1: 50, 2: 40})
7979
X = X.astype(np.float32)
8080

8181
training_generator, steps_per_epoch = balanced_batch_generator(
82-
sparse.csr_matrix(X), y, sparse=is_sparse, batch_size=10,
82+
sparse.csr_matrix(X), y, keep_sparse=keep_sparse, batch_size=10,
8383
random_state=42)
8484
for idx in range(steps_per_epoch):
8585
X_batch, y_batch = next(training_generator)
86-
if is_sparse:
86+
if keep_sparse:
8787
assert sparse.issparse(X_batch)
8888
else:
8989
assert not sparse.issparse(X_batch)

0 commit comments

Comments
 (0)