diff --git a/en-wordlist.txt b/en-wordlist.txt index 2ccab08b094..50a577bedd4 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -33,6 +33,7 @@ Captum Captum's CartPole Cayley +CharRNN Chatbots Chen Colab @@ -155,6 +156,7 @@ MaskRCNN Minifier MobileNet ModelABC +MPS Mypy NAS NCCL @@ -376,6 +378,7 @@ enum eq equalities et +eval evaluateInput extensibility fastai @@ -427,6 +430,7 @@ jpg json judgements jupyter +kernels keypoint kwargs labelled @@ -613,6 +617,7 @@ uncomment uncommented underflowing unfused +unicode unimodal unigram unnormalized diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 2bae524b4a8..67c3f04cbe3 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -25,20 +25,7 @@ Specifically, we'll train on a few thousand surnames from 18 languages of origin, and predict which language a name is from based on the -spelling: - -.. code-block:: sh - - $ python predict.py Hinton - (-0.47) Scottish - (-1.52) English - (-3.57) Irish - - $ python predict.py Schmidhuber - (-0.19) German - (-2.48) Czech - (-2.68) Dutch - +spelling. Recommended Preparation ======================= @@ -61,79 +48,62 @@ Networks `__ is about LSTMs specifically but also informative about RNNs in general +""" +###################################################################### +# Preparing Torch +# ========================== +# +# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA). +# -Preparing the Data -================== +import torch -.. note:: - Download the data from - `here `_ - and extract it to the current directory. +# Check if CUDA is available +device = torch.device('cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') -Included in the ``data/names`` directory are 18 text files named as -``[Language].txt``. Each file contains a bunch of names, one name per -line, mostly romanized (but we still need to convert from Unicode to -ASCII). +torch.set_default_device(device) +print(f"Using device = {torch.get_default_device()}") -We'll end up with a dictionary of lists of names per language, -``{language: [names ...]}``. The generic variables "category" and "line" -(for language and name in our case) are used for later extensibility. -""" -from io import open -import glob -import os - -def findFiles(path): return glob.glob(path) - -print(findFiles('data/names/*.txt')) +###################################################################### +# Preparing the Data +# ================== +# +# Download the data from `here `__ +# and extract it to the current directory. +# +# Included in the ``data/names`` directory are 18 text files named as +# ``[Language].txt``. Each file contains a bunch of names, one name per +# line, mostly romanized (but we still need to convert from Unicode to +# ASCII). +# +# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to +# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters. +import string import unicodedata -import string -all_letters = string.ascii_letters + " .,;'" -n_letters = len(all_letters) +allowed_characters = string.ascii_letters + " .,;'" +n_letters = len(allowed_characters) -# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 +# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 def unicodeToAscii(s): return ''.join( c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn' - and c in all_letters + and c in allowed_characters ) -print(unicodeToAscii('Ślusàrski')) - -# Build the category_lines dictionary, a list of names per language -category_lines = {} -all_categories = [] - -# Read a file and split into lines -def readLines(filename): - lines = open(filename, encoding='utf-8').read().strip().split('\n') - return [unicodeToAscii(line) for line in lines] - -for filename in findFiles('data/names/*.txt'): - category = os.path.splitext(os.path.basename(filename))[0] - all_categories.append(category) - lines = readLines(filename) - category_lines[category] = lines - -n_categories = len(all_categories) - - -###################################################################### -# Now we have ``category_lines``, a dictionary mapping each category -# (language) to a list of lines (names). We also kept track of -# ``all_categories`` (just a list of languages) and ``n_categories`` for -# later reference. +######################### +# Here's an example of converting a unicode alphabet name to plain ASCII. This simplifies the input layer # -print(category_lines['Italian'][:5]) - +print (f"converting 'Ślusàrski' to {unicodeToAscii('Ślusàrski')}") ###################################################################### # Turning Names into Tensors -# -------------------------- +# ========================== # # Now that we have all the names organized, we need to turn them into # Tensors to make any use of them. @@ -147,19 +117,10 @@ def readLines(filename): # # That extra 1 dimension is because PyTorch assumes everything is in # batches - we're just using a batch size of 1 here. -# - -import torch # Find letter index from all_letters, e.g. "a" = 0 def letterToIndex(letter): - return all_letters.find(letter) - -# Just for demonstration, turn a letter into a <1 x n_letters> Tensor -def letterToTensor(letter): - tensor = torch.zeros(1, n_letters) - tensor[0][letterToIndex(letter)] = 1 - return tensor + return allowed_characters.find(letter) # Turn a line into a , # or an array of one-hot letter vectors @@ -169,9 +130,87 @@ def lineToTensor(line): tensor[li][0][letterToIndex(letter)] = 1 return tensor -print(letterToTensor('J')) +######################### +# Here are some examples of how to use ``lineToTensor()`` for a single and multiple character string. + +print (f"The letter 'a' becomes {lineToTensor('a')}") #notice that the first position in the tensor = 1 +print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1 + +######################### +# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach +# for other RNN tasks with text. +# +# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this, +# we will use the `Dataset and DataLoader `__ classes +# to hold our dataset. Each Dataset needs to implement three functions: ``__init__``, ``__len__``, and ``__getitem__``. +from io import open +import glob +import os +import time + +import torch +from torch.utils.data import Dataset + +class NamesDataset(Dataset): + + def __init__(self, data_dir): + self.data_dir = data_dir #for provenance of the dataset + self.load_time = time.localtime #for provenance of the dataset + labels_set = set() #set of all classes + + self.data = [] + self.data_tensors = [] + self.labels = [] + self.labels_tensors = [] + + #read all the ``.txt`` files in the specified directory + text_files = glob.glob(os.path.join(data_dir, '*.txt')) + for filename in text_files: + label = os.path.splitext(os.path.basename(filename))[0] + labels_set.add(label) + lines = open(filename, encoding='utf-8').read().strip().split('\n') + for name in lines: + self.data.append(name) + self.data_tensors.append(lineToTensor(name)) + self.labels.append(label) + + #Cache the tensor representation of the labels + self.labels_uniq = list(labels_set) + for idx in range(len(self.labels)): + temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long) + self.labels_tensors.append(temp_tensor) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + data_item = self.data[idx] + data_label = self.labels[idx] + data_tensor = self.data_tensors[idx] + label_tensor = self.labels_tensors[idx] + + return label_tensor, data_tensor, data_label, data_item -print(lineToTensor('Jones').size()) + +######################### +#Here we can load our example data into the ``NamesDataset`` + +alldata = NamesDataset("data/names") +print(f"loaded {len(alldata)} items of data") +print(f"example = {alldata[0]}") + +######################### +#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20 +# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the +#same device as PyTorch defaults to above. + +train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024)) + +print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}") + +######################### +# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also +#split the dataset into training and testing so we can validate the model that we build. ###################################################################### @@ -183,114 +222,58 @@ def lineToTensor(line): # held hidden state and gradients which are now entirely handled by the # graph itself. This means you can implement a RNN in a very "pure" way, # as regular feed-forward layers. -# -# This RNN module implements a "vanilla RNN" an is just 3 linear layers -# which operate on an input and hidden state, with a ``LogSoftmax`` layer -# after the output. +# +# This CharRNN class implements an RNN with three components. +# First, we use the `nn.RNN implementation `__. +# Next, we define a layer that maps the RNN hidden layers to our output. And finally, we apply a ``softmax`` function. Using ``nn.RNN`` +# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing +# each layer as a ``nn.Linear``. It also simplifies the implementation in ``forward()``. # import torch.nn as nn import torch.nn.functional as F -class RNN(nn.Module): +class CharRNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): - super(RNN, self).__init__() + super(CharRNN, self).__init__() - self.hidden_size = hidden_size - - self.i2h = nn.Linear(input_size, hidden_size) - self.h2h = nn.Linear(hidden_size, hidden_size) + self.rnn = nn.RNN(input_size, hidden_size) self.h2o = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) - - def forward(self, input, hidden): - hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) - output = self.h2o(hidden) + + def forward(self, line_tensor): + rnn_out, hidden = self.rnn(line_tensor) + output = self.h2o(hidden[0]) output = self.softmax(output) - return output, hidden - def initHidden(self): - return torch.zeros(1, self.hidden_size) + return output -n_hidden = 128 -rnn = RNN(n_letters, n_hidden, n_categories) - - -###################################################################### -# To run a step of this network we need to pass an input (in our case, the -# Tensor for the current letter) and a previous hidden state (which we -# initialize as zeros at first). We'll get back the output (probability of -# each language) and a next hidden state (which we keep for the next -# step). -# -input = letterToTensor('A') -hidden = torch.zeros(1, n_hidden) - -output, next_hidden = rnn(input, hidden) +########################### +# We can then create an RNN with 57 input nodes, 128 hidden nodes, and 18 outputs: +n_hidden = 128 +rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq)) +print(rnn) ###################################################################### -# For the sake of efficiency we don't want to be creating a new Tensor for -# every step, so we will use ``lineToTensor`` instead of -# ``letterToTensor`` and use slices. This could be further optimized by -# precomputing batches of Tensors. -# - -input = lineToTensor('Albert') -hidden = torch.zeros(1, n_hidden) - -output, next_hidden = rnn(input[0], hidden) -print(output) - +# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently, +# we use a helper function, ``label_from_output``, to derive a text label for the class. -###################################################################### -# As you can see the output is a ``<1 x n_categories>`` Tensor, where -# every item is the likelihood of that category (higher is more likely). -# +def label_from_output(output, output_labels): + top_n, top_i = output.topk(1) + label_i = top_i[0].item() + return output_labels[label_i], label_i +input = lineToTensor('Albert') +output = rnn(input) #this is equivalent to ``output = rnn.forward(input)`` +print(output) +print(label_from_output(output, alldata.labels_uniq)) ###################################################################### # # Training # ======== -# Preparing for Training -# ---------------------- -# -# Before going into training we should make a few helper functions. The -# first is to interpret the output of the network, which we know to be a -# likelihood of each category. We can use ``Tensor.topk`` to get the index -# of the greatest value: -# - -def categoryFromOutput(output): - top_n, top_i = output.topk(1) - category_i = top_i[0].item() - return all_categories[category_i], category_i - -print(categoryFromOutput(output)) - - -###################################################################### -# We will also want a quick way to get a training example (a name and its -# language): -# - -import random - -def randomChoice(l): - return l[random.randint(0, len(l) - 1)] - -def randomTrainingExample(): - category = randomChoice(all_categories) - line = randomChoice(category_lines[category]) - category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long) - line_tensor = lineToTensor(line) - return category, line, category_tensor, line_tensor - -for i in range(10): - category, line, category_tensor, line_tensor = randomTrainingExample() - print('category =', category, '/ line =', line) ###################################################################### @@ -300,93 +283,67 @@ def randomTrainingExample(): # Now all it takes to train this network is show it a bunch of examples, # have it make guesses, and tell it if it's wrong. # -# For the loss function ``nn.NLLLoss`` is appropriate, since the last -# layer of the RNN is ``nn.LogSoftmax``. -# - -criterion = nn.NLLLoss() - - -###################################################################### -# Each loop of training will: -# -# - Create input and target tensors -# - Create a zeroed initial hidden state -# - Read each letter in and -# -# - Keep hidden state for next letter -# -# - Compare final output to target -# - Back-propagate -# - Return the output and loss -# - -learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn - -def train(category_tensor, line_tensor): - hidden = rnn.initHidden() - - rnn.zero_grad() - - for i in range(line_tensor.size()[0]): - output, hidden = rnn(line_tensor[i], hidden) - - loss = criterion(output, category_tensor) - loss.backward() - - # Add parameters' gradients to their values, multiplied by learning rate - for p in rnn.parameters(): - p.data.add_(p.grad.data, alpha=-learning_rate) - - return output, loss.item() - - -###################################################################### -# Now we just have to run that with a bunch of examples. Since the -# ``train`` function returns both the output and loss we can print its -# guesses and also keep track of loss for plotting. Since there are 1000s -# of examples we print only every ``print_every`` examples, and take an -# average of the loss. -# - -import time -import math - -n_iters = 100000 -print_every = 5000 -plot_every = 1000 - - - -# Keep track of losses for plotting -current_loss = 0 -all_losses = [] +# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs +# RNNs are trained similarly to other networks; therefore, for completeness, we include a batched training method here. +# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the +# weights. This operation is repeated until the number of epochs is reached. + +import random +import numpy as np + +def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()): + """ + Learn on a batch of training_data for a specified number of iterations and reporting thresholds + """ + # Keep track of losses for plotting + current_loss = 0 + all_losses = [] + rnn.train() + optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate) + + start = time.time() + print(f"training on data set with n = {len(training_data)}") + + for iter in range(1, n_epoch + 1): + rnn.zero_grad() # clear the gradients + + # create some minibatches + # we cannot use dataloaders because each of our names is a different length + batches = list(range(len(training_data))) + random.shuffle(batches) + batches = np.array_split(batches, len(batches) //n_batch_size ) + + for idx, batch in enumerate(batches): + batch_loss = 0 + for i in batch: #for each example in this batch + (label_tensor, text_tensor, label, text) = training_data[i] + output = rnn.forward(text_tensor) + loss = criterion(output, label_tensor) + batch_loss += loss + + # optimize parameters + batch_loss.backward() + nn.utils.clip_grad_norm_(rnn.parameters(), 3) + optimizer.step() + optimizer.zero_grad() + + current_loss += batch_loss.item() / len(batch) + + all_losses.append(current_loss / len(batches) ) + if iter % report_every == 0: + print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}") + current_loss = 0 + + return all_losses -def timeSince(since): - now = time.time() - s = now - since - m = math.floor(s / 60) - s -= m * 60 - return '%dm %ds' % (m, s) +########################################################################## +# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this +# example is reduced to speed up the build. You can get better results with different parameters. start = time.time() - -for iter in range(1, n_iters + 1): - category, line, category_tensor, line_tensor = randomTrainingExample() - output, loss = train(category_tensor, line_tensor) - current_loss += loss - - # Print ``iter`` number, loss, name and guess - if iter % print_every == 0: - guess, guess_i = categoryFromOutput(output) - correct = '✓' if guess == category else '✗ (%s)' % category - print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct)) - - # Add current loss avg to list of losses - if iter % plot_every == 0: - all_losses.append(current_loss / plot_every) - current_loss = 0 - +all_losses = train(rnn, train_set, n_epoch=27, learning_rate=0.15, report_every=5) +end = time.time() +print(f"training took {end-start}s") ###################################################################### # Plotting the Results @@ -401,7 +358,7 @@ def timeSince(since): plt.figure() plt.plot(all_losses) - +plt.show() ###################################################################### # Evaluating the Results @@ -414,48 +371,45 @@ def timeSince(since): # ``evaluate()``, which is the same as ``train()`` minus the backprop. # -# Keep track of correct guesses in a confusion matrix -confusion = torch.zeros(n_categories, n_categories) -n_confusion = 10000 - -# Just return an output given a line -def evaluate(line_tensor): - hidden = rnn.initHidden() - - for i in range(line_tensor.size()[0]): - output, hidden = rnn(line_tensor[i], hidden) +def evaluate(rnn, testing_data, classes): + confusion = torch.zeros(len(classes), len(classes)) + + rnn.eval() #set to eval mode + with torch.no_grad(): # do not record the gradients during eval phase + for i in range(len(testing_data)): + (label_tensor, text_tensor, label, text) = testing_data[i] + output = rnn(text_tensor) + guess, guess_i = label_from_output(output, classes) + label_i = classes.index(label) + confusion[label_i][guess_i] += 1 - return output + # Normalize by dividing every row by its sum + for i in range(len(classes)): + denom = confusion[i].sum() + if denom > 0: + confusion[i] = confusion[i] / denom -# Go through a bunch of examples and record which are correctly guessed -for i in range(n_confusion): - category, line, category_tensor, line_tensor = randomTrainingExample() - output = evaluate(line_tensor) - guess, guess_i = categoryFromOutput(output) - category_i = all_categories.index(category) - confusion[category_i][guess_i] += 1 + # Set up plot + fig = plt.figure() + ax = fig.add_subplot(111) + cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version + fig.colorbar(cax) -# Normalize by dividing every row by its sum -for i in range(n_categories): - confusion[i] = confusion[i] / confusion[i].sum() + # Set up axes + ax.set_xticks(np.arange(len(classes)), labels=classes, rotation=90) + ax.set_yticks(np.arange(len(classes)), labels=classes) -# Set up plot -fig = plt.figure() -ax = fig.add_subplot(111) -cax = ax.matshow(confusion.numpy()) -fig.colorbar(cax) + # Force label at every tick + ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) + ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) -# Set up axes -ax.set_xticklabels([''] + all_categories, rotation=90) -ax.set_yticklabels([''] + all_categories) + # sphinx_gallery_thumbnail_number = 2 + plt.show() -# Force label at every tick -ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) -ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) -# sphinx_gallery_thumbnail_number = 2 -plt.show() +evaluate(rnn, test_set, classes=alldata.labels_uniq) + ###################################################################### # You can pick out bright spots off the main axis that show which @@ -465,72 +419,20 @@ def evaluate(line_tensor): # -###################################################################### -# Running on User Input -# --------------------- -# - -def predict(input_line, n_predictions=3): - print('\n> %s' % input_line) - with torch.no_grad(): - output = evaluate(lineToTensor(input_line)) - - # Get top N categories - topv, topi = output.topk(n_predictions, 1, True) - predictions = [] - - for i in range(n_predictions): - value = topv[0][i].item() - category_index = topi[0][i].item() - print('(%.2f) %s' % (value, all_categories[category_index])) - predictions.append([value, all_categories[category_index]]) - -predict('Dovesky') -predict('Jackson') -predict('Satoshi') - - -###################################################################### -# The final versions of the scripts `in the Practical PyTorch -# repo `__ -# split the above code into a few files: -# -# - ``data.py`` (loads files) -# - ``model.py`` (defines the RNN) -# - ``train.py`` (runs training) -# - ``predict.py`` (runs ``predict()`` with command line arguments) -# - ``server.py`` (serve prediction as a JSON API with ``bottle.py``) -# -# Run ``train.py`` to train and save the network. -# -# Run ``predict.py`` with a name to view predictions: -# -# .. code-block:: sh -# -# $ python predict.py Hazaki -# (-0.42) Japanese -# (-1.39) Polish -# (-3.51) Czech -# -# Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON -# output of predictions. -# - - ###################################################################### # Exercises # ========= # -# - Try with a different dataset of line -> category, for example: -# -# - Any word -> language -# - First name -> gender -# - Character name -> writer -# - Page title -> blog or subreddit -# # - Get better results with a bigger and/or better shaped network # -# - Add more linear layers +# - Adjust the hyperparameters to enhance performance, such as changing the number of epochs, batch size, and learning rate # - Try the ``nn.LSTM`` and ``nn.GRU`` layers +# - Modify the size of the layers, such as increasing or decreasing the number of hidden nodes or adding additional linear layers # - Combine multiple of these RNNs as a higher level network +# +# - Try with a different dataset of line -> label, for example: # +# - Any word -> language +# - First name -> gender +# - Character name -> writer +# - Page title -> blog or subreddit \ No newline at end of file