Skip to content

Commit e9bc491

Browse files
committed
feedback.
1 parent 0b882f1 commit e9bc491

File tree

1 file changed

+75
-55
lines changed

1 file changed

+75
-55
lines changed

intermediate_source/speech_command_recognition_with_torchaudio.py

Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,32 @@
1111
information from executed cells disappear).
1212
1313
First, let’s import the common torch packages such as
14-
``torchaudio <https://github.com/pytorch/audio>``\ \_ that can be
15-
installed by following the instructions on the website.
14+
`torchaudio <https://github.com/pytorch/audio>`__ that can be installed
15+
by following the instructions on the website.
1616
1717
"""
1818

1919
# Uncomment the following line to run in Google Colab
2020

21-
# GPU:
22-
# !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
23-
2421
# CPU:
2522
# !pip install torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
2623

24+
# GPU:
25+
# !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
26+
2727
# For interactive demo at the end:
2828
# !pip install pydub
2929

30-
import os
31-
from base64 import b64decode
32-
from io import BytesIO
33-
34-
import IPython.display as ipd
35-
import matplotlib.pyplot as plt
36-
from tqdm.notebook import tqdm
37-
3830
import torch
3931
import torch.nn as nn
4032
import torch.nn.functional as F
4133
import torch.optim as optim
4234
import torchaudio
43-
from google.colab import output as colab_output
44-
from pydub import AudioSegment
45-
from torchaudio.datasets import SPEECHCOMMANDS
35+
36+
import matplotlib.pyplot as plt
37+
import IPython.display as ipd
38+
from tqdm.notebook import tqdm
39+
4640

4741
######################################################################
4842
# Let’s check if a CUDA GPU is available and select our device. Running
@@ -58,22 +52,26 @@
5852
# ---------------------
5953
#
6054
# We use torchaudio to download and represent the dataset. Here we use
61-
# SpeechCommands, which is a datasets of 35 commands spoken by different
62-
# people. The dataset ``SPEECHCOMMANDS`` is a ``torch.utils.data.Dataset``
63-
# version of the dataset.
55+
# `SpeechCommands <https://arxiv.org/abs/1804.03209>`__, which is a
56+
# datasets of 35 commands spoken by different people. The dataset
57+
# ``SPEECHCOMMANDS`` is a ``torch.utils.data.Dataset`` version of the
58+
# dataset. In this dataset, all audio files are about 1 second long (and
59+
# so about 16000 time frames long).
6460
#
65-
# The actual loading and formatting steps happen in the access function
66-
# ``__getitem__``. In ``__getitem__``, we use ``torchaudio.load()`` to
67-
# convert the audio files to tensors. ``torchaudio.load()`` returns a
68-
# tuple containing the newly created tensor along with the sampling
69-
# frequency of the audio file (16kHz for SpeechCommands). In this dataset,
70-
# all audio files are about 1 second long (and so about 16000 time frames
71-
# long).
61+
# The actual loading and formatting steps happen when a data point is
62+
# being accessed, and torchaudio takes care of converting the audio files
63+
# to tensors. If one wants to load an audio file directly instead,
64+
# ``torchaudio.load()`` can be used. It returns a tuple containing the
65+
# newly created tensor along with the sampling frequency of the audio file
66+
# (16kHz for SpeechCommands).
7267
#
73-
# Here we wrap it to split it into standard training, validation, testing
74-
# subsets.
68+
# Going back to the dataset, here we create a subclass that splits it into
69+
# standard training, validation, testing subsets.
7570
#
7671

72+
from torchaudio.datasets import SPEECHCOMMANDS
73+
import os
74+
7775

7876
class SubsetSC(SPEECHCOMMANDS):
7977
def __init__(self, subset: str = None):
@@ -168,11 +166,22 @@ def load_list(filename):
168166
#
169167

170168

171-
def encode(word):
169+
def label_to_index(word):
170+
# Return the position of the word in labels
172171
return torch.tensor(labels.index(word))
173172

174173

175-
encode("yes")
174+
def index_to_label(index):
175+
# Return the word corresponding to the index in labels
176+
# This is the inverse of label_to_index
177+
return labels[index]
178+
179+
180+
word_start = "yes"
181+
index = label_to_index(word_start)
182+
word_recovered = index_to_label(index)
183+
184+
print(word_start, "-->", index, "-->", word_recovered)
176185

177186

178187
######################################################################
@@ -202,10 +211,10 @@ def collate_fn(batch):
202211

203212
tensors, targets = [], []
204213

205-
# Gather in lists, and encode labels
214+
# Gather in lists, and encode labels as indices
206215
for waveform, _, label, *_ in batch:
207216
tensors += [waveform]
208-
targets += [encode(label)]
217+
targets += [label_to_index(label)]
209218

210219
# Group the list of tensors into a batched tensor
211220
tensors = pad_sequence(tensors)
@@ -250,8 +259,8 @@ def collate_fn(batch):
250259
# the raw audio data. Usually more advanced transforms are applied to the
251260
# audio data, however CNNs can be used to accurately process the raw data.
252261
# The specific architecture is modeled after the M5 network architecture
253-
# described in ``this paper <https://arxiv.org/pdf/1610.00087.pdf>``\ \_.
254-
# An important aspect of models processing raw audio data is the receptive
262+
# described in `this paper <https://arxiv.org/pdf/1610.00087.pdf>`__. An
263+
# important aspect of models processing raw audio data is the receptive
255264
# field of their first layer’s filters. Our model’s first filter is length
256265
# 80 so when processing audio sampled at 8kHz the receptive field is
257266
# around 10ms (and at 4kHz, around 20 ms). This size is similar to speech
@@ -352,10 +361,12 @@ def train(model, epoch, log_interval):
352361

353362
# print training stats
354363
if batch_idx % log_interval == 0:
355-
print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}")
364+
print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")
356365

357-
if "pbar" in globals() and "pbar_update" in globals():
358-
pbar.update(pbar_update)
366+
# update progress bar
367+
pbar.update(pbar_update)
368+
# record loss
369+
losses.append(loss.item())
359370

360371

361372
######################################################################
@@ -368,16 +379,16 @@ def train(model, epoch, log_interval):
368379
#
369380

370381

371-
def argmax(tensor):
372-
# index of the max log-probability
373-
return tensor.max(-1)[1]
374-
375-
376382
def number_of_correct(pred, target):
377-
# compute number of correct predictions
383+
# count number of correct predictions
378384
return pred.squeeze().eq(target).sum().item()
379385

380386

387+
def get_likely_index(tensor):
388+
# find most likely label index for each element in the batch
389+
return tensor.argmax(dim=-1)
390+
391+
381392
def test(model, epoch):
382393
model.eval()
383394
correct = 0
@@ -390,11 +401,11 @@ def test(model, epoch):
390401
data = transform(data)
391402
output = model(data)
392403

393-
pred = argmax(output)
404+
pred = get_likely_index(output)
394405
correct += number_of_correct(pred, target)
395406

396-
if "pbar" in globals() and "pbar_update" in globals():
397-
pbar.update(pbar_update)
407+
# update progress bar
408+
pbar.update(pbar_update)
398409

399410
print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n")
400411

@@ -408,17 +419,22 @@ def test(model, epoch):
408419

409420
log_interval = 20
410421
n_epoch = 2
422+
411423
pbar_update = 1 / (len(train_loader) + len(test_loader))
424+
losses = []
412425

413426
# The transform needs to live on the same device as the model and the data.
414427
transform = transform.to(device)
415-
416428
with tqdm(total=n_epoch) as pbar:
417429
for epoch in range(1, n_epoch + 1):
418430
train(model, epoch, log_interval)
419431
test(model, epoch)
420432
scheduler.step()
421433

434+
# Let's plot the training loss versus the number of iteration.
435+
# plt.plot(losses);
436+
# plt.title("training loss");
437+
422438

423439
######################################################################
424440
# The network should be more than 65% accurate on the test set after 2
@@ -427,14 +443,14 @@ def test(model, epoch):
427443
#
428444

429445

430-
def predict(waveform):
446+
def predict(tensor):
431447
# Use the model to predict the label of the waveform
432-
waveform = waveform.to(device)
433-
waveform = transform(waveform)
434-
output = model(waveform.unsqueeze(0))
435-
output = argmax(output).squeeze()
436-
output = labels[output]
437-
return output
448+
tensor = tensor.to(device)
449+
tensor = transform(tensor)
450+
tensor = model(tensor.unsqueeze(0))
451+
tensor = get_likely_index(tensor)
452+
tensor = index_to_label(tensor.squeeze())
453+
return tensor
438454

439455

440456
waveform, sample_rate, utterance, *_ = train_set[-1]
@@ -466,6 +482,11 @@ def predict(waveform):
466482
# will record one second of audio and try to classify it.
467483
#
468484

485+
from google.colab import output as colab_output
486+
from base64 import b64decode
487+
from io import BytesIO
488+
from pydub import AudioSegment
489+
469490

470491
RECORD = """
471492
const sleep = time => new Promise(resolve => setTimeout(resolve, time))
@@ -501,7 +522,6 @@ def record(seconds=1):
501522
fileformat = "wav"
502523
filename = f"_audio.{fileformat}"
503524
AudioSegment.from_file(BytesIO(b)).export(filename, format=fileformat)
504-
505525
return torchaudio.load(filename)
506526

507527

0 commit comments

Comments
 (0)