|
1 | 1 | """Implement generators for ``keras`` which will balance the data."""
|
2 |
| -from __future__ import division |
| 2 | + |
3 | 3 |
|
4 | 4 | # This is a trick to avoid an error during tests collection with pytest. We
|
5 | 5 | # avoid the error when importing the package raise the error at the moment of
|
6 | 6 | # creating the instance.
|
7 |
| -try: |
8 |
| - import keras |
9 |
| - ParentClass = keras.utils.Sequence |
10 |
| - HAS_KERAS = True |
11 |
| -except ImportError: |
12 |
| - ParentClass = object |
13 |
| - HAS_KERAS = False |
| 7 | +def import_keras(): |
| 8 | + """Try to import keras from keras and tensorflow. |
| 9 | +
|
| 10 | + This is possible to import the sequence from keras or tensorflow. |
| 11 | + Keras is not ducktyping ``Sequence`` before 2.3 and we need import from |
| 12 | + all possible library to ensure that the ``isinstance(...)`` is not going |
| 13 | + to fail. This function can be modified when we support Keras 2.3. |
| 14 | + """ |
| 15 | + |
| 16 | + def import_from_keras(): |
| 17 | + try: |
| 18 | + import keras |
| 19 | + return (keras.utils.Sequence,), True |
| 20 | + except ImportError: |
| 21 | + return tuple(), False |
| 22 | + |
| 23 | + def import_from_tensforflow(): |
| 24 | + try: |
| 25 | + from tensorflow import keras |
| 26 | + return (keras.utils.Sequence,), True |
| 27 | + except ImportError: |
| 28 | + return tuple(), False |
| 29 | + |
| 30 | + ParentClassKeras, has_keras_k = import_from_keras() |
| 31 | + ParentClassTensorflow, has_keras_tf = import_from_tensforflow() |
| 32 | + has_keras = has_keras_k or has_keras_tf |
| 33 | + if has_keras: |
| 34 | + ParentClass = (ParentClassKeras + ParentClassTensorflow) |
| 35 | + else: |
| 36 | + ParentClass = (object,) |
| 37 | + return ParentClass, has_keras |
| 38 | + |
| 39 | + |
| 40 | +ParentClass, HAS_KERAS = import_keras() |
14 | 41 |
|
15 | 42 | from scipy.sparse import issparse
|
16 | 43 |
|
|
29 | 56 | 'NeighbourhoodCleaningRule', 'TomekLinks')
|
30 | 57 |
|
31 | 58 |
|
32 |
| -class BalancedBatchGenerator(ParentClass): |
| 59 | +class BalancedBatchGenerator(*ParentClass): |
33 | 60 | """Create balanced batches when training a keras model.
|
34 | 61 |
|
35 | 62 | Create a keras ``Sequence`` which is given to ``fit_generator``. The
|
@@ -102,6 +129,10 @@ class BalancedBatchGenerator(ParentClass):
|
102 | 129 | ... epochs=10, verbose=0)
|
103 | 130 |
|
104 | 131 | """
|
| 132 | + |
| 133 | + # flag for keras sequence duck-typing |
| 134 | + use_sequence_api = True |
| 135 | + |
105 | 136 | def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32,
|
106 | 137 | keep_sparse=False, random_state=None):
|
107 | 138 | if not HAS_KERAS:
|
|
0 commit comments