diff --git a/.travis.yml b/.travis.yml index 36b502320..650c14bb1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,11 +38,11 @@ matrix: NUMPY_VERSION="1.13.1" SCIPY_VERSION="0.19.1" SKLEARN_VERSION="0.19.0" - env: DISTRIB="conda" PYTHON_VERSION="3.6" NUMPY_VERSION="1.13.1" SCIPY_VERSION="0.19.1" SKLEARN_VERSION="0.19.0" - - env: DISTRIB="conda" PYTHON_VERSION="3.6" - NUMPY_VERSION="1.13.1" SCIPY_VERSION="0.19.1" SKLEARN_VERSION="master" + - env: DISTRIB="conda" PYTHON_VERSION="3.7" + NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" allow_failures: - - env: DISTRIB="conda" PYTHON_VERSION="3.6" - NUMPY_VERSION="1.13.1" SCIPY_VERSION="0.19.1" SKLEARN_VERSION="master" + - env: DISTRIB="conda" PYTHON_VERSION="3.7" + NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" install: source build_tools/travis/install.sh script: bash build_tools/travis/test_script.sh diff --git a/appveyor.yml b/appveyor.yml index a09272080..6bb885553 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -10,30 +10,37 @@ environment: - PYTHON: "C:\\Miniconda-x64" PYTHON_VERSION: "2.7.x" PYTHON_ARCH: "64" + OPTIONAL_DEP: "pandas" - PYTHON: "C:\\Miniconda" PYTHON_VERSION: "2.7.x" PYTHON_ARCH: "32" + OPTIONAL_DEP: "pandas" - PYTHON: "C:\\Miniconda35-x64" PYTHON_VERSION: "3.5.x" PYTHON_ARCH: "64" + OPTIONAL_DEP: "pandas keras tensorflow" - PYTHON: "C:\\Miniconda36-x64" PYTHON_VERSION: "3.6.x" PYTHON_ARCH: "64" + OPTIONAL_DEP: "pandas keras tensorflow" - PYTHON: "C:\\Miniconda36" PYTHON_VERSION: "3.6.x" PYTHON_ARCH: "32" + OPTIONAL_DEP: "pandas" install: # Prepend miniconda installed Python to the PATH of this build # Add Library/bin directory to fix issue # https://github.com/conda/conda/issues/1753 - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\Library\\bin;%PATH%" - - conda install pip scipy numpy scikit-learn=0.19 pandas -y -q + - conda install pip scipy numpy scikit-learn=0.19 -y -q + - "conda install %OPTIONAL_DEP% -y -q" - conda install pytest pytest-cov -y -q + - pip install codecov - conda install nose -y -q # FIXME: remove this line when using sklearn > 0.19 - pip install . @@ -41,3 +48,8 @@ test_script: - mkdir for_test - cd for_test - pytest --pyargs imblearn --cov-report term-missing --cov=imblearn + +after_test: + - cp .coverage %APPVEYOR_BUILD_FOLDER% + - cd %APPVEYOR_BUILD_FOLDER% + - codecov diff --git a/build_tools/circle/build_doc.sh b/build_tools/circle/build_doc.sh index e49088ae6..4b4ff915e 100755 --- a/build_tools/circle/build_doc.sh +++ b/build_tools/circle/build_doc.sh @@ -92,7 +92,7 @@ conda create -n $CONDA_ENV_NAME --yes --quiet python=3 source activate $CONDA_ENV_NAME conda install --yes pip numpy scipy scikit-learn pillow matplotlib sphinx \ - sphinx_rtd_theme numpydoc + sphinx_rtd_theme numpydoc pandas keras pip install -U git+https://github.com/sphinx-gallery/sphinx-gallery.git # Build and install imbalanced-learn in dev mode diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index 415e4ce5d..3a56bac81 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -38,7 +38,15 @@ if [[ "$DISTRIB" == "conda" ]]; then # provided versions conda create -n testenv --yes python=$PYTHON_VERSION pip source activate testenv - conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION pandas + conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION + + if [[ $PYTHON_VERSION == "3.6" ]]; then + conda install --yes pandas + conda install --yes -c conda-forge keras + KERAS_BACKEND=tensorflow + python -c "import keras.backend" + sed -i -e 's/"backend":[[:space:]]*"[^"]*/"backend":\ "'$KERAS_BACKEND'/g' ~/.keras/keras.json; + fi if [[ "$SKLEARN_VERSION" == "master" ]]; then conda install --yes cython @@ -59,8 +67,9 @@ elif [[ "$DISTRIB" == "ubuntu" ]]; then # Create a new virtualenv using system site packages for python, numpy virtualenv --system-site-packages testvenv source testvenv/bin/activate - pip install scikit-learn pandas nose nose-timer pytest pytest-cov codecov \ - sphinx numpydoc + pip install scikit-learn + pip install pandas keras tensorflow + pip install nose nose-timer pytest pytest-cov codecov sphinx numpydoc fi @@ -68,7 +77,7 @@ python --version python -c "import numpy; print('numpy %s' % numpy.__version__)" python -c "import scipy; print('scipy %s' % scipy.__version__)" -python setup.py develop +pip install -e . ccache --show-stats # Useful for debugging how ccache is used # cat $CCACHE_LOGFILE diff --git a/conftest.py b/conftest.py index 110fdd479..d3ff91025 100644 --- a/conftest.py +++ b/conftest.py @@ -7,8 +7,27 @@ # Set numpy array str/repr to legacy behaviour on numpy > 1.13 to make # the doctests pass +import os +import pytest import numpy as np + try: np.set_printoptions(legacy='1.13') except TypeError: pass + + +def pytest_runtest_setup(item): + fname = item.fspath.strpath + if (fname.endswith(os.path.join('keras', '_generator.py')) or + fname.endswith('miscellaneous.rst')): + try: + import keras + except ImportError: + pytest.skip('The keras package is not installed.') + elif (fname.endswith(os.path.join('tensorflow', '_generator.py')) or + fname.endswith('miscellaneous.rst')): + try: + import tensorflow + except ImportError: + pytest.skip('The tensorflow package is not installed.') diff --git a/doc/api.rst b/doc/api.rst index f9566146f..4abc49d33 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -111,6 +111,46 @@ Prototype selection ensemble.BalancedBaggingClassifier ensemble.EasyEnsemble +.. _keras_ref: + +:mod:`imblearn.keras`: Batch generator for Keras +================================================ + +.. automodule:: imblearn.keras + :no-members: + :no-inherited-members: + +.. currentmodule:: imblearn + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + keras.BalancedBatchGenerator + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + keras.balanced_batch_generator + +.. _tensorflow_ref: + +:mod:`imblearn.tensorflow`: Batch generator for TensorFlow +========================================================== + +.. automodule:: imblearn.tensorflow + :no-members: + :no-inherited-members: + +.. currentmodule:: imblearn + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + tensorflow.balanced_batch_generator + .. _misc_ref: Miscellaneous diff --git a/doc/miscellaneous.rst b/doc/miscellaneous.rst index ef263a21b..5734f5c66 100644 --- a/doc/miscellaneous.rst +++ b/doc/miscellaneous.rst @@ -38,3 +38,114 @@ We illustrate the use of such sampler to implement an outlier rejection estimator which can be easily used within a :class:`imblearn.pipeline.Pipeline`: :ref:`sphx_glr_auto_examples_plot_outlier_rejections.py` + +.. _generators: + +Custom generators +----------------- + +Imbalanced-learn provides specific generators for TensorFlow and Keras which +will generate balanced mini-batches. + +.. _tensorflow_generator: + +TensorFlow generator +~~~~~~~~~~~~~~~~~~~~ + +The :func:`imblearn.tensorflow.balanced_batch_generator` allow to generate +balanced mini-batches using an imbalanced-learn sampler which returns indices:: + + >>> X = X.astype(np.float32) + >>> from imblearn.under_sampling import RandomUnderSampler + >>> from imblearn.tensorflow import balanced_batch_generator + >>> training_generator, steps_per_epoch = balanced_batch_generator( + ... X, y, sample_weight=None, sampler=RandomUnderSampler(), + ... batch_size=10, random_state=42) + +The ``generator`` and ``steps_per_epoch`` is used during the training of the +Tensorflow model. We will illustrate how to use this generator. First, we can +define a logistic regression model which will be optimized by a gradient +descent:: + + >>> learning_rate, epochs = 0.01, 10 + >>> input_size, output_size = X.shape[1], 3 + >>> import tensorflow as tf + >>> def init_weights(shape): + ... return tf.Variable(tf.random_normal(shape, stddev=0.01)) + >>> def accuracy(y_true, y_pred): + ... return np.mean(np.argmax(y_pred, axis=1) == y_true) + >>> # input and output + >>> data = tf.placeholder("float32", shape=[None, input_size]) + >>> targets = tf.placeholder("int32", shape=[None]) + >>> # build the model and weights + >>> W = init_weights([input_size, output_size]) + >>> b = init_weights([output_size]) + >>> out_act = tf.nn.sigmoid(tf.matmul(data, W) + b) + >>> # build the loss, predict, and train operator + >>> cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + ... logits=out_act, labels=targets) + >>> loss = tf.reduce_sum(cross_entropy) + >>> optimizer = tf.train.GradientDescentOptimizer(learning_rate) + >>> train_op = optimizer.minimize(loss) + >>> predict = tf.nn.softmax(out_act) + >>> # Initialization of all variables in the graph + >>> init = tf.global_variables_initializer() + +Once initialized, the model is trained by iterating on balanced mini-batches of +data and minimizing the loss previously defined:: + + >>> with tf.Session() as sess: + ... print('Starting training') + ... sess.run(init) + ... for e in range(epochs): + ... for i in range(steps_per_epoch): + ... X_batch, y_batch = next(training_generator) + ... sess.run([train_op, loss], feed_dict={data: X_batch, targets: y_batch}) + ... # For each epoch, run accuracy on train and test + ... feed_dict = dict() + ... predicts_train = sess.run(predict, feed_dict={data: X}) + ... print("epoch: {} train accuracy: {:.3f}" + ... .format(e, accuracy(y, predicts_train))) + ... # doctest: +ELLIPSIS + Starting training + [... + +.. _keras_generator: + +Keras generator +~~~~~~~~~~~~~~~ + +Keras provides an higher level API in which a model can be defined and train by +calling ``fit_generator`` method to train the model. To illustrate, we will +define a logistic regression model:: + + >>> import keras + >>> y = keras.utils.to_categorical(y, 3) + >>> model = keras.Sequential() + >>> model.add(keras.layers.Dense(y.shape[1], input_dim=X.shape[1], + ... activation='softmax')) + >>> model.compile(optimizer='sgd', loss='categorical_crossentropy', + ... metrics=['accuracy']) + +:func:`imblearn.keras.balanced_batch_generator` creates a balanced mini-batches +generator with the associated number of mini-batches which will be generated:: + + >>> from imblearn.keras import balanced_batch_generator + >>> training_generator, steps_per_epoch = balanced_batch_generator( + ... X, y, sampler=RandomUnderSampler(), batch_size=10, random_state=42) + +Then, ``fit_generator`` can be called passing the generator and the step:: + + >>> callback_history = model.fit_generator(generator=training_generator, + ... steps_per_epoch=steps_per_epoch, + ... epochs=10, verbose=0) + +The second possibility is to use +:class:`imblearn.keras.BalancedBatchGenerator`. Only an instance of this class +will be passed to ``fit_generator``:: + + >>> from imblearn.keras import BalancedBatchGenerator + >>> training_generator = BalancedBatchGenerator( + ... X, y, sampler=RandomUnderSampler(), batch_size=10, random_state=42) + >>> callback_history = model.fit_generator(generator=training_generator, + ... epochs=10, verbose=0) diff --git a/doc/whats_new/v0.0.4.rst b/doc/whats_new/v0.0.4.rst index e359a80fe..8cd326d2f 100644 --- a/doc/whats_new/v0.0.4.rst +++ b/doc/whats_new/v0.0.4.rst @@ -18,6 +18,12 @@ API - Enable to use a ``list`` for the cleaning methods to specify the class to sample. :issue:`411` by :user:`Guillaume Lemaitre `. +New features +............ + +- Add a ``keras`` and ``tensorflow`` modules to create balanced mini-batches + generator. :issue:`409` by :user:`Guillaume Lemaitre `. + Enhancement ........... diff --git a/examples/applications/porto_seguro_keras_under_sampling.py b/examples/applications/porto_seguro_keras_under_sampling.py new file mode 100644 index 000000000..c154362d9 --- /dev/null +++ b/examples/applications/porto_seguro_keras_under_sampling.py @@ -0,0 +1,238 @@ +""" +========================================================== +Porto Seguro: balancing samples in mini-batches with Keras +========================================================== + +This example compares two strategies to train a neural-network on the Porto +Seguro Kaggle data set [1]_. The data set is imbalanced and we show that +balancing each mini-batch allows to improve performance and reduce the training +time. + +References +---------- + +.. [1] https://www.kaggle.com/c/porto-seguro-safe-driver-prediction/data + +""" + +# Authors: Guillaume Lemaitre +# License: MIT + +print(__doc__) + +############################################################################### +# Data loading +############################################################################### + +from collections import Counter +import pandas as pd +import numpy as np + +############################################################################### +# First, you should download the Porto Seguro data set from Kaggle. See the +# link in the introduction. + +training_data = pd.read_csv('./input/train.csv') +testing_data = pd.read_csv('./input/test.csv') + +y_train = training_data[['id', 'target']].set_index('id') +X_train = training_data.drop(['target'], axis=1).set_index('id') +X_test = testing_data.set_index('id') + +############################################################################### +# The data set is imbalanced and it will have an effect on the fitting. + +print('The data set is imbalanced: {}'.format(Counter(y_train['target']))) + +############################################################################### +# Define the pre-processing pipeline +############################################################################### + +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import FunctionTransformer +from sklearn.impute import SimpleImputer + + +def convert_float64(X): + return X.astype(np.float64) + + +############################################################################### +# We want to standard scale the numerical features while we want to one-hot +# encode the categorical features. In this regard, we make use of the +# :class:`sklearn.compose.ColumnTransformer`. + +numerical_columns = [name for name in X_train.columns + if '_calc_' in name and '_bin' not in name] +numerical_pipeline = make_pipeline( + FunctionTransformer(func=convert_float64, validate=False), + StandardScaler()) + +categorical_columns = [name for name in X_train.columns + if '_cat' in name] +categorical_pipeline = make_pipeline( + SimpleImputer(missing_values=-1, strategy='most_frequent'), + OneHotEncoder(categories='auto')) + +preprocessor = ColumnTransformer( + [('numerical_preprocessing', numerical_pipeline, numerical_columns), + ('categorical_preprocessing', categorical_pipeline, categorical_columns)], + remainder='drop') + +# Create an environment variable to avoid using the GPU. This can be changed. +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + +############################################################################### +# Create a neural-network +############################################################################### + +from keras.models import Sequential +from keras.layers import Activation, Dense, Dropout, BatchNormalization + + +def make_model(n_features): + model = Sequential() + model.add(Dense(200, input_shape=(n_features,), + kernel_initializer='glorot_normal')) + model.add(Activation('relu')) + model.add(BatchNormalization()) + model.add(Dropout(0.5)) + model.add(Dense(100, kernel_initializer='glorot_normal')) + model.add(Activation('relu')) + model.add(BatchNormalization()) + model.add(Dropout(0.25)) + model.add(Dense(50, kernel_initializer='glorot_normal')) + model.add(Activation('relu')) + model.add(BatchNormalization()) + model.add(Dropout(0.15)) + model.add(Dense(25, kernel_initializer='glorot_normal')) + model.add(Activation('relu')) + model.add(BatchNormalization()) + model.add(Dropout(0.1)) + model.add(Dense(1, activation='sigmoid')) + + model.compile(loss='binary_crossentropy', + optimizer='adam', + metrics=['accuracy']) + + return model + + +############################################################################### +# We create a decorator to report the computation time + +import time +from functools import wraps + + +def timeit(f): + @wraps(f) + def wrapper(*args, **kwds): + start_time = time.time() + result = f(*args, **kwds) + elapsed_time = time.time() - start_time + print('Elapsed computation time: {:.3f} secs' + .format(elapsed_time)) + return (elapsed_time, result) + return wrapper + + +############################################################################### +# The first model will be trained using the ``fit`` method and with imbalanced +# mini-batches. + +from sklearn.metrics import roc_auc_score + + +@timeit +def fit_predict_imbalanced_model(X_train, y_train, X_test, y_test): + model = make_model(X_train.shape[1]) + model.fit(X_train, y_train, epochs=2, verbose=1, batch_size=1000) + y_pred = model.predict_proba(X_test, batch_size=1000) + return roc_auc_score(y_test, y_pred) + + +############################################################################### +# In the contrary, we will use imbalanced-learn to create a generator of +# mini-batches which will yield balanced mini-batches. + +from imblearn.keras import BalancedBatchGenerator + + +@timeit +def fit_predict_balanced_model(X_train, y_train, X_test, y_test): + model = make_model(X_train.shape[1]) + training_generator = BalancedBatchGenerator(X_train, y_train, + batch_size=1000, + random_state=42) + model.fit_generator(generator=training_generator, epochs=5, verbose=1) + y_pred = model.predict_proba(X_test, batch_size=1000) + return roc_auc_score(y_test, y_pred) + + +############################################################################### +# Classification loop +############################################################################### + +############################################################################### +# We will perform a 10-fold cross-validation and train the neural-network with +# the two different strategies previously presented. + +from sklearn.model_selection import StratifiedKFold + +skf = StratifiedKFold(n_splits=10) + +cv_results_imbalanced = [] +cv_time_imbalanced = [] +cv_results_balanced = [] +cv_time_balanced = [] +for train_idx, valid_idx in skf.split(X_train, y_train): + X_local_train = preprocessor.fit_transform(X_train.iloc[train_idx]) + y_local_train = y_train.iloc[train_idx].values.ravel() + X_local_test = preprocessor.transform(X_train.iloc[valid_idx]) + y_local_test = y_train.iloc[valid_idx].values.ravel() + + elapsed_time, roc_auc = fit_predict_imbalanced_model( + X_local_train, y_local_train, X_local_test, y_local_test) + cv_time_imbalanced.append(elapsed_time) + cv_results_imbalanced.append(roc_auc) + + elapsed_time, roc_auc = fit_predict_balanced_model( + X_local_train, y_local_train, X_local_test, y_local_test) + cv_time_balanced.append(elapsed_time) + cv_results_balanced.append(roc_auc) + +############################################################################### +# Plot of the results and computation time +############################################################################### + +df_results = (pd.DataFrame({'Balanced model': cv_results_balanced, + 'Imbalanced model': cv_results_imbalanced}) + .unstack().reset_index()) +df_time = (pd.DataFrame({'Balanced model': cv_time_balanced, + 'Imbalanced model': cv_time_imbalanced}) + .unstack().reset_index()) + +import seaborn as sns +import matplotlib.pyplot as plt + +plt.figure() +sns.boxplot(y='level_0', x=0, data=df_time) +sns.despine(top=True, right=True, left=True) +plt.xlabel('time [s]') +plt.ylabel('') +plt.title('Computation time difference using a random under-sampling') + +plt.figure() +sns.boxplot(y='level_0', x=0, data=df_results, whis=10.0) +sns.despine(top=True, right=True, left=True) +ax = plt.gca() +ax.xaxis.set_major_formatter( + plt.FuncFormatter(lambda x, pos: "%i%%" % (100 * x))) +plt.xlabel('ROC-AUC') +plt.ylabel('') +plt.title('Difference in terms of ROC-AUC using a random under-sampling') diff --git a/imblearn/__init__.py b/imblearn/__init__.py index 9f05adb1f..0cb3ca8fe 100644 --- a/imblearn/__init__.py +++ b/imblearn/__init__.py @@ -13,11 +13,17 @@ exceptions Module including custom warnings and error clases used across imbalanced-learn. +keras + Module which provides custom generator, layers for deep learning using + keras. metrics Module which provides metrics to quantified the classification performance with imbalanced dataset. over_sampling Module which provides methods to under-sample a dataset. +tensorflow + Module which provides custom generator, layers for deep learning using + tensorflow. under-sampling Module which provides methods to over-sample a dataset. utils diff --git a/imblearn/keras/__init__.py b/imblearn/keras/__init__.py new file mode 100644 index 000000000..407e0c7dd --- /dev/null +++ b/imblearn/keras/__init__.py @@ -0,0 +1,8 @@ +"""The :mod:`imblearn.keras` provides utilities to deal with imbalanced dataset +in keras.""" + +from ._generator import BalancedBatchGenerator +from ._generator import balanced_batch_generator + +__all__ = ['BalancedBatchGenerator', + 'balanced_batch_generator'] diff --git a/imblearn/keras/_generator.py b/imblearn/keras/_generator.py new file mode 100644 index 000000000..c3f5a4af1 --- /dev/null +++ b/imblearn/keras/_generator.py @@ -0,0 +1,229 @@ +"""Implement generators for ``keras`` which will balance the data.""" +from __future__ import division + +# This is a trick to avoid an error during tests collection with pytest. We +# avoid the error when importing the package raise the error at the moment of +# creating the instance. +try: + import keras + ParentClass = keras.utils.Sequence + HAS_KERAS = True +except ImportError: + ParentClass = object + HAS_KERAS = False + +from scipy.sparse import issparse + +from sklearn.base import clone +from sklearn.utils import safe_indexing +from sklearn.utils import check_random_state +from sklearn.utils.testing import set_random_state + +from ..under_sampling import RandomUnderSampler +from ..utils import Substitution +from ..utils._docstring import _random_state_docstring +from ..tensorflow import balanced_batch_generator as tf_bbg + + +class BalancedBatchGenerator(ParentClass): + """Create balanced batches when training a keras model. + + Create a keras ``Sequence`` which is given to ``fit_generator``. The + sampler defines the sampling strategy used to balance the dataset ahead of + creating the batch. The sampler should have an attribute + ``return_indices``. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + Original imbalanced dataset. + + y : ndarray, shape (n_samples,) or (n_samples, n_classes) + Associated targets. + + sample_weight : ndarray, shape (n_samples,) + Sample weight. + + sampler : object or None, optional (default=RandomUnderSampler) + A sampler instance which has an attribute ``return_indices``. + By default, the sampler used is a + :class:`imblearn.under_sampling.RandomUnderSampler`. + + batch_size : int, optional (default=32) + Number of samples per gradient update. + + sparse : bool, optional (default=False) + Either or not to conserve or not the sparsity of the input (i.e. ``X``, + ``y``, ``sample_weight``). By default, the returned batches will be + dense. + + random_state : int, RandomState instance or None, optional (default=None) + Control the randomization of the algorithm + - If int, ``random_state`` is the seed used by the random number + generator; + - If ``RandomState`` instance, random_state is the random number + generator; + - If ``None``, the random number generator is the ``RandomState`` + instance used by ``np.random``. + + Attributes + ---------- + sampler_ : object + The sampler used to balance the dataset. + + indices_ : ndarray, shape (n_samples, n_features) + The indices of the samples selected during sampling. + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> iris = load_iris() + >>> from imblearn.datasets import make_imbalance + >>> class_dict = dict() + >>> class_dict[0] = 30; class_dict[1] = 50; class_dict[2] = 40 + >>> X, y = make_imbalance(iris.data, iris.target, class_dict) + >>> import keras + >>> y = keras.utils.to_categorical(y, 3) + >>> model = keras.models.Sequential() + >>> model.add(keras.layers.Dense(y.shape[1], input_dim=X.shape[1], + ... activation='softmax')) + >>> model.compile(optimizer='sgd', loss='categorical_crossentropy', + ... metrics=['accuracy']) + >>> from imblearn.keras import BalancedBatchGenerator + >>> from imblearn.under_sampling import NearMiss + >>> training_generator = BalancedBatchGenerator( + ... X, y, sampler=NearMiss(), batch_size=10, random_state=42) + >>> callback_history = model.fit_generator(generator=training_generator, + ... epochs=10, verbose=0) + + """ + def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32, + sparse=False, random_state=None): + if not HAS_KERAS: + raise ImportError("'No module named 'keras'") + self.X = X + self.y = y + self.sample_weight = sample_weight + self.sampler = sampler + self.batch_size = batch_size + self.sparse = sparse + self.random_state = random_state + self._sample() + + def _sample(self): + random_state = check_random_state(self.random_state) + if self.sampler is None: + self.sampler_ = RandomUnderSampler(return_indices=True, + random_state=random_state) + else: + if not hasattr(self.sampler, 'return_indices'): + raise ValueError("'sampler' needs to return the indices of " + "the samples selected. Provide a sampler " + "which has an attribute 'return_indices'.") + self.sampler_ = clone(self.sampler) + self.sampler_.set_params(return_indices=True) + set_random_state(self.sampler_, random_state) + + _, _, self.indices_ = self.sampler_.fit_sample(self.X, self.y) + # shuffle the indices since the sampler are packing them by class + random_state.shuffle(self.indices_) + + def __len__(self): + return int(self.indices_.size // self.batch_size) + + def __getitem__(self, index): + X_resampled = safe_indexing( + self.X, self.indices_[index * self.batch_size: + (index + 1) * self.batch_size]) + y_resampled = safe_indexing( + self.y, self.indices_[index * self.batch_size: + (index + 1) * self.batch_size]) + if issparse(X_resampled) and not self.sparse: + X_resampled = X_resampled.toarray() + if self.sample_weight is not None: + sample_weight_resampled = safe_indexing( + self.sample_weight, + self.indices_[index * self.batch_size: + (index + 1) * self.batch_size]) + + if self.sample_weight is None: + return X_resampled, y_resampled + else: + return X_resampled, y_resampled, sample_weight_resampled + + +@Substitution(random_state=_random_state_docstring) +def balanced_batch_generator(X, y, sample_weight=None, sampler=None, + batch_size=32, sparse=False, random_state=None): + """Create a balanced batch generator to train keras model. + + Returns a generator --- as well as the number of step per epoch --- which + is given to ``fit_generator``. The sampler defines the sampling strategy + used to balance the dataset ahead of creating the batch. The sampler should + have an attribute ``return_indices``. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + Original imbalanced dataset. + + y : ndarray, shape (n_samples,) or (n_samples, n_classes) + Associated targets. + + sample_weight : ndarray, shape (n_samples,) + Sample weight. + + sampler : object or None, optional (default=RandomUnderSampler) + A sampler instance which has an attribute ``return_indices``. + By default, the sampler used is a + :class:`imblearn.under_sampling.RandomUnderSampler`. + + batch_size : int, optional (default=32) + Number of samples per gradient update. + + sparse : bool, optional (default=False) + Either or not to conserve or not the sparsity of the input (i.e. ``X``, + ``y``, ``sample_weight``). By default, the returned batches will be + dense. + + {random_state} + + Returns + ------- + generator : generator of tuple + Generate batch of data. The tuple generated are either (X_batch, + y_batch) or (X_batch, y_batch, sampler_weight_batch). + + steps_per_epoch : int + The number of samples per epoch. Required by ``fit_generator`` in + keras. + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> X, y = load_iris(return_X_y=True) + >>> from imblearn.datasets import make_imbalance + >>> class_dict = dict() + >>> class_dict[0] = 30; class_dict[1] = 50; class_dict[2] = 40 + >>> from imblearn.datasets import make_imbalance + >>> X, y = make_imbalance(X, y, class_dict) + >>> import keras + >>> y = keras.utils.to_categorical(y, 3) + >>> model = keras.models.Sequential() + >>> model.add(keras.layers.Dense(y.shape[1], input_dim=X.shape[1], + ... activation='softmax')) + >>> model.compile(optimizer='sgd', loss='categorical_crossentropy', + ... metrics=['accuracy']) + >>> from imblearn.keras import balanced_batch_generator + >>> from imblearn.under_sampling import NearMiss + >>> training_generator, steps_per_epoch = balanced_batch_generator( + ... X, y, sampler=NearMiss(), batch_size=10, random_state=42) + >>> callback_history = model.fit_generator(generator=training_generator, + ... steps_per_epoch=steps_per_epoch, + ... epochs=10, verbose=0) + + """ + + return tf_bbg(X=X, y=y, sample_weight=sample_weight, + sampler=sampler, batch_size=batch_size, + sparse=sparse, random_state=random_state) diff --git a/imblearn/keras/tests/__init__.py b/imblearn/keras/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/imblearn/keras/tests/test_generator.py b/imblearn/keras/tests/test_generator.py new file mode 100644 index 000000000..7b0491146 --- /dev/null +++ b/imblearn/keras/tests/test_generator.py @@ -0,0 +1,101 @@ +import pytest + +import numpy as np +from scipy import sparse + +from sklearn.datasets import load_iris + +keras = pytest.importorskip('keras') +from keras.models import Sequential +from keras.layers import Dense +from keras.utils import to_categorical + +from imblearn.datasets import make_imbalance +from imblearn.under_sampling import ClusterCentroids +from imblearn.under_sampling import NearMiss + +from imblearn.keras import BalancedBatchGenerator +from imblearn.keras import balanced_batch_generator + +iris = load_iris() +X, y = make_imbalance(iris.data, iris.target, {0: 30, 1: 50, 2: 40}) +y = to_categorical(y, 3) + + +def _build_keras_model(n_classes, n_features): + model = Sequential() + model.add(Dense(n_classes, input_dim=n_features, activation='softmax')) + model.compile(optimizer='sgd', loss='categorical_crossentropy', + metrics=['accuracy']) + return model + + +def test_balanced_batch_generator_class_no_return_indices(): + with pytest.raises(ValueError, match='needs to return the indices'): + BalancedBatchGenerator(X, y, sampler=ClusterCentroids(), batch_size=10) + + +@pytest.mark.parametrize( + "sampler, sample_weight", + [(None, None), + (NearMiss(), None), + (None, np.random.uniform(size=(y.shape[0])))] +) +def test_balanced_batch_generator_class(sampler, sample_weight): + model = _build_keras_model(y.shape[1], X.shape[1]) + training_generator = BalancedBatchGenerator(X, y, + sample_weight=sample_weight, + sampler=sampler, + batch_size=10, + random_state=42) + model.fit_generator(generator=training_generator, + epochs=10) + + +@pytest.mark.parametrize("is_sparse", [True, False]) +def test_balanced_batch_generator_class_sparse(is_sparse): + training_generator = BalancedBatchGenerator(sparse.csr_matrix(X), y, + batch_size=10, + sparse=is_sparse, + random_state=42) + for idx in range(len(training_generator)): + X_batch, y_batch = training_generator.__getitem__(idx) + if is_sparse: + assert sparse.issparse(X_batch) + else: + assert not sparse.issparse(X_batch) + + +def test_balanced_batch_generator_function_no_return_indices(): + with pytest.raises(ValueError, match='needs to return the indices'): + balanced_batch_generator( + X, y, sampler=ClusterCentroids(), batch_size=10, random_state=42) + + +@pytest.mark.parametrize( + "sampler, sample_weight", + [(None, None), + (NearMiss(), None), + (None, np.random.uniform(size=(y.shape[0])))] +) +def test_balanced_batch_generator_function(sampler, sample_weight): + model = _build_keras_model(y.shape[1], X.shape[1]) + training_generator, steps_per_epoch = balanced_batch_generator( + X, y, sample_weight=sample_weight, sampler=sampler, batch_size=10, + random_state=42) + model.fit_generator(generator=training_generator, + steps_per_epoch=steps_per_epoch, + epochs=10) + + +@pytest.mark.parametrize("is_sparse", [True, False]) +def test_balanced_batch_generator_function_sparse(is_sparse): + training_generator, steps_per_epoch = balanced_batch_generator( + sparse.csr_matrix(X), y, sparse=is_sparse, batch_size=10, + random_state=42) + for idx in range(steps_per_epoch): + X_batch, y_batch = next(training_generator) + if is_sparse: + assert sparse.issparse(X_batch) + else: + assert not sparse.issparse(X_batch) diff --git a/imblearn/tensorflow/__init__.py b/imblearn/tensorflow/__init__.py new file mode 100644 index 000000000..3224a7db1 --- /dev/null +++ b/imblearn/tensorflow/__init__.py @@ -0,0 +1,6 @@ +"""The :mod:`imblearn.tensorflow` provides utilities to deal with imbalanced +dataset in tensorflow.""" + +from ._generator import balanced_batch_generator + +__all__ = ['balanced_batch_generator'] diff --git a/imblearn/tensorflow/_generator.py b/imblearn/tensorflow/_generator.py new file mode 100644 index 000000000..9b0cb06d5 --- /dev/null +++ b/imblearn/tensorflow/_generator.py @@ -0,0 +1,150 @@ +"""Implement generators for ``tensorflow`` which will balance the data.""" + +from __future__ import division + +from scipy.sparse import issparse + +from sklearn.base import clone +from sklearn.utils import safe_indexing +from sklearn.utils import check_random_state +from sklearn.utils.testing import set_random_state + +from ..under_sampling import RandomUnderSampler +from ..utils import Substitution +from ..utils._docstring import _random_state_docstring + + +@Substitution(random_state=_random_state_docstring) +def balanced_batch_generator(X, y, sample_weight=None, sampler=None, + batch_size=32, sparse=False, random_state=None): + """Create a balanced batch generator to train keras model. + + Returns a generator --- as well as the number of step per epoch --- which + is given to ``fit_generator``. The sampler defines the sampling strategy + used to balance the dataset ahead of creating the batch. The sampler should + have an attribute ``return_indices``. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + Original imbalanced dataset. + + y : ndarray, shape (n_samples,) or (n_samples, n_classes) + Associated targets. + + sample_weight : ndarray, shape (n_samples,) + Sample weight. + + sampler : object or None, optional (default=RandomUnderSampler) + A sampler instance which has an attribute ``return_indices``. + By default, the sampler used is a + :class:`imblearn.under_sampling.RandomUnderSampler`. + + batch_size : int, optional (default=32) + Number of samples per gradient update. + + sparse : bool, optional (default=False) + Either or not to conserve or not the sparsity of the input ``X``. By + default, the returned batches will be dense. + + {random_state} + + Returns + ------- + generator : generator of tuple + Generate batch of data. The tuple generated are either (X_batch, + y_batch) or (X_batch, y_batch, sampler_weight_batch). + + steps_per_epoch : int + The number of samples per epoch. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.datasets import load_iris + >>> X, y = load_iris(return_X_y=True) + >>> class_dict = dict() + >>> class_dict[0] = 30; class_dict[1] = 50; class_dict[2] = 40 + >>> from imblearn.datasets import make_imbalance + >>> X, y = make_imbalance(X, y, class_dict) + >>> X = X.astype(np.float32) + >>> batch_size, learning_rate, epochs = 10, 0.01, 10 + >>> training_generator, steps_per_epoch = balanced_batch_generator( + ... X, y, sample_weight=None, sampler=None, + ... batch_size=batch_size, random_state=42) + >>> input_size, output_size = X.shape[1], 3 + >>> import tensorflow as tf + >>> def init_weights(shape): + ... return tf.Variable(tf.random_normal(shape, stddev=0.01)) + >>> def accuracy(y_true, y_pred): + ... return np.mean(np.argmax(y_pred, axis=1) == y_true) + >>> # input and output + >>> data = tf.placeholder("float32", shape=[None, input_size]) + >>> targets = tf.placeholder("int32", shape=[None]) + >>> # build the model and weights + >>> W = init_weights([input_size, output_size]) + >>> b = init_weights([output_size]) + >>> out_act = tf.nn.sigmoid(tf.matmul(data, W) + b) + >>> # build the loss, predict, and train operator + >>> cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + ... logits=out_act, labels=targets) + >>> loss = tf.reduce_sum(cross_entropy) + >>> optimizer = tf.train.GradientDescentOptimizer(learning_rate) + >>> train_op = optimizer.minimize(loss) + >>> predict = tf.nn.softmax(out_act) + >>> # Initialization of all variables in the graph + >>> init = tf.global_variables_initializer() + >>> with tf.Session() as sess: + ... print('Starting training') + ... sess.run(init) + ... for e in range(epochs): + ... for i in range(steps_per_epoch): + ... X_batch, y_batch = next(training_generator) + ... feed_dict = dict() + ... feed_dict[data] = X_batch; feed_dict[targets] = y_batch + ... sess.run([train_op, loss], feed_dict=feed_dict) + ... # For each epoch, run accuracy on train and test + ... feed_dict = dict() + ... feed_dict[data] = X + ... predicts_train = sess.run(predict, feed_dict=feed_dict) + ... print("epoch: {{}} train accuracy: {{:.3f}}" + ... .format(e, accuracy(y, predicts_train))) + ... # doctest: +ELLIPSIS + Starting training + [... + + """ + + random_state = check_random_state(random_state) + if sampler is None: + sampler_ = RandomUnderSampler(return_indices=True, + random_state=random_state) + else: + if not hasattr(sampler, 'return_indices'): + raise ValueError("'sampler' needs to return the indices of " + "the samples selected. Provide a sampler " + "which has an attribute 'return_indices'.") + sampler_ = clone(sampler) + sampler_.set_params(return_indices=True) + set_random_state(sampler_, random_state) + + _, _, indices = sampler_.fit_sample(X, y) + # shuffle the indices since the sampler are packing them by class + random_state.shuffle(indices) + + def generator(X, y, sample_weight, indices, batch_size): + while True: + for index in range(0, len(indices), batch_size): + X_res = safe_indexing(X, indices[index:index + batch_size]) + y_res = safe_indexing(y, indices[index:index + batch_size]) + if issparse(X_res) and not sparse: + X_res = X_res.toarray() + if sample_weight is None: + yield X_res, y_res + else: + sw_res = safe_indexing(sample_weight, + indices[index:index + batch_size]) + yield X_res, y_res, sw_res + + return (generator(X, y, sample_weight, indices, batch_size), + int(indices.size // batch_size)) diff --git a/imblearn/tensorflow/tests/test_generator.py b/imblearn/tensorflow/tests/test_generator.py new file mode 100644 index 000000000..48bce2af6 --- /dev/null +++ b/imblearn/tensorflow/tests/test_generator.py @@ -0,0 +1,89 @@ +from __future__ import division + +import pytest +import numpy as np +from scipy import sparse + +from sklearn.datasets import load_iris + +from imblearn.datasets import make_imbalance +from imblearn.under_sampling import NearMiss + +from imblearn.tensorflow import balanced_batch_generator + +tf = pytest.importorskip('tensorflow') + + +@pytest.mark.parametrize("sampler", [None, NearMiss()]) +def test_balanced_batch_generator(sampler): + X, y = load_iris(return_X_y=True) + X, y = make_imbalance(X, y, {0: 30, 1: 50, 2: 40}) + X = X.astype(np.float32) + + batch_size = 10 + training_generator, steps_per_epoch = balanced_batch_generator( + X, y, sample_weight=None, sampler=sampler, + batch_size=batch_size, random_state=42) + + learning_rate = 0.01 + epochs = 10 + input_size = X.shape[1] + output_size = 3 + + # helper functions + def init_weights(shape): + return tf.Variable(tf.random_normal(shape, stddev=0.01)) + + def accuracy(y_true, y_pred): + return np.mean(np.argmax(y_pred, axis=1) == y_true) + + # input and output + data = tf.placeholder("float32", shape=[None, input_size]) + targets = tf.placeholder("int32", shape=[None]) + + # build the model and weights + W = init_weights([input_size, output_size]) + b = init_weights([output_size]) + out_act = tf.nn.sigmoid(tf.matmul(data, W) + b) + + # build the loss, predict, and train operator + cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=out_act, labels=targets) + loss = tf.reduce_sum(cross_entropy) + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + train_op = optimizer.minimize(loss) + predict = tf.nn.softmax(out_act) + + # Initialization of all variables in the graph + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + + for e in range(epochs): + for i in range(steps_per_epoch): + X_batch, y_batch = next(training_generator) + sess.run([train_op, loss], + feed_dict={data: X_batch, targets: y_batch}) + + # For each epoch, run accuracy on train and test + predicts_train = sess.run(predict, feed_dict={data: X}) + print("epoch: {} train accuracy: {:.3f}" + .format(e, accuracy(y, predicts_train))) + + +@pytest.mark.parametrize("is_sparse", [True, False]) +def test_balanced_batch_generator_function_sparse(is_sparse): + X, y = load_iris(return_X_y=True) + X, y = make_imbalance(X, y, {0: 30, 1: 50, 2: 40}) + X = X.astype(np.float32) + + training_generator, steps_per_epoch = balanced_batch_generator( + sparse.csr_matrix(X), y, sparse=is_sparse, batch_size=10, + random_state=42) + for idx in range(steps_per_epoch): + X_batch, y_batch = next(training_generator) + if is_sparse: + assert sparse.issparse(X_batch) + else: + assert not sparse.issparse(X_batch) diff --git a/imblearn/utils/_docstring.py b/imblearn/utils/_docstring.py index a47bef8af..f036f31da 100644 --- a/imblearn/utils/_docstring.py +++ b/imblearn/utils/_docstring.py @@ -25,7 +25,7 @@ def __call__(self, obj): _random_state_docstring = \ """random_state : int, RandomState instance or None, optional (default=None) - Control the randomization of the algorithm + Control the randomization of the algorithm. - If int, ``random_state`` is the seed used by the random number generator; diff --git a/requirements.optional.txt b/requirements.optional.txt new file mode 100644 index 000000000..826277d5e --- /dev/null +++ b/requirements.optional.txt @@ -0,0 +1,2 @@ +keras +tensorflow diff --git a/setup.cfg b/setup.cfg index 56cfb932a..50f9c583a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,4 @@ doctest-extension = rst doctest-fixtures = _fixture [tool:pytest] -addopts = - --doctest-modules - +addopts = --doctest-modules