Skip to content

Commit a996498

Browse files
committed
reducing number of epochs by 50% to speed up build
1 parent 7240041 commit a996498

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

intermediate_source/char_rnn_classification_tutorial.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@
4949
is about LSTMs specifically but also informative about RNNs in
5050
general
5151
"""
52+
######################################################################
53+
# Preparing Torch
54+
# ==========================
55+
#
56+
# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
57+
#
58+
59+
import torch
60+
61+
# Check if CUDA is available
62+
device = torch.device('cpu')
63+
if torch.cuda.is_available():
64+
device = torch.device('cuda')
65+
66+
torch.set_default_device(device)
67+
print(f"Using device = {torch.get_default_device()}")
5268

5369
######################################################################
5470
# Preparing the Data
@@ -65,8 +81,6 @@
6581
# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to
6682
# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.
6783

68-
import torch
69-
device = torch.device('cpu')
7084
import string
7185
import unicodedata
7286

@@ -326,7 +340,7 @@ def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50
326340
# We can now train a dataset with minibatches for a specified number of epochs
327341

328342
start = time.time()
329-
all_losses = train(rnn, train_set, n_epoch=55, learning_rate=0.15, report_every=5)
343+
all_losses = train(rnn, train_set, n_epoch=27, learning_rate=0.15, report_every=5)
330344
end = time.time()
331345
print(f"training took {end-start}s")
332346

0 commit comments

Comments
 (0)