Skip to content

Commit 0b882f1

Browse files
committed
improve GPU performance. add interactive demo at the end.
1 parent c9e9423 commit 0b882f1

File tree

1 file changed

+130
-48
lines changed

1 file changed

+130
-48
lines changed

intermediate_source/speech_command_recognition_with_torchaudio.py

Lines changed: 130 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,33 @@
33
==========================================
44
55
This tutorial will show you how to correctly format an audio dataset and
6-
then train/test an audio classifier network on the dataset. First, let’s
7-
import the common torch packages such as
8-
``torchaudio <https://github.com/pytorch/audio>``\ \_ and can be
6+
then train/test an audio classifier network on the dataset.
7+
8+
Colab has GPU option available. In the menu tabs, select “Runtime” then
9+
“Change runtime type”. In the pop-up that follows, you can choose GPU.
10+
After the change, your runtime should automatically restart (which means
11+
information from executed cells disappear).
12+
13+
First, let’s import the common torch packages such as
14+
``torchaudio <https://github.com/pytorch/audio>``\ \_ that can be
915
installed by following the instructions on the website.
1016
1117
"""
1218

1319
# Uncomment the following line to run in Google Colab
14-
# !pip install torch
15-
# !pip install torchaudio
20+
21+
# GPU:
22+
# !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
23+
24+
# CPU:
25+
# !pip install torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
26+
27+
# For interactive demo at the end:
28+
# !pip install pydub
1629

1730
import os
31+
from base64 import b64decode
32+
from io import BytesIO
1833

1934
import IPython.display as ipd
2035
import matplotlib.pyplot as plt
@@ -25,6 +40,8 @@
2540
import torch.nn.functional as F
2641
import torch.optim as optim
2742
import torchaudio
43+
from google.colab import output as colab_output
44+
from pydub import AudioSegment
2845
from torchaudio.datasets import SPEECHCOMMANDS
2946

3047
######################################################################
@@ -73,11 +90,12 @@ def load_list(filename):
7390
self._walker = load_list("testing_list.txt")
7491
elif subset == "training":
7592
excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
93+
excludes = set(excludes)
7694
self._walker = [w for w in self._walker if w not in excludes]
7795

7896

97+
# Create training and testing split of the data. We do not use validation in this tutorial.
7998
train_set = SubsetSC("training")
80-
# valid_set = SubsetSC("validation")
8199
test_set = SubsetSC("testing")
82100

83101
waveform, sample_rate, label, speaker_id, utterance_number = train_set[0]
@@ -92,15 +110,14 @@ def load_list(filename):
92110
print("Shape of waveform: {}".format(waveform.size()))
93111
print("Sample rate of waveform: {}".format(sample_rate))
94112

95-
plt.figure();
96113
plt.plot(waveform.t().numpy());
97114

98115

99116
######################################################################
100117
# Let’s find the list of labels available in the dataset.
101118
#
102119

103-
labels = list(set(datapoint[2] for datapoint in train_set))
120+
labels = sorted(list(set(datapoint[2] for datapoint in train_set)))
104121
labels
105122

106123

@@ -170,6 +187,7 @@ def encode(word):
170187
# encoding.
171188
#
172189

190+
173191
def pad_sequence(batch):
174192
# Make all tensor in a batch the same length by padding with zeros
175193
batch = [item.t() for item in batch]
@@ -184,9 +202,9 @@ def collate_fn(batch):
184202

185203
tensors, targets = [], []
186204

187-
# Apply transform and encode
205+
# Gather in lists, and encode labels
188206
for waveform, _, label, *_ in batch:
189-
tensors += [transform(waveform)]
207+
tensors += [waveform]
190208
targets += [encode(label)]
191209

192210
# Group the list of tensors into a batched tensor
@@ -196,20 +214,31 @@ def collate_fn(batch):
196214
return tensors, targets
197215

198216

199-
batch_size = 128
217+
batch_size = 256
200218

201-
if device == 'cuda':
219+
if device == "cuda":
202220
num_workers = 1
203221
pin_memory = True
204222
else:
205223
num_workers = 0
206224
pin_memory = False
207225

208226
train_loader = torch.utils.data.DataLoader(
209-
train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory,
227+
train_set,
228+
batch_size=batch_size,
229+
shuffle=True,
230+
collate_fn=collate_fn,
231+
num_workers=num_workers,
232+
pin_memory=pin_memory,
210233
)
211234
test_loader = torch.utils.data.DataLoader(
212-
test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory,
235+
test_set,
236+
batch_size=batch_size,
237+
shuffle=False,
238+
drop_last=False,
239+
collate_fn=collate_fn,
240+
num_workers=num_workers,
241+
pin_memory=pin_memory,
213242
)
214243

215244

@@ -232,21 +261,21 @@ def collate_fn(batch):
232261

233262

234263
class M5(nn.Module):
235-
def __init__(self, stride=16, n_channel=32, n_output=35):
264+
def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
236265
super().__init__()
237-
self.conv1 = nn.Conv1d(1, n_channel, 80, stride=stride)
266+
self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
238267
self.bn1 = nn.BatchNorm1d(n_channel)
239268
self.pool1 = nn.MaxPool1d(4)
240-
self.conv2 = nn.Conv1d(n_channel, n_channel, 3)
269+
self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
241270
self.bn2 = nn.BatchNorm1d(n_channel)
242271
self.pool2 = nn.MaxPool1d(4)
243-
self.conv3 = nn.Conv1d(n_channel, 2*n_channel, 3)
244-
self.bn3 = nn.BatchNorm1d(2*n_channel)
272+
self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
273+
self.bn3 = nn.BatchNorm1d(2 * n_channel)
245274
self.pool3 = nn.MaxPool1d(4)
246-
self.conv4 = nn.Conv1d(2*n_channel, 2*n_channel, 3)
247-
self.bn4 = nn.BatchNorm1d(2*n_channel)
275+
self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
276+
self.bn4 = nn.BatchNorm1d(2 * n_channel)
248277
self.pool4 = nn.MaxPool1d(4)
249-
self.fc1 = nn.Linear(2*n_channel, n_output)
278+
self.fc1 = nn.Linear(2 * n_channel, n_output)
250279

251280
def forward(self, x):
252281
x = self.conv1(x)
@@ -267,7 +296,7 @@ def forward(self, x):
267296
return F.log_softmax(x, dim=2)
268297

269298

270-
model = M5(n_output=len(labels))
299+
model = M5(n_input=transformed.shape[0], n_output=len(labels))
271300
model.to(device)
272301
print(model)
273302

@@ -284,7 +313,7 @@ def count_parameters(model):
284313
# We will use the same optimization technique used in the paper, an Adam
285314
# optimizer with weight decay set to 0.0001. At first, we will train with
286315
# a learning rate of 0.01, but we will use a ``scheduler`` to decrease it
287-
# to 0.001 during training.
316+
# to 0.001 during training after 20 epochs.
288317
#
289318

290319
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
@@ -296,11 +325,9 @@ def count_parameters(model):
296325
# --------------------------------
297326
#
298327
# Now let’s define a training function that will feed our training data
299-
# into the model and perform the backward pass and optimization steps.
300-
#
301-
# Finally, we can train and test the network. We will train the network
302-
# for ten epochs then reduce the learn rate and train for ten more epochs.
303-
# The network will be tested after each epoch to see how the accuracy
328+
# into the model and perform the backward pass and optimization steps. For
329+
# training, the loss we will use is the negative log-likelihood. The
330+
# network will then be tested after each epoch to see how the accuracy
304331
# varies during the training.
305332
#
306333

@@ -312,6 +339,8 @@ def train(model, epoch, log_interval):
312339
data = data.to(device)
313340
target = target.to(device)
314341

342+
# apply transform and model on whole batch directly on device
343+
data = transform(data)
315344
output = model(data)
316345

317346
# negative log-likelihood for a tensor of size (batch x 1 x n_output)
@@ -323,10 +352,10 @@ def train(model, epoch, log_interval):
323352

324353
# print training stats
325354
if batch_idx % log_interval == 0:
326-
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}')
355+
print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}")
327356

328-
if 'pbar' in globals():
329-
pbar.update()
357+
if "pbar" in globals() and "pbar_update" in globals():
358+
pbar.update(pbar_update)
330359

331360

332361
######################################################################
@@ -346,24 +375,28 @@ def argmax(tensor):
346375

347376
def number_of_correct(pred, target):
348377
# compute number of correct predictions
349-
return pred.squeeze().eq(target).cpu().sum().item()
378+
return pred.squeeze().eq(target).sum().item()
350379

351380

352381
def test(model, epoch):
353382
model.eval()
354383
correct = 0
355384
for data, target in test_loader:
385+
356386
data = data.to(device)
357387
target = target.to(device)
358388

389+
# apply transform and model on whole batch directly on device
390+
data = transform(data)
359391
output = model(data)
392+
360393
pred = argmax(output)
361394
correct += number_of_correct(pred, target)
362395

363-
if 'pbar' in globals():
364-
pbar.update()
396+
if "pbar" in globals() and "pbar_update" in globals():
397+
pbar.update(pbar_update)
365398

366-
print(f'\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')
399+
print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n")
367400

368401

369402
######################################################################
@@ -375,21 +408,28 @@ def test(model, epoch):
375408

376409
log_interval = 20
377410
n_epoch = 2
411+
pbar_update = 1 / (len(train_loader) + len(test_loader))
412+
413+
# The transform needs to live on the same device as the model and the data.
414+
transform = transform.to(device)
378415

379-
with tqdm(total=n_epoch * (len(train_loader) + len(test_loader))) as pbar:
380-
for epoch in range(1, n_epoch+1):
416+
with tqdm(total=n_epoch) as pbar:
417+
for epoch in range(1, n_epoch + 1):
381418
train(model, epoch, log_interval)
382419
test(model, epoch)
383420
scheduler.step()
384421

385422

386423
######################################################################
387-
# Let’s look at the last words in the train set, and see how the model did
388-
# on it.
424+
# The network should be more than 65% accurate on the test set after 2
425+
# epochs, and 85% after 21 epochs. Let’s look at the last words in the
426+
# train set, and see how the model did on it.
389427
#
390428

429+
391430
def predict(waveform):
392-
# Take a waveform and use the model to predict
431+
# Use the model to predict the label of the waveform
432+
waveform = waveform.to(device)
393433
waveform = transform(waveform)
394434
output = model(waveform.unsqueeze(0))
395435
output = argmax(output).squeeze()
@@ -410,9 +450,9 @@ def predict(waveform):
410450
for i, (waveform, sample_rate, utterance, *_) in enumerate(test_set):
411451
output = predict(waveform)
412452
if output != utterance:
413-
ipd.Audio(waveform.numpy(), rate=sample_rate)
414-
print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
415-
break
453+
ipd.Audio(waveform.numpy(), rate=sample_rate)
454+
print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
455+
break
416456
else:
417457
print("All examples in this dataset were correctly classified!")
418458
print("In this case, let's just look at the last data point")
@@ -421,17 +461,59 @@ def predict(waveform):
421461

422462

423463
######################################################################
424-
# Feel free to try with one of your own recordings!
464+
# Feel free to try with one of your own recordings of one of the labels!
465+
# For example, using Colab, say “Go” while executing the cell below. This
466+
# will record one second of audio and try to classify it.
425467
#
426468

427469

470+
RECORD = """
471+
const sleep = time => new Promise(resolve => setTimeout(resolve, time))
472+
const b2text = blob => new Promise(resolve => {
473+
const reader = new FileReader()
474+
reader.onloadend = e => resolve(e.srcElement.result)
475+
reader.readAsDataURL(blob)
476+
})
477+
var record = time => new Promise(async resolve => {
478+
stream = await navigator.mediaDevices.getUserMedia({ audio: true })
479+
recorder = new MediaRecorder(stream)
480+
chunks = []
481+
recorder.ondataavailable = e => chunks.push(e.data)
482+
recorder.start()
483+
await sleep(time)
484+
recorder.onstop = async ()=>{
485+
blob = new Blob(chunks)
486+
text = await b2text(blob)
487+
resolve(text)
488+
}
489+
recorder.stop()
490+
})
491+
"""
492+
493+
494+
def record(seconds=1):
495+
display(ipd.Javascript(RECORD))
496+
print(f"Recording started for {seconds} seconds.")
497+
s = colab_output.eval_js("record(%d)" % (seconds * 1000))
498+
print("Recording ended.")
499+
b = b64decode(s.split(",")[1])
500+
501+
fileformat = "wav"
502+
filename = f"_audio.{fileformat}"
503+
AudioSegment.from_file(BytesIO(b)).export(filename, format=fileformat)
504+
505+
return torchaudio.load(filename)
506+
507+
508+
waveform, sample_rate = record()
509+
print(f"Predicted: {predict(waveform)}.")
510+
ipd.Audio(waveform.numpy(), rate=sample_rate)
511+
512+
428513
######################################################################
429514
# Conclusion
430515
# ----------
431516
#
432-
# The network should be more than 70% accurate on the test set after 2
433-
# epochs, 80% after 14 epochs, and 85% after 21 epochs.
434-
#
435517
# In this tutorial, we used torchaudio to load a dataset and resample the
436518
# signal. We have then defined a neural network that we trained to
437519
# recognize a given command. There are also other data preprocessing

0 commit comments

Comments
 (0)