diff --git a/en-wordlist.txt b/en-wordlist.txt
index 2ccab08b094..50a577bedd4 100644
--- a/en-wordlist.txt
+++ b/en-wordlist.txt
@@ -33,6 +33,7 @@ Captum
Captum's
CartPole
Cayley
+CharRNN
Chatbots
Chen
Colab
@@ -155,6 +156,7 @@ MaskRCNN
Minifier
MobileNet
ModelABC
+MPS
Mypy
NAS
NCCL
@@ -376,6 +378,7 @@ enum
eq
equalities
et
+eval
evaluateInput
extensibility
fastai
@@ -427,6 +430,7 @@ jpg
json
judgements
jupyter
+kernels
keypoint
kwargs
labelled
@@ -613,6 +617,7 @@ uncomment
uncommented
underflowing
unfused
+unicode
unimodal
unigram
unnormalized
diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py
index 2bae524b4a8..67c3f04cbe3 100644
--- a/intermediate_source/char_rnn_classification_tutorial.py
+++ b/intermediate_source/char_rnn_classification_tutorial.py
@@ -25,20 +25,7 @@
Specifically, we'll train on a few thousand surnames from 18 languages
of origin, and predict which language a name is from based on the
-spelling:
-
-.. code-block:: sh
-
- $ python predict.py Hinton
- (-0.47) Scottish
- (-1.52) English
- (-3.57) Irish
-
- $ python predict.py Schmidhuber
- (-0.19) German
- (-2.48) Czech
- (-2.68) Dutch
-
+spelling.
Recommended Preparation
=======================
@@ -61,79 +48,62 @@
Networks `__
is about LSTMs specifically but also informative about RNNs in
general
+"""
+######################################################################
+# Preparing Torch
+# ==========================
+#
+# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
+#
-Preparing the Data
-==================
+import torch
-.. note::
- Download the data from
- `here `_
- and extract it to the current directory.
+# Check if CUDA is available
+device = torch.device('cpu')
+if torch.cuda.is_available():
+ device = torch.device('cuda')
-Included in the ``data/names`` directory are 18 text files named as
-``[Language].txt``. Each file contains a bunch of names, one name per
-line, mostly romanized (but we still need to convert from Unicode to
-ASCII).
+torch.set_default_device(device)
+print(f"Using device = {torch.get_default_device()}")
-We'll end up with a dictionary of lists of names per language,
-``{language: [names ...]}``. The generic variables "category" and "line"
-(for language and name in our case) are used for later extensibility.
-"""
-from io import open
-import glob
-import os
-
-def findFiles(path): return glob.glob(path)
-
-print(findFiles('data/names/*.txt'))
+######################################################################
+# Preparing the Data
+# ==================
+#
+# Download the data from `here `__
+# and extract it to the current directory.
+#
+# Included in the ``data/names`` directory are 18 text files named as
+# ``[Language].txt``. Each file contains a bunch of names, one name per
+# line, mostly romanized (but we still need to convert from Unicode to
+# ASCII).
+#
+# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to
+# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.
+import string
import unicodedata
-import string
-all_letters = string.ascii_letters + " .,;'"
-n_letters = len(all_letters)
+allowed_characters = string.ascii_letters + " .,;'"
+n_letters = len(allowed_characters)
-# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
+# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
- and c in all_letters
+ and c in allowed_characters
)
-print(unicodeToAscii('Ślusàrski'))
-
-# Build the category_lines dictionary, a list of names per language
-category_lines = {}
-all_categories = []
-
-# Read a file and split into lines
-def readLines(filename):
- lines = open(filename, encoding='utf-8').read().strip().split('\n')
- return [unicodeToAscii(line) for line in lines]
-
-for filename in findFiles('data/names/*.txt'):
- category = os.path.splitext(os.path.basename(filename))[0]
- all_categories.append(category)
- lines = readLines(filename)
- category_lines[category] = lines
-
-n_categories = len(all_categories)
-
-
-######################################################################
-# Now we have ``category_lines``, a dictionary mapping each category
-# (language) to a list of lines (names). We also kept track of
-# ``all_categories`` (just a list of languages) and ``n_categories`` for
-# later reference.
+#########################
+# Here's an example of converting a unicode alphabet name to plain ASCII. This simplifies the input layer
#
-print(category_lines['Italian'][:5])
-
+print (f"converting 'Ślusàrski' to {unicodeToAscii('Ślusàrski')}")
######################################################################
# Turning Names into Tensors
-# --------------------------
+# ==========================
#
# Now that we have all the names organized, we need to turn them into
# Tensors to make any use of them.
@@ -147,19 +117,10 @@ def readLines(filename):
#
# That extra 1 dimension is because PyTorch assumes everything is in
# batches - we're just using a batch size of 1 here.
-#
-
-import torch
# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):
- return all_letters.find(letter)
-
-# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
-def letterToTensor(letter):
- tensor = torch.zeros(1, n_letters)
- tensor[0][letterToIndex(letter)] = 1
- return tensor
+ return allowed_characters.find(letter)
# Turn a line into a ,
# or an array of one-hot letter vectors
@@ -169,9 +130,87 @@ def lineToTensor(line):
tensor[li][0][letterToIndex(letter)] = 1
return tensor
-print(letterToTensor('J'))
+#########################
+# Here are some examples of how to use ``lineToTensor()`` for a single and multiple character string.
+
+print (f"The letter 'a' becomes {lineToTensor('a')}") #notice that the first position in the tensor = 1
+print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1
+
+#########################
+# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
+# for other RNN tasks with text.
+#
+# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this,
+# we will use the `Dataset and DataLoader `__ classes
+# to hold our dataset. Each Dataset needs to implement three functions: ``__init__``, ``__len__``, and ``__getitem__``.
+from io import open
+import glob
+import os
+import time
+
+import torch
+from torch.utils.data import Dataset
+
+class NamesDataset(Dataset):
+
+ def __init__(self, data_dir):
+ self.data_dir = data_dir #for provenance of the dataset
+ self.load_time = time.localtime #for provenance of the dataset
+ labels_set = set() #set of all classes
+
+ self.data = []
+ self.data_tensors = []
+ self.labels = []
+ self.labels_tensors = []
+
+ #read all the ``.txt`` files in the specified directory
+ text_files = glob.glob(os.path.join(data_dir, '*.txt'))
+ for filename in text_files:
+ label = os.path.splitext(os.path.basename(filename))[0]
+ labels_set.add(label)
+ lines = open(filename, encoding='utf-8').read().strip().split('\n')
+ for name in lines:
+ self.data.append(name)
+ self.data_tensors.append(lineToTensor(name))
+ self.labels.append(label)
+
+ #Cache the tensor representation of the labels
+ self.labels_uniq = list(labels_set)
+ for idx in range(len(self.labels)):
+ temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)
+ self.labels_tensors.append(temp_tensor)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ data_item = self.data[idx]
+ data_label = self.labels[idx]
+ data_tensor = self.data_tensors[idx]
+ label_tensor = self.labels_tensors[idx]
+
+ return label_tensor, data_tensor, data_label, data_item
-print(lineToTensor('Jones').size())
+
+#########################
+#Here we can load our example data into the ``NamesDataset``
+
+alldata = NamesDataset("data/names")
+print(f"loaded {len(alldata)} items of data")
+print(f"example = {alldata[0]}")
+
+#########################
+#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20
+# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the
+#same device as PyTorch defaults to above.
+
+train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024))
+
+print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")
+
+#########################
+# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also
+#split the dataset into training and testing so we can validate the model that we build.
######################################################################
@@ -183,114 +222,58 @@ def lineToTensor(line):
# held hidden state and gradients which are now entirely handled by the
# graph itself. This means you can implement a RNN in a very "pure" way,
# as regular feed-forward layers.
-#
-# This RNN module implements a "vanilla RNN" an is just 3 linear layers
-# which operate on an input and hidden state, with a ``LogSoftmax`` layer
-# after the output.
+#
+# This CharRNN class implements an RNN with three components.
+# First, we use the `nn.RNN implementation `__.
+# Next, we define a layer that maps the RNN hidden layers to our output. And finally, we apply a ``softmax`` function. Using ``nn.RNN``
+# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing
+# each layer as a ``nn.Linear``. It also simplifies the implementation in ``forward()``.
#
import torch.nn as nn
import torch.nn.functional as F
-class RNN(nn.Module):
+class CharRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
- super(RNN, self).__init__()
+ super(CharRNN, self).__init__()
- self.hidden_size = hidden_size
-
- self.i2h = nn.Linear(input_size, hidden_size)
- self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.rnn = nn.RNN(input_size, hidden_size)
self.h2o = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
-
- def forward(self, input, hidden):
- hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
- output = self.h2o(hidden)
+
+ def forward(self, line_tensor):
+ rnn_out, hidden = self.rnn(line_tensor)
+ output = self.h2o(hidden[0])
output = self.softmax(output)
- return output, hidden
- def initHidden(self):
- return torch.zeros(1, self.hidden_size)
+ return output
-n_hidden = 128
-rnn = RNN(n_letters, n_hidden, n_categories)
-
-
-######################################################################
-# To run a step of this network we need to pass an input (in our case, the
-# Tensor for the current letter) and a previous hidden state (which we
-# initialize as zeros at first). We'll get back the output (probability of
-# each language) and a next hidden state (which we keep for the next
-# step).
-#
-input = letterToTensor('A')
-hidden = torch.zeros(1, n_hidden)
-
-output, next_hidden = rnn(input, hidden)
+###########################
+# We can then create an RNN with 57 input nodes, 128 hidden nodes, and 18 outputs:
+n_hidden = 128
+rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq))
+print(rnn)
######################################################################
-# For the sake of efficiency we don't want to be creating a new Tensor for
-# every step, so we will use ``lineToTensor`` instead of
-# ``letterToTensor`` and use slices. This could be further optimized by
-# precomputing batches of Tensors.
-#
-
-input = lineToTensor('Albert')
-hidden = torch.zeros(1, n_hidden)
-
-output, next_hidden = rnn(input[0], hidden)
-print(output)
-
+# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently,
+# we use a helper function, ``label_from_output``, to derive a text label for the class.
-######################################################################
-# As you can see the output is a ``<1 x n_categories>`` Tensor, where
-# every item is the likelihood of that category (higher is more likely).
-#
+def label_from_output(output, output_labels):
+ top_n, top_i = output.topk(1)
+ label_i = top_i[0].item()
+ return output_labels[label_i], label_i
+input = lineToTensor('Albert')
+output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``
+print(output)
+print(label_from_output(output, alldata.labels_uniq))
######################################################################
#
# Training
# ========
-# Preparing for Training
-# ----------------------
-#
-# Before going into training we should make a few helper functions. The
-# first is to interpret the output of the network, which we know to be a
-# likelihood of each category. We can use ``Tensor.topk`` to get the index
-# of the greatest value:
-#
-
-def categoryFromOutput(output):
- top_n, top_i = output.topk(1)
- category_i = top_i[0].item()
- return all_categories[category_i], category_i
-
-print(categoryFromOutput(output))
-
-
-######################################################################
-# We will also want a quick way to get a training example (a name and its
-# language):
-#
-
-import random
-
-def randomChoice(l):
- return l[random.randint(0, len(l) - 1)]
-
-def randomTrainingExample():
- category = randomChoice(all_categories)
- line = randomChoice(category_lines[category])
- category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)
- line_tensor = lineToTensor(line)
- return category, line, category_tensor, line_tensor
-
-for i in range(10):
- category, line, category_tensor, line_tensor = randomTrainingExample()
- print('category =', category, '/ line =', line)
######################################################################
@@ -300,93 +283,67 @@ def randomTrainingExample():
# Now all it takes to train this network is show it a bunch of examples,
# have it make guesses, and tell it if it's wrong.
#
-# For the loss function ``nn.NLLLoss`` is appropriate, since the last
-# layer of the RNN is ``nn.LogSoftmax``.
-#
-
-criterion = nn.NLLLoss()
-
-
-######################################################################
-# Each loop of training will:
-#
-# - Create input and target tensors
-# - Create a zeroed initial hidden state
-# - Read each letter in and
-#
-# - Keep hidden state for next letter
-#
-# - Compare final output to target
-# - Back-propagate
-# - Return the output and loss
-#
-
-learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
-
-def train(category_tensor, line_tensor):
- hidden = rnn.initHidden()
-
- rnn.zero_grad()
-
- for i in range(line_tensor.size()[0]):
- output, hidden = rnn(line_tensor[i], hidden)
-
- loss = criterion(output, category_tensor)
- loss.backward()
-
- # Add parameters' gradients to their values, multiplied by learning rate
- for p in rnn.parameters():
- p.data.add_(p.grad.data, alpha=-learning_rate)
-
- return output, loss.item()
-
-
-######################################################################
-# Now we just have to run that with a bunch of examples. Since the
-# ``train`` function returns both the output and loss we can print its
-# guesses and also keep track of loss for plotting. Since there are 1000s
-# of examples we print only every ``print_every`` examples, and take an
-# average of the loss.
-#
-
-import time
-import math
-
-n_iters = 100000
-print_every = 5000
-plot_every = 1000
-
-
-
-# Keep track of losses for plotting
-current_loss = 0
-all_losses = []
+# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs
+# RNNs are trained similarly to other networks; therefore, for completeness, we include a batched training method here.
+# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the
+# weights. This operation is repeated until the number of epochs is reached.
+
+import random
+import numpy as np
+
+def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):
+ """
+ Learn on a batch of training_data for a specified number of iterations and reporting thresholds
+ """
+ # Keep track of losses for plotting
+ current_loss = 0
+ all_losses = []
+ rnn.train()
+ optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)
+
+ start = time.time()
+ print(f"training on data set with n = {len(training_data)}")
+
+ for iter in range(1, n_epoch + 1):
+ rnn.zero_grad() # clear the gradients
+
+ # create some minibatches
+ # we cannot use dataloaders because each of our names is a different length
+ batches = list(range(len(training_data)))
+ random.shuffle(batches)
+ batches = np.array_split(batches, len(batches) //n_batch_size )
+
+ for idx, batch in enumerate(batches):
+ batch_loss = 0
+ for i in batch: #for each example in this batch
+ (label_tensor, text_tensor, label, text) = training_data[i]
+ output = rnn.forward(text_tensor)
+ loss = criterion(output, label_tensor)
+ batch_loss += loss
+
+ # optimize parameters
+ batch_loss.backward()
+ nn.utils.clip_grad_norm_(rnn.parameters(), 3)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ current_loss += batch_loss.item() / len(batch)
+
+ all_losses.append(current_loss / len(batches) )
+ if iter % report_every == 0:
+ print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")
+ current_loss = 0
+
+ return all_losses
-def timeSince(since):
- now = time.time()
- s = now - since
- m = math.floor(s / 60)
- s -= m * 60
- return '%dm %ds' % (m, s)
+##########################################################################
+# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this
+# example is reduced to speed up the build. You can get better results with different parameters.
start = time.time()
-
-for iter in range(1, n_iters + 1):
- category, line, category_tensor, line_tensor = randomTrainingExample()
- output, loss = train(category_tensor, line_tensor)
- current_loss += loss
-
- # Print ``iter`` number, loss, name and guess
- if iter % print_every == 0:
- guess, guess_i = categoryFromOutput(output)
- correct = '✓' if guess == category else '✗ (%s)' % category
- print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))
-
- # Add current loss avg to list of losses
- if iter % plot_every == 0:
- all_losses.append(current_loss / plot_every)
- current_loss = 0
-
+all_losses = train(rnn, train_set, n_epoch=27, learning_rate=0.15, report_every=5)
+end = time.time()
+print(f"training took {end-start}s")
######################################################################
# Plotting the Results
@@ -401,7 +358,7 @@ def timeSince(since):
plt.figure()
plt.plot(all_losses)
-
+plt.show()
######################################################################
# Evaluating the Results
@@ -414,48 +371,45 @@ def timeSince(since):
# ``evaluate()``, which is the same as ``train()`` minus the backprop.
#
-# Keep track of correct guesses in a confusion matrix
-confusion = torch.zeros(n_categories, n_categories)
-n_confusion = 10000
-
-# Just return an output given a line
-def evaluate(line_tensor):
- hidden = rnn.initHidden()
-
- for i in range(line_tensor.size()[0]):
- output, hidden = rnn(line_tensor[i], hidden)
+def evaluate(rnn, testing_data, classes):
+ confusion = torch.zeros(len(classes), len(classes))
+
+ rnn.eval() #set to eval mode
+ with torch.no_grad(): # do not record the gradients during eval phase
+ for i in range(len(testing_data)):
+ (label_tensor, text_tensor, label, text) = testing_data[i]
+ output = rnn(text_tensor)
+ guess, guess_i = label_from_output(output, classes)
+ label_i = classes.index(label)
+ confusion[label_i][guess_i] += 1
- return output
+ # Normalize by dividing every row by its sum
+ for i in range(len(classes)):
+ denom = confusion[i].sum()
+ if denom > 0:
+ confusion[i] = confusion[i] / denom
-# Go through a bunch of examples and record which are correctly guessed
-for i in range(n_confusion):
- category, line, category_tensor, line_tensor = randomTrainingExample()
- output = evaluate(line_tensor)
- guess, guess_i = categoryFromOutput(output)
- category_i = all_categories.index(category)
- confusion[category_i][guess_i] += 1
+ # Set up plot
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version
+ fig.colorbar(cax)
-# Normalize by dividing every row by its sum
-for i in range(n_categories):
- confusion[i] = confusion[i] / confusion[i].sum()
+ # Set up axes
+ ax.set_xticks(np.arange(len(classes)), labels=classes, rotation=90)
+ ax.set_yticks(np.arange(len(classes)), labels=classes)
-# Set up plot
-fig = plt.figure()
-ax = fig.add_subplot(111)
-cax = ax.matshow(confusion.numpy())
-fig.colorbar(cax)
+ # Force label at every tick
+ ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
+ ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
-# Set up axes
-ax.set_xticklabels([''] + all_categories, rotation=90)
-ax.set_yticklabels([''] + all_categories)
+ # sphinx_gallery_thumbnail_number = 2
+ plt.show()
-# Force label at every tick
-ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
-ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
-# sphinx_gallery_thumbnail_number = 2
-plt.show()
+evaluate(rnn, test_set, classes=alldata.labels_uniq)
+
######################################################################
# You can pick out bright spots off the main axis that show which
@@ -465,72 +419,20 @@ def evaluate(line_tensor):
#
-######################################################################
-# Running on User Input
-# ---------------------
-#
-
-def predict(input_line, n_predictions=3):
- print('\n> %s' % input_line)
- with torch.no_grad():
- output = evaluate(lineToTensor(input_line))
-
- # Get top N categories
- topv, topi = output.topk(n_predictions, 1, True)
- predictions = []
-
- for i in range(n_predictions):
- value = topv[0][i].item()
- category_index = topi[0][i].item()
- print('(%.2f) %s' % (value, all_categories[category_index]))
- predictions.append([value, all_categories[category_index]])
-
-predict('Dovesky')
-predict('Jackson')
-predict('Satoshi')
-
-
-######################################################################
-# The final versions of the scripts `in the Practical PyTorch
-# repo `__
-# split the above code into a few files:
-#
-# - ``data.py`` (loads files)
-# - ``model.py`` (defines the RNN)
-# - ``train.py`` (runs training)
-# - ``predict.py`` (runs ``predict()`` with command line arguments)
-# - ``server.py`` (serve prediction as a JSON API with ``bottle.py``)
-#
-# Run ``train.py`` to train and save the network.
-#
-# Run ``predict.py`` with a name to view predictions:
-#
-# .. code-block:: sh
-#
-# $ python predict.py Hazaki
-# (-0.42) Japanese
-# (-1.39) Polish
-# (-3.51) Czech
-#
-# Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON
-# output of predictions.
-#
-
-
######################################################################
# Exercises
# =========
#
-# - Try with a different dataset of line -> category, for example:
-#
-# - Any word -> language
-# - First name -> gender
-# - Character name -> writer
-# - Page title -> blog or subreddit
-#
# - Get better results with a bigger and/or better shaped network
#
-# - Add more linear layers
+# - Adjust the hyperparameters to enhance performance, such as changing the number of epochs, batch size, and learning rate
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
+# - Modify the size of the layers, such as increasing or decreasing the number of hidden nodes or adding additional linear layers
# - Combine multiple of these RNNs as a higher level network
+#
+# - Try with a different dataset of line -> label, for example:
#
+# - Any word -> language
+# - First name -> gender
+# - Character name -> writer
+# - Page title -> blog or subreddit
\ No newline at end of file