Skip to content

BUG: allow to import keras from tensorflow #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ Bug
:class:`imblearn.ensemble.RUSBoostClassifier` to get a decision stump as a
weak learner as in the original paper.
:pr:`545` by :user:`Christos Aridas <chkoar>`.

- Allow to import ``keras`` directly from ``tensorflow`` in the
:mod:`imblearn.keras`.
:pr:`531` by :user:`Guillaume Lemaitre <glemaitre>`.
49 changes: 40 additions & 9 deletions imblearn/keras/_generator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
"""Implement generators for ``keras`` which will balance the data."""
from __future__ import division


# This is a trick to avoid an error during tests collection with pytest. We
# avoid the error when importing the package raise the error at the moment of
# creating the instance.
try:
import keras
ParentClass = keras.utils.Sequence
HAS_KERAS = True
except ImportError:
ParentClass = object
HAS_KERAS = False
def import_keras():
"""Try to import keras from keras and tensorflow.

This is possible to import the sequence from keras or tensorflow.
Keras is not ducktyping ``Sequence`` before 2.3 and we need import from
all possible library to ensure that the ``isinstance(...)`` is not going
to fail. This function can be modified when we support Keras 2.3.
"""

def import_from_keras():
try:
import keras
return (keras.utils.Sequence,), True
except ImportError:
return tuple(), False

def import_from_tensforflow():
try:
from tensorflow import keras
return (keras.utils.Sequence,), True
except ImportError:
return tuple(), False

ParentClassKeras, has_keras_k = import_from_keras()
ParentClassTensorflow, has_keras_tf = import_from_tensforflow()
has_keras = has_keras_k or has_keras_tf
if has_keras:
ParentClass = (ParentClassKeras + ParentClassTensorflow)
else:
ParentClass = (object,)
return ParentClass, has_keras


ParentClass, HAS_KERAS = import_keras()

from scipy.sparse import issparse

Expand All @@ -29,7 +56,7 @@
'NeighbourhoodCleaningRule', 'TomekLinks')


class BalancedBatchGenerator(ParentClass):
class BalancedBatchGenerator(*ParentClass):
"""Create balanced batches when training a keras model.

Create a keras ``Sequence`` which is given to ``fit_generator``. The
Expand Down Expand Up @@ -102,6 +129,10 @@ class BalancedBatchGenerator(ParentClass):
... epochs=10, verbose=0)

"""

# flag for keras sequence duck-typing
use_sequence_api = True

def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32,
keep_sparse=False, random_state=None):
if not HAS_KERAS:
Expand Down