Skip to content

Commit 9cfac37

Browse files
authored
refine bilstm model (#11)
* refine bilstm model * update by comment
1 parent 711f8c3 commit 9cfac37

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

sqlflow_models/lstmclassifier.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
11
import tensorflow as tf
22

33
class StackedBiLSTMClassifier(tf.keras.Model):
4-
def __init__(self, feature_columns, units=64, stack_size=1, n_classes=2):
4+
def __init__(self, feature_columns, stack_units=[32], hidden_size=64, n_classes=2):
55
"""StackedBiLSTMClassifier
66
:param feature_columns: All columns must be embedding of sequence column with same sequence_length.
77
:type feature_columns: list[tf.embedding_column].
8-
:param units: Units for LSTM layer.
9-
:type units: int.
10-
:param stack_size: number of bidirectional LSTM layers in the stack, default 1.
11-
:type stack_size: int.
8+
:param stack_units: Units for LSTM layer.
9+
:type stack_units: vector of ints.
1210
:param n_classes: Target number of classes.
1311
:type n_classes: int.
1412
"""
1513
super(StackedBiLSTMClassifier, self).__init__()
1614

1715
self.feature_layer = tf.keras.experimental.SequenceFeatures(feature_columns)
1816
self.stack_bilstm = []
19-
self.stack_size = stack_size
20-
if stack_size > 1:
21-
for i in range(stack_size - 1):
17+
self.stack_size = len(stack_units)
18+
self.stack_units = stack_units
19+
self.n_classes = n_classes
20+
if self.stack_size > 1:
21+
for i in range(self.stack_size - 1):
2222
self.stack_bilstm.append(
23-
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units, return_sequences=True))
23+
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(self.stack_units[i], return_sequences=True))
2424
)
25-
self.lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units))
26-
self.pred = tf.keras.layers.Dense(n_classes, activation='softmax')
25+
self.lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(self.stack_units[-1]))
26+
self.hidden = tf.keras.layers.Dense(hidden_size, activation='relu')
27+
if self.n_classes == 2:
28+
# special setup for binary classification
29+
pred_act = 'sigmoid'
30+
self.loss = 'binary_crossentropy'
31+
else:
32+
pred_act = 'softmax'
33+
self.loss = 'categorical_crossentropy'
34+
self.pred = tf.keras.layers.Dense(n_classes, activation=pred_act)
2735

2836
def call(self, inputs):
2937
x, seq_len = self.feature_layer(inputs)
@@ -32,6 +40,7 @@ def call(self, inputs):
3240
for i in range(self.stack_size - 1):
3341
x = self.stack_bilstm[i](x, mask=seq_mask)
3442
x = self.lstm(x, mask=seq_mask)
43+
x = self.hidden(x)
3544
return self.pred(x)
3645

3746
def default_optimizer(self):
@@ -40,7 +49,7 @@ def default_optimizer(self):
4049

4150
def default_loss(self):
4251
"""Default loss function. Used in model.compile."""
43-
return 'categorical_crossentropy'
52+
return self.loss
4453

4554
def default_training_epochs(self):
4655
"""Default training epochs. Used in model.fit."""

tests/test_lstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def setUp(self):
1919
fea,
2020
dimension=32)
2121
feature_columns = [emb]
22-
self.model = sqlflow_models.StackedBiLSTMClassifier(feature_columns=feature_columns)
22+
self.model = sqlflow_models.StackedBiLSTMClassifier(feature_columns=feature_columns, stack_units=[64, 32])
2323

2424

2525
if __name__ == '__main__':

0 commit comments

Comments
 (0)