Skip to content

Support torch>=2.6 in word_language_model example #1347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ function vision_transformer() {

function word_language_model() {
uv run main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
for model in "RNN_TANH" "RNN_RELU" "LSTM" "GRU" "Transformer"; do
uv run main.py --model $model --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
done
}

function gcn() {
Expand Down
34 changes: 30 additions & 4 deletions word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.onnx

import data
import model
from model import PositionalEncoding, RNNModel, TransformerModel

parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM/GRU/Transformer Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2',
Expand Down Expand Up @@ -108,9 +108,9 @@ def batchify(data, bsz):

ntokens = len(corpus.dictionary)
if args.model == 'Transformer':
model = model.TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
model = TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
else:
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)
model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)

criterion = nn.NLLLoss()

Expand Down Expand Up @@ -243,7 +243,33 @@ def export_onnx(path, batch_size, seq_len):

# Load the best saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)
if args.model == 'Transformer':
safe_globals = [
PositionalEncoding,
TransformerModel,
torch.nn.functional.relu,
torch.nn.modules.activation.MultiheadAttention,
torch.nn.modules.container.ModuleList,
torch.nn.modules.dropout.Dropout,
torch.nn.modules.linear.Linear,
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
torch.nn.modules.normalization.LayerNorm,
torch.nn.modules.sparse.Embedding,
torch.nn.modules.transformer.TransformerEncoder,
torch.nn.modules.transformer.TransformerEncoderLayer,
]
else:
safe_globals = [
RNNModel,
torch.nn.modules.dropout.Dropout,
torch.nn.modules.linear.Linear,
torch.nn.modules.rnn.GRU,
torch.nn.modules.rnn.LSTM,
torch.nn.modules.rnn.RNN,
torch.nn.modules.sparse.Embedding,
]
with torch.serialization.safe_globals(safe_globals):
model = torch.load(f)
# after load the rnn params are not a continuous chunk of memory
# this makes them a continuous chunk, and will speed up forward pass
# Currently, only rnn model supports flatten_parameters function.
Expand Down
2 changes: 1 addition & 1 deletion word_language_model/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torch<2.6
torch>=2.6
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.serialization.safe_globals I've used in this PR appeared in torch 2.5. The torch.load default to load weights was changed to True in 2.6. Since we in any case plan to bump dependency to 2.6 by using torch.accelerate, I think we can update the requirement to 2.6 right away.