Skip to content

Commit fe9d29c

Browse files
framoncgdvrogozh
andauthored
Update word_language_model/generate.py to remove duplicates, use abc order
Co-authored-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent 4d5f787 commit fe9d29c

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

word_language_model/generate.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
with open(args.checkpoint, 'rb') as f:
5656
safe_globals = [
5757
PositionalEncoding,
58+
RNNModel,
5859
TransformerModel,
5960
torch.nn.functional.relu,
6061
torch.nn.modules.activation.MultiheadAttention,
@@ -64,15 +65,11 @@
6465
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
6566
torch.nn.modules.normalization.LayerNorm,
6667
torch.nn.modules.sparse.Embedding,
67-
torch.nn.modules.transformer.TransformerEncoder,
68-
torch.nn.modules.transformer.TransformerEncoderLayer,
69-
RNNModel,
70-
torch.nn.modules.dropout.Dropout,
71-
torch.nn.modules.linear.Linear,
7268
torch.nn.modules.rnn.GRU,
7369
torch.nn.modules.rnn.LSTM,
7470
torch.nn.modules.rnn.RNN,
75-
torch.nn.modules.sparse.Embedding,
71+
torch.nn.modules.transformer.TransformerEncoder,
72+
torch.nn.modules.transformer.TransformerEncoderLayer,
7673
]
7774
with torch.serialization.safe_globals(safe_globals):
7875
model = torch.load(f, map_location=device)

0 commit comments

Comments
 (0)