diff --git a/run_python_examples.sh b/run_python_examples.sh index e075a28ed2..2d769c0ae1 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -154,6 +154,9 @@ function vision_transformer() { function word_language_model() { uv run main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed" + for model in "RNN_TANH" "RNN_RELU" "LSTM" "GRU" "Transformer"; do + uv run main.py --model $model --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed" + done } function gcn() { diff --git a/word_language_model/main.py b/word_language_model/main.py index 23bda03e73..72fee6cd3b 100644 --- a/word_language_model/main.py +++ b/word_language_model/main.py @@ -8,7 +8,7 @@ import torch.onnx import data -import model +from model import PositionalEncoding, RNNModel, TransformerModel parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM/GRU/Transformer Language Model') parser.add_argument('--data', type=str, default='./data/wikitext-2', @@ -108,9 +108,9 @@ def batchify(data, bsz): ntokens = len(corpus.dictionary) if args.model == 'Transformer': - model = model.TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device) + model = TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device) else: - model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device) + model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device) criterion = nn.NLLLoss() @@ -243,7 +243,33 @@ def export_onnx(path, batch_size, seq_len): # Load the best saved model. with open(args.save, 'rb') as f: - model = torch.load(f) + if args.model == 'Transformer': + safe_globals = [ + PositionalEncoding, + TransformerModel, + torch.nn.functional.relu, + torch.nn.modules.activation.MultiheadAttention, + torch.nn.modules.container.ModuleList, + torch.nn.modules.dropout.Dropout, + torch.nn.modules.linear.Linear, + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + torch.nn.modules.normalization.LayerNorm, + torch.nn.modules.sparse.Embedding, + torch.nn.modules.transformer.TransformerEncoder, + torch.nn.modules.transformer.TransformerEncoderLayer, + ] + else: + safe_globals = [ + RNNModel, + torch.nn.modules.dropout.Dropout, + torch.nn.modules.linear.Linear, + torch.nn.modules.rnn.GRU, + torch.nn.modules.rnn.LSTM, + torch.nn.modules.rnn.RNN, + torch.nn.modules.sparse.Embedding, + ] + with torch.serialization.safe_globals(safe_globals): + model = torch.load(f) # after load the rnn params are not a continuous chunk of memory # this makes them a continuous chunk, and will speed up forward pass # Currently, only rnn model supports flatten_parameters function. diff --git a/word_language_model/requirements.txt b/word_language_model/requirements.txt index 43dbf9ee52..90b9e8b11d 100644 --- a/word_language_model/requirements.txt +++ b/word_language_model/requirements.txt @@ -1 +1 @@ -torch<2.6 +torch>=2.6