Skip to content

Commit 2fa0596

Browse files
authored
BUG allow to import keras from tensorflow (#532)
1 parent 85422e8 commit 2fa0596

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

doc/whats_new/v0.5.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@ Bug
7777
:class:`imblearn.ensemble.RUSBoostClassifier` to get a decision stump as a
7878
weak learner as in the original paper.
7979
:pr:`545` by :user:`Christos Aridas <chkoar>`.
80+
81+
- Allow to import ``keras`` directly from ``tensorflow`` in the
82+
:mod:`imblearn.keras`.
83+
:pr:`531` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/keras/_generator.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,43 @@
11
"""Implement generators for ``keras`` which will balance the data."""
2-
from __future__ import division
2+
33

44
# This is a trick to avoid an error during tests collection with pytest. We
55
# avoid the error when importing the package raise the error at the moment of
66
# creating the instance.
7-
try:
8-
import keras
9-
ParentClass = keras.utils.Sequence
10-
HAS_KERAS = True
11-
except ImportError:
12-
ParentClass = object
13-
HAS_KERAS = False
7+
def import_keras():
8+
"""Try to import keras from keras and tensorflow.
9+
10+
This is possible to import the sequence from keras or tensorflow.
11+
Keras is not ducktyping ``Sequence`` before 2.3 and we need import from
12+
all possible library to ensure that the ``isinstance(...)`` is not going
13+
to fail. This function can be modified when we support Keras 2.3.
14+
"""
15+
16+
def import_from_keras():
17+
try:
18+
import keras
19+
return (keras.utils.Sequence,), True
20+
except ImportError:
21+
return tuple(), False
22+
23+
def import_from_tensforflow():
24+
try:
25+
from tensorflow import keras
26+
return (keras.utils.Sequence,), True
27+
except ImportError:
28+
return tuple(), False
29+
30+
ParentClassKeras, has_keras_k = import_from_keras()
31+
ParentClassTensorflow, has_keras_tf = import_from_tensforflow()
32+
has_keras = has_keras_k or has_keras_tf
33+
if has_keras:
34+
ParentClass = (ParentClassKeras + ParentClassTensorflow)
35+
else:
36+
ParentClass = (object,)
37+
return ParentClass, has_keras
38+
39+
40+
ParentClass, HAS_KERAS = import_keras()
1441

1542
from scipy.sparse import issparse
1643

@@ -29,7 +56,7 @@
2956
'NeighbourhoodCleaningRule', 'TomekLinks')
3057

3158

32-
class BalancedBatchGenerator(ParentClass):
59+
class BalancedBatchGenerator(*ParentClass):
3360
"""Create balanced batches when training a keras model.
3461
3562
Create a keras ``Sequence`` which is given to ``fit_generator``. The
@@ -102,6 +129,10 @@ class BalancedBatchGenerator(ParentClass):
102129
... epochs=10, verbose=0)
103130
104131
"""
132+
133+
# flag for keras sequence duck-typing
134+
use_sequence_api = True
135+
105136
def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32,
106137
keep_sparse=False, random_state=None):
107138
if not HAS_KERAS:

0 commit comments

Comments
 (0)