1
1
import tensorflow as tf
2
2
3
3
class StackedBiLSTMClassifier (tf .keras .Model ):
4
- def __init__ (self , feature_columns , units = 64 , stack_size = 1 , n_classes = 2 ):
4
+ def __init__ (self , feature_columns , stack_units = [ 32 ], hidden_size = 64 , n_classes = 2 ):
5
5
"""StackedBiLSTMClassifier
6
6
:param feature_columns: All columns must be embedding of sequence column with same sequence_length.
7
7
:type feature_columns: list[tf.embedding_column].
8
- :param units: Units for LSTM layer.
9
- :type units: int.
10
- :param stack_size: number of bidirectional LSTM layers in the stack, default 1.
11
- :type stack_size: int.
8
+ :param stack_units: Units for LSTM layer.
9
+ :type stack_units: vector of ints.
12
10
:param n_classes: Target number of classes.
13
11
:type n_classes: int.
14
12
"""
15
13
super (StackedBiLSTMClassifier , self ).__init__ ()
16
14
17
15
self .feature_layer = tf .keras .experimental .SequenceFeatures (feature_columns )
18
16
self .stack_bilstm = []
19
- self .stack_size = stack_size
20
- if stack_size > 1 :
21
- for i in range (stack_size - 1 ):
17
+ self .stack_size = len (stack_units )
18
+ self .stack_units = stack_units
19
+ self .n_classes = n_classes
20
+ if self .stack_size > 1 :
21
+ for i in range (self .stack_size - 1 ):
22
22
self .stack_bilstm .append (
23
- tf .keras .layers .Bidirectional (tf .keras .layers .LSTM (units , return_sequences = True ))
23
+ tf .keras .layers .Bidirectional (tf .keras .layers .LSTM (self . stack_units [ i ] , return_sequences = True ))
24
24
)
25
- self .lstm = tf .keras .layers .Bidirectional (tf .keras .layers .LSTM (units ))
26
- self .pred = tf .keras .layers .Dense (n_classes , activation = 'softmax' )
25
+ self .lstm = tf .keras .layers .Bidirectional (tf .keras .layers .LSTM (self .stack_units [- 1 ]))
26
+ self .hidden = tf .keras .layers .Dense (hidden_size , activation = 'relu' )
27
+ if self .n_classes == 2 :
28
+ # special setup for binary classification
29
+ pred_act = 'sigmoid'
30
+ self .loss = 'binary_crossentropy'
31
+ else :
32
+ pred_act = 'softmax'
33
+ self .loss = 'categorical_crossentropy'
34
+ self .pred = tf .keras .layers .Dense (n_classes , activation = pred_act )
27
35
28
36
def call (self , inputs ):
29
37
x , seq_len = self .feature_layer (inputs )
@@ -32,6 +40,7 @@ def call(self, inputs):
32
40
for i in range (self .stack_size - 1 ):
33
41
x = self .stack_bilstm [i ](x , mask = seq_mask )
34
42
x = self .lstm (x , mask = seq_mask )
43
+ x = self .hidden (x )
35
44
return self .pred (x )
36
45
37
46
def default_optimizer (self ):
@@ -40,7 +49,7 @@ def default_optimizer(self):
40
49
41
50
def default_loss (self ):
42
51
"""Default loss function. Used in model.compile."""
43
- return 'categorical_crossentropy'
52
+ return self . loss
44
53
45
54
def default_training_epochs (self ):
46
55
"""Default training epochs. Used in model.fit."""
0 commit comments