From 1393b6a682c635faf8e8d90752d5c918e9d56a3b Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 23 Oct 2020 17:46:44 -0400 Subject: [PATCH 1/5] add new speech tutorial. --- ...ech_command_recognition_with_torchaudio.py | 438 ++++++++++++++++++ 1 file changed, 438 insertions(+) create mode 100644 intermediate_source/speech_command_recognition_with_torchaudio.py diff --git a/intermediate_source/speech_command_recognition_with_torchaudio.py b/intermediate_source/speech_command_recognition_with_torchaudio.py new file mode 100644 index 00000000000..7d691be44ff --- /dev/null +++ b/intermediate_source/speech_command_recognition_with_torchaudio.py @@ -0,0 +1,438 @@ +""" +Speech Command Recognition with torchaudio +========================================== + +This tutorial will show you how to correctly format an audio dataset and +then train/test an audio classifier network on the dataset. First, let’s +import the common torch packages such as +``torchaudio ``\ \_ and can be +installed by following the instructions on the website. + +""" + +# Uncomment to run in Google Colab +# !pip install torch +# !pip install torchaudio + +import os + +import IPython.display as ipd +import matplotlib.pyplot as plt +from tqdm.notebook import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchaudio +from torchaudio.datasets import SPEECHCOMMANDS + +###################################################################### +# Let’s check if a CUDA GPU is available and select our device. Running +# the network on a GPU will greatly decrease the training/testing runtime. +# + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(device) + + +###################################################################### +# Importing the Dataset +# --------------------- +# +# We use torchaudio to download and represent the dataset. Here we use +# SpeechCommands, which is a datasets of 35 commands spoken by different +# people. The dataset ``SPEECHCOMMANDS`` is a ``torch.utils.data.Dataset`` +# version of the dataset. +# +# The actual loading and formatting steps happen in the access function +# ``__getitem__``. In ``__getitem__``, we use ``torchaudio.load()`` to +# convert the audio files to tensors. ``torchaudio.load()`` returns a +# tuple containing the newly created tensor along with the sampling +# frequency of the audio file (16kHz for SpeechCommands). In this dataset, +# all audio files are about 1 second long (and so about 16000 time frames +# long). +# +# Here we wrap it to split it into standard training, validation, testing +# subsets. +# + + +class SubsetSC(SPEECHCOMMANDS): + def __init__(self, subset: str = None): + super().__init__("./", download=True) + + if subset in ["training", "validation"]: + filepath = os.path.join(self._path, "validation_list.txt") + with open(filepath) as f: + validation_list = [ + os.path.join(self._path, l.strip()) for l in f.readlines() + ] + + if subset in ["training", "testing"]: + filepath = os.path.join(self._path, "testing_list.txt") + with open(filepath) as f: + testing_list = [ + os.path.join(self._path, l.strip()) for l in f.readlines() + ] + + if subset == "validation": + walker = validation_list + elif subset == "testing": + walker = testing_list + elif subset in ["training", None]: + walker = self._walker # defined by SPEECHCOMMANDS parent class + else: + raise ValueError( + "When `subset` not None, it must take a value from {'training', 'validation', 'testing'}." + ) + + if subset == "training": + walker = filter( + lambda w: not (w in validation_list or w in testing_list), walker + ) + + self._walker = list(walker) + + +train_set = SubsetSC("training") +# valid_set = SubsetSC("validation") +test_set = SubsetSC("testing") + + +waveform, sample_rate, label, speaker_id, utterance_number = train_set[0] + + +###################################################################### +# A data point in the SPEECHCOMMANDS dataset is a tuple made of a waveform +# (the audio signal), the sample rate, the utterance (label), the ID of +# the speaker, the number of the utterance. +# + +print("Shape of waveform: {}".format(waveform.size())) +print("Sample rate of waveform: {}".format(sample_rate)) + +plt.figure() +plt.plot(waveform.t().numpy()) + + +###################################################################### +# Let’s find the list of labels available in the dataset. +# + +labels = list(set(datapoint[2] for datapoint in train_set)) +labels + + +###################################################################### +# The 35 audio labels are commands that are said by users. The first few +# files are people saying “marvin”. +# + +waveform_first, *_ = train_set[0] +ipd.Audio(waveform_first.numpy(), rate=sample_rate) + +waveform_second, *_ = train_set[1] +ipd.Audio(waveform_second.numpy(), rate=sample_rate) + + +###################################################################### +# The last file is someone saying “visual”. +# + +waveform_last, *_ = train_set[-1] +ipd.Audio(waveform_last.numpy(), rate=sample_rate) + + +###################################################################### +# Formatting the Data +# ------------------- +# +# The dataset uses a single channel for audio. We do not need to down mix +# the audio channels (which we could do for instance by either taking the +# mean along the channel dimension, or simply keeping only one of the +# channels). +# + + +###################################################################### +# We downsample the audio for faster processing without losing too much of +# the classification power. +# + +new_sample_rate = 8000 +transform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=new_sample_rate +) +transformed = transform(waveform) + +ipd.Audio(transformed.numpy(), rate=new_sample_rate) + + +###################################################################### +# To encode each word, we use a simple language model where we represent +# each fo the 35 words by its corresponding position of the command in the +# list above. +# + + +def encode(word): + return torch.tensor(labels.index(word)) + + +encode("yes") + + +###################################################################### +# We now define a collate function that assembles a list of audio +# recordings and a list of utterances into two batched tensors. In this +# function, we also apply the resampling, and the encoding. The collate +# function is used in the pytroch data loader that allow us to iterate +# over a dataset by batches. +# + + +def pad_sequence(batch): + # Make all tensor in a batch the same length by padding with zeros + batch = [item.t() for item in batch] + batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.0) + return batch.permute(0, 2, 1) + + +def collate_fn(batch): + + # A data tuple has the form: + # waveform, sample_rate, label, speaker_id, utterance_number + # and so we are only interested in item 0 and 2 + + # Apply transforms to waveforms + tensors = [transform(b[0]) for b in batch] + tensors = pad_sequence(tensors) + + # Apply transform to target utterance + targets = [encode(b[2]) for b in batch] + targets = torch.stack(targets) + + return tensors, targets + + +kwargs = ( + {"num_workers": 1, "pin_memory": True} if device == "cuda" else {} +) # needed to run on gpu + +train_loader = torch.utils.data.DataLoader( + train_set, batch_size=128, shuffle=True, collate_fn=collate_fn, **kwargs +) +test_loader = torch.utils.data.DataLoader( + test_set, batch_size=128, shuffle=False, collate_fn=collate_fn, **kwargs +) + + +###################################################################### +# Define the Network +# ------------------ +# +# For this tutorial we will use a convolutional neural network to process +# the raw audio data. Usually more advanced transforms are applied to the +# audio data, however CNNs can be used to accurately process the raw data. +# The specific architecture is modeled after the M5 network architecture +# described in https://arxiv.org/pdf/1610.00087.pdf. An important aspect +# of models processing raw audio data is the receptive field of their +# first layer’s filters. Our model’s first filter is length 80 so when +# processing audio sampled at 8kHz the receptive field is around 10ms. +# This size is similar to speech processing applications that often use +# receptive fields ranging from 20ms to 40ms. +# + + +class Net(nn.Module): + def __init__(self, n_output=10): + super(Net, self).__init__() + self.conv1 = nn.Conv1d(1, 128, 80, 4) + self.bn1 = nn.BatchNorm1d(128) + self.pool1 = nn.MaxPool1d(4) + self.conv2 = nn.Conv1d(128, 128, 3) + self.bn2 = nn.BatchNorm1d(128) + self.pool2 = nn.MaxPool1d(4) + self.conv3 = nn.Conv1d(128, 256, 3) + self.bn3 = nn.BatchNorm1d(256) + self.pool3 = nn.MaxPool1d(4) + self.conv4 = nn.Conv1d(256, 512, 3) + self.bn4 = nn.BatchNorm1d(512) + self.pool4 = nn.MaxPool1d(4) + self.fc1 = nn.Linear(512, n_output) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.pool1(x) + x = self.conv2(x) + x = F.relu(self.bn2(x)) + x = self.pool2(x) + x = self.conv3(x) + x = F.relu(self.bn3(x)) + x = self.pool3(x) + x = self.conv4(x) + x = F.relu(self.bn4(x)) + x = self.pool4(x) + x = F.avg_pool1d( + x, x.shape[-1] + ) # input should be 512x14 so this outputs a 512x1 + x = x.permute(0, 2, 1) # change the 512x1 to 1x512 + x = self.fc1(x) + return F.log_softmax(x, dim=2) + + +model = Net(n_output=len(labels)) +model.to(device) +print(model) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +n = count_parameters(model) +print("Number of parameters: %s" % n) + + +###################################################################### +# We will use the same optimization technique used in the paper, an Adam +# optimizer with weight decay set to 0.0001. At first, we will train with +# a learning rate of 0.01, but we will use a ``scheduler`` to decrease it +# to 0.001 during training. +# + +optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001) +scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) + + +###################################################################### +# Training and Testing the Network +# -------------------------------- +# +# Now let’s define a training function that will feed our training data +# into the model and perform the backward pass and optimization steps. +# +# Finally, we can train and test the network. We will train the network +# for ten epochs then reduce the learn rate and train for ten more epochs. +# The network will be tested after each epoch to see how the accuracy +# varies during the training. +# + + +def nll_loss(tensor, target): + # negative log-likelihood for a tensor of size (batch x 1 x n_output) + return F.nll_loss(tensor.squeeze(), target) + + +def train(model, epoch, log_interval): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + + data = data.to(device) + target = target.to(device) + + output = model(data) + loss = nll_loss(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # print training stats + if batch_idx % log_interval == 0: + print( + f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}" + ) + + if "pbar" in globals(): + pbar.update() + + +###################################################################### +# Now that we have a training function, we need to make one for testing +# the networks accuracy. We will set the model to ``eval()`` mode and then +# run inference on the test dataset. Calling ``eval()`` sets the training +# variable in all modules in the network to false. Certain layers like +# batch normalization and dropout layers behave differently during +# training so this step is crucial for getting correct results. +# + + +def argmax(tensor): + # index of the max log-probability + return tensor.max(-1)[1] + + +def number_of_correct(pred, target): + # compute number of correct predictions + return pred.squeeze().eq(target).cpu().sum().item() + + +def test(model, epoch): + model.eval() + correct = 0 + for data, target in test_loader: + data = data.to(device) + target = target.to(device) + + output = model(data) + pred = argmax(output) + correct += number_of_correct(pred, target) + + if "pbar" in globals(): + pbar.update() + + print( + f"\nTest set: Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n" + ) + + +###################################################################### +# Finally, we can train and test the network. We will train the network +# for ten epochs then reduce the learn rate and train for ten more epochs. +# The network will be tested after each epoch to see how the accuracy +# varies during the training. +# + +log_interval = 20 +n_epoch = 2 + +with tqdm(total=n_epoch * (len(train_loader) + len(test_loader))) as pbar: + for epoch in range(1, n_epoch + 1): + train(model, epoch, log_interval) + test(model, epoch) + scheduler.step() + + +###################################################################### +# Let’s try looking at one of the last words in the train and test set. +# + +waveform, sample_rate, utterance, *_ = train_set[-1] +ipd.Audio(waveform.numpy(), rate=sample_rate) +output = model(waveform.unsqueeze(0)) +output = argmax(output).squeeze() +print(f"Expected: {utterance}. Predicted: {labels[output]}.") + +waveform, sample_rate, utterance, *_ = test_set[-1] +ipd.Audio(waveform.numpy(), rate=sample_rate) +output = model(waveform.unsqueeze(0)) +output = argmax(output).squeeze() +print(f"Expected: {utterance}. Predicted: {labels[output]}.") + + +###################################################################### +# Conclusion +# ---------- +# +# After one epoch, the network should be more than 65% accurate. +# +# In this tutorial, we used torchaudio to load a dataset and resample the +# signal. We have then defined a neural network that we trained to +# recognize a given command. There are also other data preprocessing +# methods, such as finding the mel frequency cepstral coefficients (MFCC), +# that can reduce the size of the dataset. This transform is also +# available in torchaudio as ``torchaudio.transforms.MFCC``. +# From f42ff8fd3cd979a47b10d69fe857a4b4044434f5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 23 Oct 2020 23:40:51 -0400 Subject: [PATCH 2/5] update with a few parameter tuned. model takes less than 10 min to run now. --- ...ech_command_recognition_with_torchaudio.py | 80 ++++++++++--------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/intermediate_source/speech_command_recognition_with_torchaudio.py b/intermediate_source/speech_command_recognition_with_torchaudio.py index 7d691be44ff..d0397de2e90 100644 --- a/intermediate_source/speech_command_recognition_with_torchaudio.py +++ b/intermediate_source/speech_command_recognition_with_torchaudio.py @@ -10,7 +10,6 @@ """ -# Uncomment to run in Google Colab # !pip install torch # !pip install torchaudio @@ -66,14 +65,14 @@ def __init__(self, subset: str = None): filepath = os.path.join(self._path, "validation_list.txt") with open(filepath) as f: validation_list = [ - os.path.join(self._path, l.strip()) for l in f.readlines() + os.path.join(self._path, line.strip()) for line in f.readlines() ] if subset in ["training", "testing"]: filepath = os.path.join(self._path, "testing_list.txt") with open(filepath) as f: testing_list = [ - os.path.join(self._path, l.strip()) for l in f.readlines() + os.path.join(self._path, line.strip()) for line in f.readlines() ] if subset == "validation": @@ -216,15 +215,16 @@ def collate_fn(batch): return tensors, targets +batch_size = 128 + kwargs = ( {"num_workers": 1, "pin_memory": True} if device == "cuda" else {} -) # needed to run on gpu - +) # needed for using datasets on gpu train_loader = torch.utils.data.DataLoader( - train_set, batch_size=128, shuffle=True, collate_fn=collate_fn, **kwargs + train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, **kwargs ) test_loader = torch.utils.data.DataLoader( - test_set, batch_size=128, shuffle=False, collate_fn=collate_fn, **kwargs + test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, **kwargs ) @@ -236,31 +236,32 @@ def collate_fn(batch): # the raw audio data. Usually more advanced transforms are applied to the # audio data, however CNNs can be used to accurately process the raw data. # The specific architecture is modeled after the M5 network architecture -# described in https://arxiv.org/pdf/1610.00087.pdf. An important aspect -# of models processing raw audio data is the receptive field of their -# first layer’s filters. Our model’s first filter is length 80 so when -# processing audio sampled at 8kHz the receptive field is around 10ms. -# This size is similar to speech processing applications that often use -# receptive fields ranging from 20ms to 40ms. +# described in ``this paper ``\ \_. +# An important aspect of models processing raw audio data is the receptive +# field of their first layer’s filters. Our model’s first filter is length +# 80 so when processing audio sampled at 8kHz the receptive field is +# around 10ms (and at 4kHz, around 20 ms). This size is similar to speech +# processing applications that often use receptive fields ranging from +# 20ms to 40ms. # -class Net(nn.Module): - def __init__(self, n_output=10): - super(Net, self).__init__() - self.conv1 = nn.Conv1d(1, 128, 80, 4) - self.bn1 = nn.BatchNorm1d(128) +class M5(nn.Module): + def __init__(self, stride=16, n_channel=32, n_output=35): + super().__init__() + self.conv1 = nn.Conv1d(1, n_channel, 80, stride=stride) + self.bn1 = nn.BatchNorm1d(n_channel) self.pool1 = nn.MaxPool1d(4) - self.conv2 = nn.Conv1d(128, 128, 3) - self.bn2 = nn.BatchNorm1d(128) + self.conv2 = nn.Conv1d(n_channel, n_channel, 3) + self.bn2 = nn.BatchNorm1d(n_channel) self.pool2 = nn.MaxPool1d(4) - self.conv3 = nn.Conv1d(128, 256, 3) - self.bn3 = nn.BatchNorm1d(256) + self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, 3) + self.bn3 = nn.BatchNorm1d(2 * n_channel) self.pool3 = nn.MaxPool1d(4) - self.conv4 = nn.Conv1d(256, 512, 3) - self.bn4 = nn.BatchNorm1d(512) + self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, 3) + self.bn4 = nn.BatchNorm1d(2 * n_channel) self.pool4 = nn.MaxPool1d(4) - self.fc1 = nn.Linear(512, n_output) + self.fc1 = nn.Linear(2 * n_channel, n_output) def forward(self, x): x = self.conv1(x) @@ -275,15 +276,13 @@ def forward(self, x): x = self.conv4(x) x = F.relu(self.bn4(x)) x = self.pool4(x) - x = F.avg_pool1d( - x, x.shape[-1] - ) # input should be 512x14 so this outputs a 512x1 - x = x.permute(0, 2, 1) # change the 512x1 to 1x512 + x = F.avg_pool1d(x, x.shape[-1]) + x = x.permute(0, 2, 1) x = self.fc1(x) return F.log_softmax(x, dim=2) -model = Net(n_output=len(labels)) +model = M5(n_output=len(labels)) model.to(device) print(model) @@ -304,7 +303,9 @@ def count_parameters(model): # optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001) -scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) +scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=20, gamma=0.1 +) # reduce the learning after 20 epochs by a factor of 10 ###################################################################### @@ -321,11 +322,6 @@ def count_parameters(model): # -def nll_loss(tensor, target): - # negative log-likelihood for a tensor of size (batch x 1 x n_output) - return F.nll_loss(tensor.squeeze(), target) - - def train(model, epoch, log_interval): model.train() for batch_idx, (data, target) in enumerate(train_loader): @@ -334,7 +330,9 @@ def train(model, epoch, log_interval): target = target.to(device) output = model(data) - loss = nll_loss(output, target) + + # negative log-likelihood for a tensor of size (batch x 1 x n_output) + loss = F.nll_loss(output.squeeze(), target) optimizer.zero_grad() loss.backward() @@ -385,7 +383,7 @@ def test(model, epoch): pbar.update() print( - f"\nTest set: Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n" + f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n" ) @@ -412,12 +410,16 @@ def test(model, epoch): waveform, sample_rate, utterance, *_ = train_set[-1] ipd.Audio(waveform.numpy(), rate=sample_rate) + +waveform = transform(waveform) output = model(waveform.unsqueeze(0)) output = argmax(output).squeeze() print(f"Expected: {utterance}. Predicted: {labels[output]}.") waveform, sample_rate, utterance, *_ = test_set[-1] ipd.Audio(waveform.numpy(), rate=sample_rate) + +waveform = transform(waveform) output = model(waveform.unsqueeze(0)) output = argmax(output).squeeze() print(f"Expected: {utterance}. Predicted: {labels[output]}.") @@ -427,7 +429,7 @@ def test(model, epoch): # Conclusion # ---------- # -# After one epoch, the network should be more than 65% accurate. +# After two epochs, the network should be more than 70% accurate. # # In this tutorial, we used torchaudio to load a dataset and resample the # signal. We have then defined a neural network that we trained to From ac4c2f41d70319810a26bf2d06f4e02d9987f18c Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 26 Oct 2020 19:09:57 -0400 Subject: [PATCH 3/5] feedback. --- ...ech_command_recognition_with_torchaudio.py | 189 +++++++++--------- 1 file changed, 95 insertions(+), 94 deletions(-) diff --git a/intermediate_source/speech_command_recognition_with_torchaudio.py b/intermediate_source/speech_command_recognition_with_torchaudio.py index d0397de2e90..73a012fd468 100644 --- a/intermediate_source/speech_command_recognition_with_torchaudio.py +++ b/intermediate_source/speech_command_recognition_with_torchaudio.py @@ -10,6 +10,7 @@ """ +# Uncomment the following line to run in Google Colab # !pip install torch # !pip install torchaudio @@ -61,44 +62,24 @@ class SubsetSC(SPEECHCOMMANDS): def __init__(self, subset: str = None): super().__init__("./", download=True) - if subset in ["training", "validation"]: - filepath = os.path.join(self._path, "validation_list.txt") - with open(filepath) as f: - validation_list = [ - os.path.join(self._path, line.strip()) for line in f.readlines() - ] - - if subset in ["training", "testing"]: - filepath = os.path.join(self._path, "testing_list.txt") - with open(filepath) as f: - testing_list = [ - os.path.join(self._path, line.strip()) for line in f.readlines() - ] + def load_list(filename): + filepath = os.path.join(self._path, filename) + with open(filepath) as fileobj: + return [os.path.join(self._path, line.strip()) for line in fileobj] if subset == "validation": - walker = validation_list + self._walker = load_list("validation_list.txt") elif subset == "testing": - walker = testing_list - elif subset in ["training", None]: - walker = self._walker # defined by SPEECHCOMMANDS parent class - else: - raise ValueError( - "When `subset` not None, it must take a value from {'training', 'validation', 'testing'}." - ) - - if subset == "training": - walker = filter( - lambda w: not (w in validation_list or w in testing_list), walker - ) - - self._walker = list(walker) + self._walker = load_list("testing_list.txt") + elif subset == "training": + excludes = load_list("validation_list.txt") + load_list("testing_list.txt") + self._walker = [w for w in self._walker if w not in excludes] train_set = SubsetSC("training") # valid_set = SubsetSC("validation") test_set = SubsetSC("testing") - waveform, sample_rate, label, speaker_id, utterance_number = train_set[0] @@ -111,8 +92,8 @@ def __init__(self, subset: str = None): print("Shape of waveform: {}".format(waveform.size())) print("Sample rate of waveform: {}".format(sample_rate)) -plt.figure() -plt.plot(waveform.t().numpy()) +plt.figure(); +plt.plot(waveform.t().numpy()); ###################################################################### @@ -147,31 +128,26 @@ def __init__(self, subset: str = None): # Formatting the Data # ------------------- # -# The dataset uses a single channel for audio. We do not need to down mix -# the audio channels (which we could do for instance by either taking the -# mean along the channel dimension, or simply keeping only one of the -# channels). +# This is a good place to apply transformations to the data. For the +# waveform, we downsample the audio for faster processing without losing +# too much of the classification power. # - - -###################################################################### -# We downsample the audio for faster processing without losing too much of -# the classification power. +# We don’t need to apply other transformations here. It is common for some +# datasets though to have to reduce the number of channels (say from +# stereo to mono) by either taking the mean along the channel dimension, +# or simply keeping only one of the channels. Since SpeechCommands uses a +# single channel for audio, this is not needed here. # new_sample_rate = 8000 -transform = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=new_sample_rate -) +transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate) transformed = transform(waveform) ipd.Audio(transformed.numpy(), rate=new_sample_rate) ###################################################################### -# To encode each word, we use a simple language model where we represent -# each fo the 35 words by its corresponding position of the command in the -# list above. +# We are encoding each word using its index in the list of labels. # @@ -183,18 +159,21 @@ def encode(word): ###################################################################### -# We now define a collate function that assembles a list of audio -# recordings and a list of utterances into two batched tensors. In this -# function, we also apply the resampling, and the encoding. The collate -# function is used in the pytroch data loader that allow us to iterate -# over a dataset by batches. +# To turn a list of data point made of audio recordings and utterances +# into two batched tensors for the model, we implement a collate function +# which is used by the PyTorch DataLoader that allows us to iterate over a +# dataset by batches. Please see `the +# documentation `__ +# for more information about working with a collate function. +# +# In the collate function, we also apply the resampling, and the text +# encoding. # - def pad_sequence(batch): # Make all tensor in a batch the same length by padding with zeros batch = [item.t() for item in batch] - batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.0) + batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.) return batch.permute(0, 2, 1) @@ -202,14 +181,16 @@ def collate_fn(batch): # A data tuple has the form: # waveform, sample_rate, label, speaker_id, utterance_number - # and so we are only interested in item 0 and 2 - # Apply transforms to waveforms - tensors = [transform(b[0]) for b in batch] - tensors = pad_sequence(tensors) + tensors, targets = [], [] + + # Apply transform and encode + for waveform, _, label, *_ in batch: + tensors += [transform(waveform)] + targets += [encode(label)] - # Apply transform to target utterance - targets = [encode(b[2]) for b in batch] + # Group the list of tensors into a batched tensor + tensors = pad_sequence(tensors) targets = torch.stack(targets) return tensors, targets @@ -217,14 +198,18 @@ def collate_fn(batch): batch_size = 128 -kwargs = ( - {"num_workers": 1, "pin_memory": True} if device == "cuda" else {} -) # needed for using datasets on gpu +if device == 'cuda': + num_workers = 1 + pin_memory = True +else: + num_workers = 0 + pin_memory = False + train_loader = torch.utils.data.DataLoader( - train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, **kwargs + train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory, ) test_loader = torch.utils.data.DataLoader( - test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, **kwargs + test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory, ) @@ -255,13 +240,13 @@ def __init__(self, stride=16, n_channel=32, n_output=35): self.conv2 = nn.Conv1d(n_channel, n_channel, 3) self.bn2 = nn.BatchNorm1d(n_channel) self.pool2 = nn.MaxPool1d(4) - self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, 3) - self.bn3 = nn.BatchNorm1d(2 * n_channel) + self.conv3 = nn.Conv1d(n_channel, 2*n_channel, 3) + self.bn3 = nn.BatchNorm1d(2*n_channel) self.pool3 = nn.MaxPool1d(4) - self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, 3) - self.bn4 = nn.BatchNorm1d(2 * n_channel) + self.conv4 = nn.Conv1d(2*n_channel, 2*n_channel, 3) + self.bn4 = nn.BatchNorm1d(2*n_channel) self.pool4 = nn.MaxPool1d(4) - self.fc1 = nn.Linear(2 * n_channel, n_output) + self.fc1 = nn.Linear(2*n_channel, n_output) def forward(self, x): x = self.conv1(x) @@ -303,9 +288,7 @@ def count_parameters(model): # optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001) -scheduler = optim.lr_scheduler.StepLR( - optimizer, step_size=20, gamma=0.1 -) # reduce the learning after 20 epochs by a factor of 10 +scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) # reduce the learning after 20 epochs by a factor of 10 ###################################################################### @@ -340,11 +323,9 @@ def train(model, epoch, log_interval): # print training stats if batch_idx % log_interval == 0: - print( - f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}" - ) + print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}') - if "pbar" in globals(): + if 'pbar' in globals(): pbar.update() @@ -379,12 +360,10 @@ def test(model, epoch): pred = argmax(output) correct += number_of_correct(pred, target) - if "pbar" in globals(): - pbar.update() + if 'pbar' in globals(): + pbar.update() - print( - f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n" - ) + print(f'\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n') ###################################################################### @@ -398,38 +377,60 @@ def test(model, epoch): n_epoch = 2 with tqdm(total=n_epoch * (len(train_loader) + len(test_loader))) as pbar: - for epoch in range(1, n_epoch + 1): + for epoch in range(1, n_epoch+1): train(model, epoch, log_interval) test(model, epoch) scheduler.step() ###################################################################### -# Let’s try looking at one of the last words in the train and test set. +# Let’s look at the last words in the train set, and see how the model did +# on it. # +def predict(waveform): + # Take a waveform and use the model to predict + waveform = transform(waveform) + output = model(waveform.unsqueeze(0)) + output = argmax(output).squeeze() + output = labels[output] + return output + + waveform, sample_rate, utterance, *_ = train_set[-1] ipd.Audio(waveform.numpy(), rate=sample_rate) -waveform = transform(waveform) -output = model(waveform.unsqueeze(0)) -output = argmax(output).squeeze() -print(f"Expected: {utterance}. Predicted: {labels[output]}.") +print(f"Expected: {utterance}. Predicted: {predict(waveform)}.") -waveform, sample_rate, utterance, *_ = test_set[-1] -ipd.Audio(waveform.numpy(), rate=sample_rate) -waveform = transform(waveform) -output = model(waveform.unsqueeze(0)) -output = argmax(output).squeeze() -print(f"Expected: {utterance}. Predicted: {labels[output]}.") +###################################################################### +# Let’s find an example that isn’t classified correctly, if there is one. +# + +for i, (waveform, sample_rate, utterance, *_) in enumerate(test_set): + output = predict(waveform) + if output != utterance: + ipd.Audio(waveform.numpy(), rate=sample_rate) + print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.") + break +else: + print("All examples in this dataset were correctly classified!") + print("In this case, let's just look at the last data point") + ipd.Audio(waveform.numpy(), rate=sample_rate) + print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.") + + +###################################################################### +# Feel free to try with one of your own recordings! +# ###################################################################### # Conclusion # ---------- # -# After two epochs, the network should be more than 70% accurate. +# The network should be more than 70% accurate on the test set after 2 +# epochs, 80% after 14 epochs, and 85% after 21 epochs. # # In this tutorial, we used torchaudio to load a dataset and resample the # signal. We have then defined a neural network that we trained to From e4c641932286aa8ccd445cd9c4514598119aa764 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 4 Nov 2020 11:23:57 -0500 Subject: [PATCH 4/5] improve GPU performance. add interactive demo at the end. --- ...ech_command_recognition_with_torchaudio.py | 178 +++++++++++++----- 1 file changed, 130 insertions(+), 48 deletions(-) diff --git a/intermediate_source/speech_command_recognition_with_torchaudio.py b/intermediate_source/speech_command_recognition_with_torchaudio.py index 73a012fd468..854aa5cb028 100644 --- a/intermediate_source/speech_command_recognition_with_torchaudio.py +++ b/intermediate_source/speech_command_recognition_with_torchaudio.py @@ -3,18 +3,33 @@ ========================================== This tutorial will show you how to correctly format an audio dataset and -then train/test an audio classifier network on the dataset. First, let’s -import the common torch packages such as -``torchaudio ``\ \_ and can be +then train/test an audio classifier network on the dataset. + +Colab has GPU option available. In the menu tabs, select “Runtime” then +“Change runtime type”. In the pop-up that follows, you can choose GPU. +After the change, your runtime should automatically restart (which means +information from executed cells disappear). + +First, let’s import the common torch packages such as +``torchaudio ``\ \_ that can be installed by following the instructions on the website. """ # Uncomment the following line to run in Google Colab -# !pip install torch -# !pip install torchaudio + +# GPU: +# !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU: +# !pip install torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html + +# For interactive demo at the end: +# !pip install pydub import os +from base64 import b64decode +from io import BytesIO import IPython.display as ipd import matplotlib.pyplot as plt @@ -25,6 +40,8 @@ import torch.nn.functional as F import torch.optim as optim import torchaudio +from google.colab import output as colab_output +from pydub import AudioSegment from torchaudio.datasets import SPEECHCOMMANDS ###################################################################### @@ -73,11 +90,12 @@ def load_list(filename): self._walker = load_list("testing_list.txt") elif subset == "training": excludes = load_list("validation_list.txt") + load_list("testing_list.txt") + excludes = set(excludes) self._walker = [w for w in self._walker if w not in excludes] +# Create training and testing split of the data. We do not use validation in this tutorial. train_set = SubsetSC("training") -# valid_set = SubsetSC("validation") test_set = SubsetSC("testing") waveform, sample_rate, label, speaker_id, utterance_number = train_set[0] @@ -92,7 +110,6 @@ def load_list(filename): print("Shape of waveform: {}".format(waveform.size())) print("Sample rate of waveform: {}".format(sample_rate)) -plt.figure(); plt.plot(waveform.t().numpy()); @@ -100,7 +117,7 @@ def load_list(filename): # Let’s find the list of labels available in the dataset. # -labels = list(set(datapoint[2] for datapoint in train_set)) +labels = sorted(list(set(datapoint[2] for datapoint in train_set))) labels @@ -170,6 +187,7 @@ def encode(word): # encoding. # + def pad_sequence(batch): # Make all tensor in a batch the same length by padding with zeros batch = [item.t() for item in batch] @@ -184,9 +202,9 @@ def collate_fn(batch): tensors, targets = [], [] - # Apply transform and encode + # Gather in lists, and encode labels for waveform, _, label, *_ in batch: - tensors += [transform(waveform)] + tensors += [waveform] targets += [encode(label)] # Group the list of tensors into a batched tensor @@ -196,9 +214,9 @@ def collate_fn(batch): return tensors, targets -batch_size = 128 +batch_size = 256 -if device == 'cuda': +if device == "cuda": num_workers = 1 pin_memory = True else: @@ -206,10 +224,21 @@ def collate_fn(batch): pin_memory = False train_loader = torch.utils.data.DataLoader( - train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory, + train_set, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=pin_memory, ) test_loader = torch.utils.data.DataLoader( - test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory, + test_set, + batch_size=batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=pin_memory, ) @@ -232,21 +261,21 @@ def collate_fn(batch): class M5(nn.Module): - def __init__(self, stride=16, n_channel=32, n_output=35): + def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32): super().__init__() - self.conv1 = nn.Conv1d(1, n_channel, 80, stride=stride) + self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride) self.bn1 = nn.BatchNorm1d(n_channel) self.pool1 = nn.MaxPool1d(4) - self.conv2 = nn.Conv1d(n_channel, n_channel, 3) + self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3) self.bn2 = nn.BatchNorm1d(n_channel) self.pool2 = nn.MaxPool1d(4) - self.conv3 = nn.Conv1d(n_channel, 2*n_channel, 3) - self.bn3 = nn.BatchNorm1d(2*n_channel) + self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3) + self.bn3 = nn.BatchNorm1d(2 * n_channel) self.pool3 = nn.MaxPool1d(4) - self.conv4 = nn.Conv1d(2*n_channel, 2*n_channel, 3) - self.bn4 = nn.BatchNorm1d(2*n_channel) + self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3) + self.bn4 = nn.BatchNorm1d(2 * n_channel) self.pool4 = nn.MaxPool1d(4) - self.fc1 = nn.Linear(2*n_channel, n_output) + self.fc1 = nn.Linear(2 * n_channel, n_output) def forward(self, x): x = self.conv1(x) @@ -267,7 +296,7 @@ def forward(self, x): return F.log_softmax(x, dim=2) -model = M5(n_output=len(labels)) +model = M5(n_input=transformed.shape[0], n_output=len(labels)) model.to(device) print(model) @@ -284,7 +313,7 @@ def count_parameters(model): # We will use the same optimization technique used in the paper, an Adam # optimizer with weight decay set to 0.0001. At first, we will train with # a learning rate of 0.01, but we will use a ``scheduler`` to decrease it -# to 0.001 during training. +# to 0.001 during training after 20 epochs. # optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001) @@ -296,11 +325,9 @@ def count_parameters(model): # -------------------------------- # # Now let’s define a training function that will feed our training data -# into the model and perform the backward pass and optimization steps. -# -# Finally, we can train and test the network. We will train the network -# for ten epochs then reduce the learn rate and train for ten more epochs. -# The network will be tested after each epoch to see how the accuracy +# into the model and perform the backward pass and optimization steps. For +# training, the loss we will use is the negative log-likelihood. The +# network will then be tested after each epoch to see how the accuracy # varies during the training. # @@ -312,6 +339,8 @@ def train(model, epoch, log_interval): data = data.to(device) target = target.to(device) + # apply transform and model on whole batch directly on device + data = transform(data) output = model(data) # negative log-likelihood for a tensor of size (batch x 1 x n_output) @@ -323,10 +352,10 @@ def train(model, epoch, log_interval): # print training stats if batch_idx % log_interval == 0: - print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}') + print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}") - if 'pbar' in globals(): - pbar.update() + if "pbar" in globals() and "pbar_update" in globals(): + pbar.update(pbar_update) ###################################################################### @@ -346,24 +375,28 @@ def argmax(tensor): def number_of_correct(pred, target): # compute number of correct predictions - return pred.squeeze().eq(target).cpu().sum().item() + return pred.squeeze().eq(target).sum().item() def test(model, epoch): model.eval() correct = 0 for data, target in test_loader: + data = data.to(device) target = target.to(device) + # apply transform and model on whole batch directly on device + data = transform(data) output = model(data) + pred = argmax(output) correct += number_of_correct(pred, target) - if 'pbar' in globals(): - pbar.update() + if "pbar" in globals() and "pbar_update" in globals(): + pbar.update(pbar_update) - print(f'\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n') + print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n") ###################################################################### @@ -375,21 +408,28 @@ def test(model, epoch): log_interval = 20 n_epoch = 2 +pbar_update = 1 / (len(train_loader) + len(test_loader)) + +# The transform needs to live on the same device as the model and the data. +transform = transform.to(device) -with tqdm(total=n_epoch * (len(train_loader) + len(test_loader))) as pbar: - for epoch in range(1, n_epoch+1): +with tqdm(total=n_epoch) as pbar: + for epoch in range(1, n_epoch + 1): train(model, epoch, log_interval) test(model, epoch) scheduler.step() ###################################################################### -# Let’s look at the last words in the train set, and see how the model did -# on it. +# The network should be more than 65% accurate on the test set after 2 +# epochs, and 85% after 21 epochs. Let’s look at the last words in the +# train set, and see how the model did on it. # + def predict(waveform): - # Take a waveform and use the model to predict + # Use the model to predict the label of the waveform + waveform = waveform.to(device) waveform = transform(waveform) output = model(waveform.unsqueeze(0)) output = argmax(output).squeeze() @@ -410,9 +450,9 @@ def predict(waveform): for i, (waveform, sample_rate, utterance, *_) in enumerate(test_set): output = predict(waveform) if output != utterance: - ipd.Audio(waveform.numpy(), rate=sample_rate) - print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.") - break + ipd.Audio(waveform.numpy(), rate=sample_rate) + print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.") + break else: print("All examples in this dataset were correctly classified!") print("In this case, let's just look at the last data point") @@ -421,17 +461,59 @@ def predict(waveform): ###################################################################### -# Feel free to try with one of your own recordings! +# Feel free to try with one of your own recordings of one of the labels! +# For example, using Colab, say “Go” while executing the cell below. This +# will record one second of audio and try to classify it. # +RECORD = """ +const sleep = time => new Promise(resolve => setTimeout(resolve, time)) +const b2text = blob => new Promise(resolve => { + const reader = new FileReader() + reader.onloadend = e => resolve(e.srcElement.result) + reader.readAsDataURL(blob) +}) +var record = time => new Promise(async resolve => { + stream = await navigator.mediaDevices.getUserMedia({ audio: true }) + recorder = new MediaRecorder(stream) + chunks = [] + recorder.ondataavailable = e => chunks.push(e.data) + recorder.start() + await sleep(time) + recorder.onstop = async ()=>{ + blob = new Blob(chunks) + text = await b2text(blob) + resolve(text) + } + recorder.stop() +}) +""" + + +def record(seconds=1): + display(ipd.Javascript(RECORD)) + print(f"Recording started for {seconds} seconds.") + s = colab_output.eval_js("record(%d)" % (seconds * 1000)) + print("Recording ended.") + b = b64decode(s.split(",")[1]) + + fileformat = "wav" + filename = f"_audio.{fileformat}" + AudioSegment.from_file(BytesIO(b)).export(filename, format=fileformat) + + return torchaudio.load(filename) + + +waveform, sample_rate = record() +print(f"Predicted: {predict(waveform)}.") +ipd.Audio(waveform.numpy(), rate=sample_rate) + + ###################################################################### # Conclusion # ---------- # -# The network should be more than 70% accurate on the test set after 2 -# epochs, 80% after 14 epochs, and 85% after 21 epochs. -# # In this tutorial, we used torchaudio to load a dataset and resample the # signal. We have then defined a neural network that we trained to # recognize a given command. There are also other data preprocessing From d07e2929ede3db822cbd0fed57a702e60a602fd2 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 6 Nov 2020 13:59:54 -0500 Subject: [PATCH 5/5] feedback. --- ...ech_command_recognition_with_torchaudio.py | 130 ++++++++++-------- 1 file changed, 75 insertions(+), 55 deletions(-) diff --git a/intermediate_source/speech_command_recognition_with_torchaudio.py b/intermediate_source/speech_command_recognition_with_torchaudio.py index 854aa5cb028..8ef7b161a28 100644 --- a/intermediate_source/speech_command_recognition_with_torchaudio.py +++ b/intermediate_source/speech_command_recognition_with_torchaudio.py @@ -11,38 +11,32 @@ information from executed cells disappear). First, let’s import the common torch packages such as -``torchaudio ``\ \_ that can be -installed by following the instructions on the website. +`torchaudio `__ that can be installed +by following the instructions on the website. """ # Uncomment the following line to run in Google Colab -# GPU: -# !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html - # CPU: # !pip install torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html +# GPU: +# !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html + # For interactive demo at the end: # !pip install pydub -import os -from base64 import b64decode -from io import BytesIO - -import IPython.display as ipd -import matplotlib.pyplot as plt -from tqdm.notebook import tqdm - import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchaudio -from google.colab import output as colab_output -from pydub import AudioSegment -from torchaudio.datasets import SPEECHCOMMANDS + +import matplotlib.pyplot as plt +import IPython.display as ipd +from tqdm.notebook import tqdm + ###################################################################### # Let’s check if a CUDA GPU is available and select our device. Running @@ -58,22 +52,26 @@ # --------------------- # # We use torchaudio to download and represent the dataset. Here we use -# SpeechCommands, which is a datasets of 35 commands spoken by different -# people. The dataset ``SPEECHCOMMANDS`` is a ``torch.utils.data.Dataset`` -# version of the dataset. +# `SpeechCommands `__, which is a +# datasets of 35 commands spoken by different people. The dataset +# ``SPEECHCOMMANDS`` is a ``torch.utils.data.Dataset`` version of the +# dataset. In this dataset, all audio files are about 1 second long (and +# so about 16000 time frames long). # -# The actual loading and formatting steps happen in the access function -# ``__getitem__``. In ``__getitem__``, we use ``torchaudio.load()`` to -# convert the audio files to tensors. ``torchaudio.load()`` returns a -# tuple containing the newly created tensor along with the sampling -# frequency of the audio file (16kHz for SpeechCommands). In this dataset, -# all audio files are about 1 second long (and so about 16000 time frames -# long). +# The actual loading and formatting steps happen when a data point is +# being accessed, and torchaudio takes care of converting the audio files +# to tensors. If one wants to load an audio file directly instead, +# ``torchaudio.load()`` can be used. It returns a tuple containing the +# newly created tensor along with the sampling frequency of the audio file +# (16kHz for SpeechCommands). # -# Here we wrap it to split it into standard training, validation, testing -# subsets. +# Going back to the dataset, here we create a subclass that splits it into +# standard training, validation, testing subsets. # +from torchaudio.datasets import SPEECHCOMMANDS +import os + class SubsetSC(SPEECHCOMMANDS): def __init__(self, subset: str = None): @@ -168,11 +166,22 @@ def load_list(filename): # -def encode(word): +def label_to_index(word): + # Return the position of the word in labels return torch.tensor(labels.index(word)) -encode("yes") +def index_to_label(index): + # Return the word corresponding to the index in labels + # This is the inverse of label_to_index + return labels[index] + + +word_start = "yes" +index = label_to_index(word_start) +word_recovered = index_to_label(index) + +print(word_start, "-->", index, "-->", word_recovered) ###################################################################### @@ -202,10 +211,10 @@ def collate_fn(batch): tensors, targets = [], [] - # Gather in lists, and encode labels + # Gather in lists, and encode labels as indices for waveform, _, label, *_ in batch: tensors += [waveform] - targets += [encode(label)] + targets += [label_to_index(label)] # Group the list of tensors into a batched tensor tensors = pad_sequence(tensors) @@ -250,8 +259,8 @@ def collate_fn(batch): # the raw audio data. Usually more advanced transforms are applied to the # audio data, however CNNs can be used to accurately process the raw data. # The specific architecture is modeled after the M5 network architecture -# described in ``this paper ``\ \_. -# An important aspect of models processing raw audio data is the receptive +# described in `this paper `__. An +# important aspect of models processing raw audio data is the receptive # field of their first layer’s filters. Our model’s first filter is length # 80 so when processing audio sampled at 8kHz the receptive field is # around 10ms (and at 4kHz, around 20 ms). This size is similar to speech @@ -352,10 +361,12 @@ def train(model, epoch, log_interval): # print training stats if batch_idx % log_interval == 0: - print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}") + print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}") - if "pbar" in globals() and "pbar_update" in globals(): - pbar.update(pbar_update) + # update progress bar + pbar.update(pbar_update) + # record loss + losses.append(loss.item()) ###################################################################### @@ -368,16 +379,16 @@ def train(model, epoch, log_interval): # -def argmax(tensor): - # index of the max log-probability - return tensor.max(-1)[1] - - def number_of_correct(pred, target): - # compute number of correct predictions + # count number of correct predictions return pred.squeeze().eq(target).sum().item() +def get_likely_index(tensor): + # find most likely label index for each element in the batch + return tensor.argmax(dim=-1) + + def test(model, epoch): model.eval() correct = 0 @@ -390,11 +401,11 @@ def test(model, epoch): data = transform(data) output = model(data) - pred = argmax(output) + pred = get_likely_index(output) correct += number_of_correct(pred, target) - if "pbar" in globals() and "pbar_update" in globals(): - pbar.update(pbar_update) + # update progress bar + pbar.update(pbar_update) print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n") @@ -408,17 +419,22 @@ def test(model, epoch): log_interval = 20 n_epoch = 2 + pbar_update = 1 / (len(train_loader) + len(test_loader)) +losses = [] # The transform needs to live on the same device as the model and the data. transform = transform.to(device) - with tqdm(total=n_epoch) as pbar: for epoch in range(1, n_epoch + 1): train(model, epoch, log_interval) test(model, epoch) scheduler.step() +# Let's plot the training loss versus the number of iteration. +# plt.plot(losses); +# plt.title("training loss"); + ###################################################################### # The network should be more than 65% accurate on the test set after 2 @@ -427,14 +443,14 @@ def test(model, epoch): # -def predict(waveform): +def predict(tensor): # Use the model to predict the label of the waveform - waveform = waveform.to(device) - waveform = transform(waveform) - output = model(waveform.unsqueeze(0)) - output = argmax(output).squeeze() - output = labels[output] - return output + tensor = tensor.to(device) + tensor = transform(tensor) + tensor = model(tensor.unsqueeze(0)) + tensor = get_likely_index(tensor) + tensor = index_to_label(tensor.squeeze()) + return tensor waveform, sample_rate, utterance, *_ = train_set[-1] @@ -466,6 +482,11 @@ def predict(waveform): # will record one second of audio and try to classify it. # +from google.colab import output as colab_output +from base64 import b64decode +from io import BytesIO +from pydub import AudioSegment + RECORD = """ const sleep = time => new Promise(resolve => setTimeout(resolve, time)) @@ -501,7 +522,6 @@ def record(seconds=1): fileformat = "wav" filename = f"_audio.{fileformat}" AudioSegment.from_file(BytesIO(b)).export(filename, format=fileformat) - return torchaudio.load(filename)