Skip to content

MNT remove deprecation warning in keras example #775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions examples/applications/porto_seguro_keras_under_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ def convert_float64(X):
###############################################################################
# Create a neural-network
###############################################################################

from keras.models import Sequential
from keras.layers import Activation, Dense, Dropout, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
Activation,
Dense,
Dropout,
BatchNormalization,
)


def make_model(n_features):
Expand Down Expand Up @@ -169,8 +173,8 @@ def fit_predict_balanced_model(X_train, y_train, X_test, y_test):
training_generator = BalancedBatchGenerator(X_train, y_train,
batch_size=1000,
random_state=42)
model.fit_generator(generator=training_generator, epochs=5, verbose=1)
y_pred = model.predict_proba(X_test, batch_size=1000)
model.fit(training_generator, epochs=5, verbose=1)
y_pred = model.predict(X_test, batch_size=1000)
return roc_auc_score(y_test, y_pred)


Expand Down
36 changes: 20 additions & 16 deletions imblearn/keras/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,21 @@ class BalancedBatchGenerator(*ParentClass):
>>> class_dict = dict()
>>> class_dict[0] = 30; class_dict[1] = 50; class_dict[2] = 40
>>> X, y = make_imbalance(iris.data, iris.target, class_dict)
>>> import keras
>>> y = keras.utils.to_categorical(y, 3)
>>> model = keras.models.Sequential()
>>> model.add(keras.layers.Dense(y.shape[1], input_dim=X.shape[1],
... activation='softmax'))
>>> import tensorflow
>>> y = tensorflow.keras.utils.to_categorical(y, 3)
>>> model = tensorflow.keras.models.Sequential()
>>> model.add(
... tensorflow.keras.layers.Dense(
... y.shape[1], input_dim=X.shape[1], activation='softmax'
... )
... )
>>> model.compile(optimizer='sgd', loss='categorical_crossentropy',
... metrics=['accuracy'])
>>> from imblearn.keras import BalancedBatchGenerator
>>> from imblearn.under_sampling import NearMiss
>>> training_generator = BalancedBatchGenerator(
... X, y, sampler=NearMiss(), batch_size=10, random_state=42)
>>> callback_history = model.fit_generator(generator=training_generator,
... epochs=10, verbose=0)
>>> callback_history = model.fit(training_generator, epochs=10, verbose=0)
"""

# flag for keras sequence duck-typing
Expand Down Expand Up @@ -264,21 +266,23 @@ def balanced_batch_generator(
>>> class_dict[0] = 30; class_dict[1] = 50; class_dict[2] = 40
>>> from imblearn.datasets import make_imbalance
>>> X, y = make_imbalance(X, y, class_dict)
>>> import keras
>>> y = keras.utils.to_categorical(y, 3)
>>> model = keras.models.Sequential()
>>> model.add(keras.layers.Dense(y.shape[1], input_dim=X.shape[1],
... activation='softmax'))
>>> import tensorflow
>>> y = tensorflow.keras.utils.to_categorical(y, 3)
>>> model = tensorflow.keras.models.Sequential()
>>> model.add(
... tensorflow.keras.layers.Dense(
... y.shape[1], input_dim=X.shape[1], activation='softmax'
... )
... )
>>> model.compile(optimizer='sgd', loss='categorical_crossentropy',
... metrics=['accuracy'])
>>> from imblearn.keras import balanced_batch_generator
>>> from imblearn.under_sampling import NearMiss
>>> training_generator, steps_per_epoch = balanced_batch_generator(
... X, y, sampler=NearMiss(), batch_size=10, random_state=42)
>>> callback_history = model.fit_generator(generator=training_generator,
... steps_per_epoch=steps_per_epoch,
... epochs=10, verbose=0)

>>> callback_history = model.fit(training_generator,
... steps_per_epoch=steps_per_epoch,
... epochs=10, verbose=0)
"""

return tf_bbg(
Expand Down