Skip to content

Commit c9e9423

Browse files
committed
feedback.
1 parent 5e728e9 commit c9e9423

File tree

1 file changed

+95
-94
lines changed

1 file changed

+95
-94
lines changed

intermediate_source/speech_command_recognition_with_torchaudio.py

Lines changed: 95 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
1111
"""
1212

13+
# Uncomment the following line to run in Google Colab
1314
# !pip install torch
1415
# !pip install torchaudio
1516

@@ -61,44 +62,24 @@ class SubsetSC(SPEECHCOMMANDS):
6162
def __init__(self, subset: str = None):
6263
super().__init__("./", download=True)
6364

64-
if subset in ["training", "validation"]:
65-
filepath = os.path.join(self._path, "validation_list.txt")
66-
with open(filepath) as f:
67-
validation_list = [
68-
os.path.join(self._path, line.strip()) for line in f.readlines()
69-
]
70-
71-
if subset in ["training", "testing"]:
72-
filepath = os.path.join(self._path, "testing_list.txt")
73-
with open(filepath) as f:
74-
testing_list = [
75-
os.path.join(self._path, line.strip()) for line in f.readlines()
76-
]
65+
def load_list(filename):
66+
filepath = os.path.join(self._path, filename)
67+
with open(filepath) as fileobj:
68+
return [os.path.join(self._path, line.strip()) for line in fileobj]
7769

7870
if subset == "validation":
79-
walker = validation_list
71+
self._walker = load_list("validation_list.txt")
8072
elif subset == "testing":
81-
walker = testing_list
82-
elif subset in ["training", None]:
83-
walker = self._walker # defined by SPEECHCOMMANDS parent class
84-
else:
85-
raise ValueError(
86-
"When `subset` not None, it must take a value from {'training', 'validation', 'testing'}."
87-
)
88-
89-
if subset == "training":
90-
walker = filter(
91-
lambda w: not (w in validation_list or w in testing_list), walker
92-
)
93-
94-
self._walker = list(walker)
73+
self._walker = load_list("testing_list.txt")
74+
elif subset == "training":
75+
excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
76+
self._walker = [w for w in self._walker if w not in excludes]
9577

9678

9779
train_set = SubsetSC("training")
9880
# valid_set = SubsetSC("validation")
9981
test_set = SubsetSC("testing")
10082

101-
10283
waveform, sample_rate, label, speaker_id, utterance_number = train_set[0]
10384

10485

@@ -111,8 +92,8 @@ def __init__(self, subset: str = None):
11192
print("Shape of waveform: {}".format(waveform.size()))
11293
print("Sample rate of waveform: {}".format(sample_rate))
11394

114-
plt.figure()
115-
plt.plot(waveform.t().numpy())
95+
plt.figure();
96+
plt.plot(waveform.t().numpy());
11697

11798

11899
######################################################################
@@ -147,31 +128,26 @@ def __init__(self, subset: str = None):
147128
# Formatting the Data
148129
# -------------------
149130
#
150-
# The dataset uses a single channel for audio. We do not need to down mix
151-
# the audio channels (which we could do for instance by either taking the
152-
# mean along the channel dimension, or simply keeping only one of the
153-
# channels).
131+
# This is a good place to apply transformations to the data. For the
132+
# waveform, we downsample the audio for faster processing without losing
133+
# too much of the classification power.
154134
#
155-
156-
157-
######################################################################
158-
# We downsample the audio for faster processing without losing too much of
159-
# the classification power.
135+
# We don’t need to apply other transformations here. It is common for some
136+
# datasets though to have to reduce the number of channels (say from
137+
# stereo to mono) by either taking the mean along the channel dimension,
138+
# or simply keeping only one of the channels. Since SpeechCommands uses a
139+
# single channel for audio, this is not needed here.
160140
#
161141

162142
new_sample_rate = 8000
163-
transform = torchaudio.transforms.Resample(
164-
orig_freq=sample_rate, new_freq=new_sample_rate
165-
)
143+
transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)
166144
transformed = transform(waveform)
167145

168146
ipd.Audio(transformed.numpy(), rate=new_sample_rate)
169147

170148

171149
######################################################################
172-
# To encode each word, we use a simple language model where we represent
173-
# each fo the 35 words by its corresponding position of the command in the
174-
# list above.
150+
# We are encoding each word using its index in the list of labels.
175151
#
176152

177153

@@ -183,48 +159,57 @@ def encode(word):
183159

184160

185161
######################################################################
186-
# We now define a collate function that assembles a list of audio
187-
# recordings and a list of utterances into two batched tensors. In this
188-
# function, we also apply the resampling, and the encoding. The collate
189-
# function is used in the pytroch data loader that allow us to iterate
190-
# over a dataset by batches.
162+
# To turn a list of data point made of audio recordings and utterances
163+
# into two batched tensors for the model, we implement a collate function
164+
# which is used by the PyTorch DataLoader that allows us to iterate over a
165+
# dataset by batches. Please see `the
166+
# documentation <https://pytorch.org/docs/stable/data.html#working-with-collate-fn>`__
167+
# for more information about working with a collate function.
168+
#
169+
# In the collate function, we also apply the resampling, and the text
170+
# encoding.
191171
#
192-
193172

194173
def pad_sequence(batch):
195174
# Make all tensor in a batch the same length by padding with zeros
196175
batch = [item.t() for item in batch]
197-
batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.0)
176+
batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
198177
return batch.permute(0, 2, 1)
199178

200179

201180
def collate_fn(batch):
202181

203182
# A data tuple has the form:
204183
# waveform, sample_rate, label, speaker_id, utterance_number
205-
# and so we are only interested in item 0 and 2
206184

207-
# Apply transforms to waveforms
208-
tensors = [transform(b[0]) for b in batch]
209-
tensors = pad_sequence(tensors)
185+
tensors, targets = [], []
186+
187+
# Apply transform and encode
188+
for waveform, _, label, *_ in batch:
189+
tensors += [transform(waveform)]
190+
targets += [encode(label)]
210191

211-
# Apply transform to target utterance
212-
targets = [encode(b[2]) for b in batch]
192+
# Group the list of tensors into a batched tensor
193+
tensors = pad_sequence(tensors)
213194
targets = torch.stack(targets)
214195

215196
return tensors, targets
216197

217198

218199
batch_size = 128
219200

220-
kwargs = (
221-
{"num_workers": 1, "pin_memory": True} if device == "cuda" else {}
222-
) # needed for using datasets on gpu
201+
if device == 'cuda':
202+
num_workers = 1
203+
pin_memory = True
204+
else:
205+
num_workers = 0
206+
pin_memory = False
207+
223208
train_loader = torch.utils.data.DataLoader(
224-
train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, **kwargs
209+
train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory,
225210
)
226211
test_loader = torch.utils.data.DataLoader(
227-
test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, **kwargs
212+
test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory,
228213
)
229214

230215

@@ -255,13 +240,13 @@ def __init__(self, stride=16, n_channel=32, n_output=35):
255240
self.conv2 = nn.Conv1d(n_channel, n_channel, 3)
256241
self.bn2 = nn.BatchNorm1d(n_channel)
257242
self.pool2 = nn.MaxPool1d(4)
258-
self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, 3)
259-
self.bn3 = nn.BatchNorm1d(2 * n_channel)
243+
self.conv3 = nn.Conv1d(n_channel, 2*n_channel, 3)
244+
self.bn3 = nn.BatchNorm1d(2*n_channel)
260245
self.pool3 = nn.MaxPool1d(4)
261-
self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, 3)
262-
self.bn4 = nn.BatchNorm1d(2 * n_channel)
246+
self.conv4 = nn.Conv1d(2*n_channel, 2*n_channel, 3)
247+
self.bn4 = nn.BatchNorm1d(2*n_channel)
263248
self.pool4 = nn.MaxPool1d(4)
264-
self.fc1 = nn.Linear(2 * n_channel, n_output)
249+
self.fc1 = nn.Linear(2*n_channel, n_output)
265250

266251
def forward(self, x):
267252
x = self.conv1(x)
@@ -303,9 +288,7 @@ def count_parameters(model):
303288
#
304289

305290
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
306-
scheduler = optim.lr_scheduler.StepLR(
307-
optimizer, step_size=20, gamma=0.1
308-
) # reduce the learning after 20 epochs by a factor of 10
291+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) # reduce the learning after 20 epochs by a factor of 10
309292

310293

311294
######################################################################
@@ -340,11 +323,9 @@ def train(model, epoch, log_interval):
340323

341324
# print training stats
342325
if batch_idx % log_interval == 0:
343-
print(
344-
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}"
345-
)
326+
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss:.6f}')
346327

347-
if "pbar" in globals():
328+
if 'pbar' in globals():
348329
pbar.update()
349330

350331

@@ -379,12 +360,10 @@ def test(model, epoch):
379360
pred = argmax(output)
380361
correct += number_of_correct(pred, target)
381362

382-
if "pbar" in globals():
383-
pbar.update()
363+
if 'pbar' in globals():
364+
pbar.update()
384365

385-
print(
386-
f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n"
387-
)
366+
print(f'\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')
388367

389368

390369
######################################################################
@@ -398,38 +377,60 @@ def test(model, epoch):
398377
n_epoch = 2
399378

400379
with tqdm(total=n_epoch * (len(train_loader) + len(test_loader))) as pbar:
401-
for epoch in range(1, n_epoch + 1):
380+
for epoch in range(1, n_epoch+1):
402381
train(model, epoch, log_interval)
403382
test(model, epoch)
404383
scheduler.step()
405384

406385

407386
######################################################################
408-
# Let’s try looking at one of the last words in the train and test set.
387+
# Let’s look at the last words in the train set, and see how the model did
388+
# on it.
409389
#
410390

391+
def predict(waveform):
392+
# Take a waveform and use the model to predict
393+
waveform = transform(waveform)
394+
output = model(waveform.unsqueeze(0))
395+
output = argmax(output).squeeze()
396+
output = labels[output]
397+
return output
398+
399+
411400
waveform, sample_rate, utterance, *_ = train_set[-1]
412401
ipd.Audio(waveform.numpy(), rate=sample_rate)
413402

414-
waveform = transform(waveform)
415-
output = model(waveform.unsqueeze(0))
416-
output = argmax(output).squeeze()
417-
print(f"Expected: {utterance}. Predicted: {labels[output]}.")
403+
print(f"Expected: {utterance}. Predicted: {predict(waveform)}.")
418404

419-
waveform, sample_rate, utterance, *_ = test_set[-1]
420-
ipd.Audio(waveform.numpy(), rate=sample_rate)
421405

422-
waveform = transform(waveform)
423-
output = model(waveform.unsqueeze(0))
424-
output = argmax(output).squeeze()
425-
print(f"Expected: {utterance}. Predicted: {labels[output]}.")
406+
######################################################################
407+
# Let’s find an example that isn’t classified correctly, if there is one.
408+
#
409+
410+
for i, (waveform, sample_rate, utterance, *_) in enumerate(test_set):
411+
output = predict(waveform)
412+
if output != utterance:
413+
ipd.Audio(waveform.numpy(), rate=sample_rate)
414+
print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
415+
break
416+
else:
417+
print("All examples in this dataset were correctly classified!")
418+
print("In this case, let's just look at the last data point")
419+
ipd.Audio(waveform.numpy(), rate=sample_rate)
420+
print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
421+
422+
423+
######################################################################
424+
# Feel free to try with one of your own recordings!
425+
#
426426

427427

428428
######################################################################
429429
# Conclusion
430430
# ----------
431431
#
432-
# After two epochs, the network should be more than 70% accurate.
432+
# The network should be more than 70% accurate on the test set after 2
433+
# epochs, 80% after 14 epochs, and 85% after 21 epochs.
433434
#
434435
# In this tutorial, we used torchaudio to load a dataset and resample the
435436
# signal. We have then defined a neural network that we trained to

0 commit comments

Comments
 (0)