Skip to content

Commit 2c599be

Browse files
mthrokbrianjoCaroline Chen
authored
Add speech recognition tutorial (#1714)
* Update build.sh * Update audio tutorial (#1713) * Update audio tutorial * fix * Add speech recognition tutorial * update title * Fix some * Apply suggestions from code review Co-authored-by: Caroline Chen <carolinechen@fb.com> * Update intermediate_source/speech_recognition_pipeline_tutorial.py * Update intermediate_source/speech_recognition_pipeline_tutorial.py * Update intermediate_source/speech_recognition_pipeline_tutorial.py * Add link to the tutorial page * Fix Co-authored-by: Brian Johnson <brianjo@fb.com> Co-authored-by: Caroline Chen <carolinechen@fb.com>
1 parent f725a0e commit 2c599be

File tree

3 files changed

+325
-0
lines changed

3 files changed

+325
-0
lines changed
Loading

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ Welcome to PyTorch Tutorials
130130
:link: beginner/audio_preprocessing_tutorial.html
131131
:tags: Audio
132132

133+
.. customcarditem::
134+
:header: Automatic Speech Recognition with Wav2Vec2 in torchaudio
135+
:card_description: Learn how to use torchaudio's pretrained models for building a speech recognition application.
136+
:image: _static/img/thumbnails/cropped/torchaudio-asr.png
137+
:link: intermediate_source/speech_recognition_pipeline_tutorial.html
138+
:tags: Audio
139+
133140
.. customcarditem::
134141
:header: Speech Command Recognition
135142
:card_description: Learn how to correctly format an audio dataset and then train/test an audio classifier network on the dataset.
@@ -615,6 +622,7 @@ Additional Resources
615622
:caption: Audio
616623

617624
beginner/audio_preprocessing_tutorial
625+
intermediate/speech_recognition_pipeline_tutorial
618626
intermediate/speech_command_recognition_with_torchaudio_tutorial
619627
intermediate/text_to_speech_with_torchaudio
620628

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
"""
2+
Speech Recognition with Torchaudio
3+
==================================
4+
5+
**Author**: `Moto Hira <moto@fb.com>`__
6+
7+
This tutorial shows how to perform speech recognition using using
8+
pre-trained models from wav2vec 2.0
9+
[`paper <https://arxiv.org/abs/2006.11477>`__].
10+
11+
"""
12+
13+
14+
######################################################################
15+
# Overview
16+
# --------
17+
#
18+
# The process of speech recognition looks like the following.
19+
#
20+
# 1. Extract the acoustic features from audio waveform
21+
#
22+
# 2. Estimate the class of the acoustic features frame-by-frame
23+
#
24+
# 3. Generate hypothesis from the sequence of the class probabilities
25+
#
26+
# Torchaudio provides easy access to the pre-trained weights and
27+
# associated information, such as the expected sample rate and class
28+
# labels. They are bundled together and available under
29+
# ``torchaudio.pipelines`` module.
30+
#
31+
32+
33+
######################################################################
34+
# Preparation
35+
# -----------
36+
#
37+
# First we import the necessary packages, and fetch data that we work on.
38+
#
39+
40+
# %matplotlib inline
41+
42+
import torch
43+
import torchaudio
44+
45+
print(torch.__version__)
46+
print(torchaudio.__version__)
47+
48+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49+
50+
print(device)
51+
52+
53+
import matplotlib
54+
import matplotlib.pyplot as plt
55+
56+
[width, height] = matplotlib.rcParams['figure.figsize']
57+
if width < 10:
58+
matplotlib.rcParams['figure.figsize'] = [width * 2.5, height]
59+
60+
import IPython
61+
62+
import requests
63+
64+
SPEECH_FILE = "speech.wav"
65+
66+
url = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
67+
with open(SPEECH_FILE, 'wb') as file_:
68+
file_.write(requests.get(url).content)
69+
70+
71+
72+
######################################################################
73+
# Creating a pipeline
74+
# -------------------
75+
#
76+
# First, we will create a Wav2Vec2 model that performs the feature
77+
# extraction and the classification.
78+
#
79+
# There are two types of Wav2Vec2 pre-trained weights available in
80+
# torchaudio. The ones fine-tuned for ASR task, and the ones not
81+
# fine-tuned.
82+
#
83+
# Wav2Vec2 (and HuBERT) models are trained in self-supervised manner. They
84+
# are firstly trained with audio only for representation learning, then
85+
# fine-tuned for a specific task with additional labels.
86+
#
87+
# The pre-trained weights without fine-tuning can be fine-tuned
88+
# for other downstream tasks as well, but this tutorial does not
89+
# cover that.
90+
#
91+
# We will use ``torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`` here.
92+
#
93+
# There are multiple models available in
94+
# ``torchaudio.pipelines``. Please check the
95+
# `documentation <https://pytorch.org/audio/stable/pipelines.html>`__ for
96+
# the detail of how they are trained.
97+
#
98+
# The bundle object provides the interface to instantiate model and other
99+
# information. Sampling rate and the class labels are found as follow.
100+
#
101+
102+
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
103+
104+
print("Sample Rate:", bundle.sample_rate)
105+
106+
print("Labels:", bundle.get_labels())
107+
108+
109+
######################################################################
110+
# Model can be constructed as following. This process will automatically
111+
# fetch the pre-trained weights and load it into the model.
112+
#
113+
114+
model = bundle.get_model().to(device)
115+
116+
print(model.__class__)
117+
118+
119+
######################################################################
120+
# Loading data
121+
# ------------
122+
#
123+
# We will use the speech data from `VOiCES
124+
# dataset <https://iqtlabs.github.io/voices/>`__, which is licensed under
125+
# Creative Commos BY 4.0.
126+
#
127+
128+
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
129+
130+
131+
######################################################################
132+
# To load data, we use ``torchaudio.load``.
133+
#
134+
# If the sampling rate is different from what the pipeline expects, then
135+
# we can use ``torchaudio.functional.resample`` for resampling.
136+
#
137+
# **Note** -
138+
# ```torchaudio.functional.resample`` <https://pytorch.org/audio/stable/functional.html#resample>`__
139+
# works on CUDA tensors as well. - When performing resampling multiple
140+
# times on the same set of sample rates, using
141+
# ```torchaudio.transforms.Resample`` <https://pytorch.org/audio/stable/transforms.html#resample>`__
142+
# might improve the performace.
143+
#
144+
145+
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
146+
waveform = waveform.to(device)
147+
148+
if sample_rate != bundle.sample_rate:
149+
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
150+
151+
152+
######################################################################
153+
# Extracting acoustic features
154+
# ----------------------------
155+
#
156+
# The next step is to extract acoustic features from the audio.
157+
#
158+
# Note that Wav2Vec2 models fine-tuned for ASR task can perform feature
159+
# extraction and classification with one step, but for the sake of the
160+
# tutorial, we also show how to perform feature extraction here.
161+
#
162+
163+
with torch.inference_mode():
164+
features, _ = model.extract_features(waveform)
165+
166+
167+
######################################################################
168+
# The returned features is a list of tensors. Each tensor is the output of
169+
# a transformer layer.
170+
#
171+
172+
for i, feats in enumerate(features):
173+
plt.imshow(feats[0].cpu())
174+
plt.title(f"Feature from transformer layer {i+1}")
175+
plt.xlabel("Feature dimension")
176+
plt.ylabel("Frame (time-axis)")
177+
plt.show()
178+
179+
180+
######################################################################
181+
# Feature classification
182+
# ----------------------
183+
#
184+
# Once the acoustic features are extracted, the next step is to classify
185+
# them into a set of categories.
186+
#
187+
# Wav2Vec2 model provides method to perform the feature extraction and
188+
# classification in one step.
189+
#
190+
191+
with torch.inference_mode():
192+
emission, _ = model(waveform)
193+
194+
195+
######################################################################
196+
# The output is in the form of logits. It is not in the form of
197+
# probability.
198+
#
199+
# Let’s visualize this.
200+
#
201+
202+
plt.imshow(emission[0].cpu().T)
203+
plt.title("Classification result")
204+
plt.xlabel("Frame (time-axis)")
205+
plt.ylabel("Class")
206+
plt.colorbar()
207+
plt.show()
208+
print("Class labels:", bundle.get_labels())
209+
210+
211+
######################################################################
212+
# We can see that there are strong indications to certain labels across
213+
# the time line.
214+
#
215+
# Note that the class 1 to 3, (``<pad>``, ``</s>`` and ``<unk>``) have
216+
# mostly huge negative values, this is an artifact from the original
217+
# ``fairseq`` implementation where these labels are added by default but
218+
# not used during the training.
219+
#
220+
221+
222+
######################################################################
223+
# Generating transcripts
224+
# ----------------------
225+
#
226+
# From the sequence of label probabilities, now we want to generate
227+
# transcripts. The process to generate hypotheses is often called
228+
# “decoding”.
229+
#
230+
# Decoding is more elaborate than simple classification because
231+
# decoding at certain time step can be affected by surrounding
232+
# observations.
233+
#
234+
# For example, take a word like ``night`` and ``knight``. Even if their
235+
# prior probability distribution are differnt (in typical conversations,
236+
# ``night`` would occur way more often than ``knight``), to accurately
237+
# generate transcripts with ``knight``, such as ``a knight with a sword``,
238+
# the decoding process has to postpone the final decision until it sees
239+
# enough context.
240+
#
241+
# There are many decoding techniques proposed, and they require external
242+
# resources, such as word dictionary and language models.
243+
#
244+
# In this tutorial, for the sake of simplicity, we will perform greeding
245+
# decoding which does not depend on such external components, and simply
246+
# pick up the best hypothesis at each time step. Therefore, the context
247+
# information are not used, and only one transcript can be generated.
248+
#
249+
# We start by defining greedy decoding algorithm.
250+
#
251+
252+
class GreedyCTCDecoder(torch.nn.Module):
253+
def __init__(self, labels, ignore):
254+
super().__init__()
255+
self.labels = labels
256+
self.ignore = ignore
257+
258+
def forward(self, emission: torch.Tensor) -> str:
259+
"""Given a sequence emission over labels, get the best path string
260+
Args:
261+
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
262+
Returns:
263+
str: The resulting transcript
264+
"""
265+
indices = torch.argmax(emission, dim=-1) # [num_seq,]
266+
indices = torch.unique_consecutive(indices, dim=-1)
267+
indices = [i for i in indices if i not in self.ignore]
268+
return ''.join([self.labels[i] for i in indices])
269+
270+
271+
######################################################################
272+
# Now create the decoder object and decode the transcript.
273+
#
274+
275+
decoder = GreedyCTCDecoder(
276+
labels=bundle.get_labels(),
277+
ignore=(0, 1, 2, 3),
278+
)
279+
transcript = decoder(emission[0])
280+
281+
282+
######################################################################
283+
# Let’s check the result and listen again the audio.
284+
#
285+
286+
print(transcript)
287+
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
288+
289+
290+
######################################################################
291+
# There are few remarks in decoding.
292+
#
293+
# Firstly, the ASR model is fine-tuned using a loss function called CTC.
294+
# The detail of CTC loss is explained
295+
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
296+
# special token which represents a repetition of the previous symbol. In
297+
# decoding, these are simply ignored.
298+
#
299+
# Secondly, as is explained in the feature extraction section, the
300+
# Wav2Vec2 model originated from ``fairseq`` has labels that are not used.
301+
# These also have to be ignored.
302+
#
303+
304+
305+
######################################################################
306+
# Conclusion
307+
# ----------
308+
#
309+
# In this tutorial, we looked at how to use ``torchaudio.pipeline`` to
310+
# perform acoustic feature extraction and speech recognition. Constructing
311+
# a model and getting the emission is as short as two lines.
312+
#
313+
# ::
314+
#
315+
# model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
316+
# emission = model(waveforms, ...)
317+
#

0 commit comments

Comments
 (0)