1
1
"""
2
- Speech Recognition with Torchaudio
3
- ==================================
2
+ Speech Recognition with Wav2Vec2
3
+ ================================
4
4
5
5
**Author**: `Moto Hira <moto@fb.com>`__
6
6
39
39
40
40
# %matplotlib inline
41
41
42
+ import os
43
+
42
44
import torch
43
45
import torchaudio
44
-
45
- print (torch .__version__ )
46
- print (torchaudio .__version__ )
47
-
48
- device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
49
-
50
- print (device )
51
-
52
-
46
+ import requests
53
47
import matplotlib
54
48
import matplotlib .pyplot as plt
55
-
56
- [width , height ] = matplotlib .rcParams ['figure.figsize' ]
57
- if width < 10 :
58
- matplotlib .rcParams ['figure.figsize' ] = [width * 2.5 , height ]
59
-
60
49
import IPython
61
50
62
- import requests
51
+ matplotlib .rcParams ['figure.figsize' ] = [16.0 , 4.8 ]
52
+
53
+ torch .random .manual_seed (0 )
54
+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
63
55
64
- SPEECH_FILE = "speech.wav"
56
+ print (torch .__version__ )
57
+ print (torchaudio .__version__ )
58
+ print (device )
65
59
66
- url = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
67
- with open (SPEECH_FILE , 'wb' ) as file_ :
68
- file_ .write (requests .get (url ).content )
60
+ SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
61
+ SPEECH_FILE = "_assets/speech.wav"
69
62
63
+ if not os .path .exists (SPEECH_FILE ):
64
+ os .makedirs ('_assets' , exist_ok = True )
65
+ with open (SPEECH_FILE , 'wb' ) as file :
66
+ file .write (requests .get (SPEECH_URL ).content )
70
67
71
68
72
69
######################################################################
88
85
# for other downstream tasks as well, but this tutorial does not
89
86
# cover that.
90
87
#
91
- # We will use `` torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` ` here.
88
+ # We will use :py:func:` torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
92
89
#
93
- # There are multiple models available in
94
- # ``torchaudio.pipelines``. Please check the
95
- # `documentation <https://pytorch.org/audio/stable/pipelines.html>`__ for
90
+ # There are multiple models available as
91
+ # :py:mod:`torchaudio.pipelines`. Please check the documentation for
96
92
# the detail of how they are trained.
97
93
#
98
94
# The bundle object provides the interface to instantiate model and other
125
121
# Creative Commos BY 4.0.
126
122
#
127
123
128
- IPython .display .display ( IPython . display . Audio (SPEECH_FILE ) )
124
+ IPython .display .Audio (SPEECH_FILE )
129
125
130
126
131
127
######################################################################
132
- # To load data, we use `` torchaudio.load` `.
128
+ # To load data, we use :py:func:` torchaudio.load`.
133
129
#
134
130
# If the sampling rate is different from what the pipeline expects, then
135
- # we can use `` torchaudio.functional.resample` ` for resampling.
131
+ # we can use :py:func:` torchaudio.functional.resample` for resampling.
136
132
#
137
- # **Note** -
138
- # ```torchaudio.functional.resample`` <https://pytorch.org/audio/stable/functional.html#resample>`__
139
- # works on CUDA tensors as well. - When performing resampling multiple
140
- # times on the same set of sample rates, using
141
- # ```torchaudio.transforms.Resample`` <https://pytorch.org/audio/stable/transforms.html#resample>`__
142
- # might improve the performace.
133
+ # .. note::
134
+ #
135
+ # - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
136
+ # - When performing resampling multiple times on the same set of sample rates,
137
+ # using :py:func:`torchaudio.transforms.Resample` might improve the performace.
143
138
#
144
139
145
140
waveform , sample_rate = torchaudio .load (SPEECH_FILE )
155
150
#
156
151
# The next step is to extract acoustic features from the audio.
157
152
#
158
- # Note that Wav2Vec2 models fine-tuned for ASR task can perform feature
159
- # extraction and classification with one step, but for the sake of the
160
- # tutorial, we also show how to perform feature extraction here.
153
+ # .. note::
154
+ # Wav2Vec2 models fine-tuned for ASR task can perform feature
155
+ # extraction and classification with one step, but for the sake of the
156
+ # tutorial, we also show how to perform feature extraction here.
161
157
#
162
158
163
159
with torch .inference_mode ():
169
165
# a transformer layer.
170
166
#
171
167
168
+ fig , ax = plt .subplots (len (features ), 1 , figsize = (16 , 4.3 * len (features )))
172
169
for i , feats in enumerate (features ):
173
- plt .imshow (feats [0 ].cpu ())
174
- plt .title (f"Feature from transformer layer { i + 1 } " )
175
- plt .xlabel ("Feature dimension" )
176
- plt .ylabel ("Frame (time-axis)" )
177
- plt .show ()
170
+ ax [i ].imshow (feats [0 ].cpu ())
171
+ ax [i ].set_title (f"Feature from transformer layer { i + 1 } " )
172
+ ax [i ].set_xlabel ("Feature dimension" )
173
+ ax [i ].set_ylabel ("Frame (time-axis)" )
174
+ plt .tight_layout ()
175
+ plt .show ()
178
176
179
177
180
178
######################################################################
203
201
plt .title ("Classification result" )
204
202
plt .xlabel ("Frame (time-axis)" )
205
203
plt .ylabel ("Class" )
206
- plt .colorbar ()
207
204
plt .show ()
208
205
print ("Class labels:" , bundle .get_labels ())
209
206
218
215
# not used during the training.
219
216
#
220
217
221
-
222
218
######################################################################
223
219
# Generating transcripts
224
220
# ----------------------
241
237
# There are many decoding techniques proposed, and they require external
242
238
# resources, such as word dictionary and language models.
243
239
#
244
- # In this tutorial, for the sake of simplicity, we will perform greeding
240
+ # In this tutorial, for the sake of simplicity, we will perform greedy
245
241
# decoding which does not depend on such external components, and simply
246
242
# pick up the best hypothesis at each time step. Therefore, the context
247
243
# information are not used, and only one transcript can be generated.
@@ -259,6 +255,7 @@ def forward(self, emission: torch.Tensor) -> str:
259
255
"""Given a sequence emission over labels, get the best path string
260
256
Args:
261
257
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
258
+
262
259
Returns:
263
260
str: The resulting transcript
264
261
"""
@@ -273,24 +270,22 @@ def forward(self, emission: torch.Tensor) -> str:
273
270
#
274
271
275
272
decoder = GreedyCTCDecoder (
276
- labels = bundle .get_labels (),
273
+ labels = bundle .get_labels (),
277
274
ignore = (0 , 1 , 2 , 3 ),
278
275
)
279
276
transcript = decoder (emission [0 ])
280
277
281
278
282
279
######################################################################
283
- # Let’s check the result and listen again the audio.
280
+ # Let’s check the result and listen again to the audio.
284
281
#
285
282
286
283
print (transcript )
287
- IPython .display .display ( IPython . display . Audio (SPEECH_FILE ) )
284
+ IPython .display .Audio (SPEECH_FILE )
288
285
289
286
290
287
######################################################################
291
- # There are few remarks in decoding.
292
- #
293
- # Firstly, the ASR model is fine-tuned using a loss function called CTC.
288
+ # The ASR model is fine-tuned using a loss function called Connectionist Temporal Classification (CTC).
294
289
# The detail of CTC loss is explained
295
290
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
296
291
# special token which represents a repetition of the previous symbol. In
@@ -301,12 +296,11 @@ def forward(self, emission: torch.Tensor) -> str:
301
296
# These also have to be ignored.
302
297
#
303
298
304
-
305
299
######################################################################
306
300
# Conclusion
307
301
# ----------
308
302
#
309
- # In this tutorial, we looked at how to use `` torchaudio.pipeline` ` to
303
+ # In this tutorial, we looked at how to use :py:mod:` torchaudio.pipelines ` to
310
304
# perform acoustic feature extraction and speech recognition. Constructing
311
305
# a model and getting the emission is as short as two lines.
312
306
#
0 commit comments