Skip to content

Commit a04723d

Browse files
mthrokholly1238brianjo
authored
Update speech recognition tutorial (#1741)
Backport the update from https://pytorch.org/audio/main/tutorials/speech_recognition_pipeline_tutorial.html 1. Fix rendering 2. Render audio Co-authored-by: Holly Sweeney <77758406+holly1238@users.noreply.github.com> Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent 5c44056 commit a04723d

File tree

1 file changed

+47
-53
lines changed

1 file changed

+47
-53
lines changed

intermediate_source/speech_recognition_pipeline_tutorial.py

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
Speech Recognition with Torchaudio
3-
==================================
2+
Speech Recognition with Wav2Vec2
3+
================================
44
55
**Author**: `Moto Hira <moto@fb.com>`__
66
@@ -39,34 +39,31 @@
3939

4040
# %matplotlib inline
4141

42+
import os
43+
4244
import torch
4345
import torchaudio
44-
45-
print(torch.__version__)
46-
print(torchaudio.__version__)
47-
48-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49-
50-
print(device)
51-
52-
46+
import requests
5347
import matplotlib
5448
import matplotlib.pyplot as plt
55-
56-
[width, height] = matplotlib.rcParams['figure.figsize']
57-
if width < 10:
58-
matplotlib.rcParams['figure.figsize'] = [width * 2.5, height]
59-
6049
import IPython
6150

62-
import requests
51+
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
52+
53+
torch.random.manual_seed(0)
54+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6355

64-
SPEECH_FILE = "speech.wav"
56+
print(torch.__version__)
57+
print(torchaudio.__version__)
58+
print(device)
6559

66-
url = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
67-
with open(SPEECH_FILE, 'wb') as file_:
68-
file_.write(requests.get(url).content)
60+
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
61+
SPEECH_FILE = "_assets/speech.wav"
6962

63+
if not os.path.exists(SPEECH_FILE):
64+
os.makedirs('_assets', exist_ok=True)
65+
with open(SPEECH_FILE, 'wb') as file:
66+
file.write(requests.get(SPEECH_URL).content)
7067

7168

7269
######################################################################
@@ -88,11 +85,10 @@
8885
# for other downstream tasks as well, but this tutorial does not
8986
# cover that.
9087
#
91-
# We will use ``torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`` here.
88+
# We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
9289
#
93-
# There are multiple models available in
94-
# ``torchaudio.pipelines``. Please check the
95-
# `documentation <https://pytorch.org/audio/stable/pipelines.html>`__ for
90+
# There are multiple models available as
91+
# :py:mod:`torchaudio.pipelines`. Please check the documentation for
9692
# the detail of how they are trained.
9793
#
9894
# The bundle object provides the interface to instantiate model and other
@@ -125,21 +121,20 @@
125121
# Creative Commos BY 4.0.
126122
#
127123

128-
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
124+
IPython.display.Audio(SPEECH_FILE)
129125

130126

131127
######################################################################
132-
# To load data, we use ``torchaudio.load``.
128+
# To load data, we use :py:func:`torchaudio.load`.
133129
#
134130
# If the sampling rate is different from what the pipeline expects, then
135-
# we can use ``torchaudio.functional.resample`` for resampling.
131+
# we can use :py:func:`torchaudio.functional.resample` for resampling.
136132
#
137-
# **Note** -
138-
# ```torchaudio.functional.resample`` <https://pytorch.org/audio/stable/functional.html#resample>`__
139-
# works on CUDA tensors as well. - When performing resampling multiple
140-
# times on the same set of sample rates, using
141-
# ```torchaudio.transforms.Resample`` <https://pytorch.org/audio/stable/transforms.html#resample>`__
142-
# might improve the performace.
133+
# .. note::
134+
#
135+
# - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
136+
# - When performing resampling multiple times on the same set of sample rates,
137+
# using :py:func:`torchaudio.transforms.Resample` might improve the performace.
143138
#
144139

145140
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
@@ -155,9 +150,10 @@
155150
#
156151
# The next step is to extract acoustic features from the audio.
157152
#
158-
# Note that Wav2Vec2 models fine-tuned for ASR task can perform feature
159-
# extraction and classification with one step, but for the sake of the
160-
# tutorial, we also show how to perform feature extraction here.
153+
# .. note::
154+
# Wav2Vec2 models fine-tuned for ASR task can perform feature
155+
# extraction and classification with one step, but for the sake of the
156+
# tutorial, we also show how to perform feature extraction here.
161157
#
162158

163159
with torch.inference_mode():
@@ -169,12 +165,14 @@
169165
# a transformer layer.
170166
#
171167

168+
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
172169
for i, feats in enumerate(features):
173-
plt.imshow(feats[0].cpu())
174-
plt.title(f"Feature from transformer layer {i+1}")
175-
plt.xlabel("Feature dimension")
176-
plt.ylabel("Frame (time-axis)")
177-
plt.show()
170+
ax[i].imshow(feats[0].cpu())
171+
ax[i].set_title(f"Feature from transformer layer {i+1}")
172+
ax[i].set_xlabel("Feature dimension")
173+
ax[i].set_ylabel("Frame (time-axis)")
174+
plt.tight_layout()
175+
plt.show()
178176

179177

180178
######################################################################
@@ -203,7 +201,6 @@
203201
plt.title("Classification result")
204202
plt.xlabel("Frame (time-axis)")
205203
plt.ylabel("Class")
206-
plt.colorbar()
207204
plt.show()
208205
print("Class labels:", bundle.get_labels())
209206

@@ -218,7 +215,6 @@
218215
# not used during the training.
219216
#
220217

221-
222218
######################################################################
223219
# Generating transcripts
224220
# ----------------------
@@ -241,7 +237,7 @@
241237
# There are many decoding techniques proposed, and they require external
242238
# resources, such as word dictionary and language models.
243239
#
244-
# In this tutorial, for the sake of simplicity, we will perform greeding
240+
# In this tutorial, for the sake of simplicity, we will perform greedy
245241
# decoding which does not depend on such external components, and simply
246242
# pick up the best hypothesis at each time step. Therefore, the context
247243
# information are not used, and only one transcript can be generated.
@@ -259,6 +255,7 @@ def forward(self, emission: torch.Tensor) -> str:
259255
"""Given a sequence emission over labels, get the best path string
260256
Args:
261257
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
258+
262259
Returns:
263260
str: The resulting transcript
264261
"""
@@ -273,24 +270,22 @@ def forward(self, emission: torch.Tensor) -> str:
273270
#
274271

275272
decoder = GreedyCTCDecoder(
276-
labels=bundle.get_labels(),
273+
labels=bundle.get_labels(),
277274
ignore=(0, 1, 2, 3),
278275
)
279276
transcript = decoder(emission[0])
280277

281278

282279
######################################################################
283-
# Let’s check the result and listen again the audio.
280+
# Let’s check the result and listen again to the audio.
284281
#
285282

286283
print(transcript)
287-
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
284+
IPython.display.Audio(SPEECH_FILE)
288285

289286

290287
######################################################################
291-
# There are few remarks in decoding.
292-
#
293-
# Firstly, the ASR model is fine-tuned using a loss function called CTC.
288+
# The ASR model is fine-tuned using a loss function called Connectionist Temporal Classification (CTC).
294289
# The detail of CTC loss is explained
295290
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
296291
# special token which represents a repetition of the previous symbol. In
@@ -301,12 +296,11 @@ def forward(self, emission: torch.Tensor) -> str:
301296
# These also have to be ignored.
302297
#
303298

304-
305299
######################################################################
306300
# Conclusion
307301
# ----------
308302
#
309-
# In this tutorial, we looked at how to use ``torchaudio.pipeline`` to
303+
# In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to
310304
# perform acoustic feature extraction and speech recognition. Constructing
311305
# a model and getting the emission is as short as two lines.
312306
#

0 commit comments

Comments
 (0)