Skip to content

Commit 5fdf751

Browse files
authored
Merge pull request #687 from vincentqb/reinstate-tutorial-audio
Reinstate torchaudio tutorial
2 parents 514a513 + 0b5aa0d commit 5fdf751

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""
2+
torchaudio Tutorial
3+
===================
4+
5+
PyTorch is an open source deep learning platform that provides a
6+
seamless path from research prototyping to production deployment with
7+
GPU support.
8+
9+
Significant effort in solving machine learning problems goes into data
10+
preparation. ``torchaudio`` leverages PyTorch’s GPU support, and provides
11+
many tools to make data loading easy and more readable. In this
12+
tutorial, we will see how to load and preprocess data from a simple
13+
dataset.
14+
15+
For this tutorial, please make sure the ``matplotlib`` package is
16+
installed for easier visualization.
17+
18+
"""
19+
20+
import torch
21+
import torchaudio
22+
import matplotlib.pyplot as plt
23+
24+
25+
######################################################################
26+
# Opening a dataset
27+
# -----------------
28+
#
29+
30+
31+
######################################################################
32+
# torchaudio supports loading sound files in the wav and mp3 format. We
33+
# call waveform the resulting raw audio signal.
34+
#
35+
36+
filename = "../_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav"
37+
waveform, sample_rate = torchaudio.load(filename)
38+
39+
print("Shape of waveform: {}".format(waveform.size()))
40+
print("Sample rate of waveform: {}".format(sample_rate))
41+
42+
plt.figure()
43+
plt.plot(waveform.t().numpy())
44+
45+
46+
######################################################################
47+
# Transformations
48+
# ---------------
49+
#
50+
# torchaudio supports a growing list of
51+
# `transformations <https://pytorch.org/audio/transforms.html>`_.
52+
#
53+
# - **Resample**: Resample waveform to a different sample rate.
54+
# - **Spectrogram**: Create a spectrogram from a waveform.
55+
# - **MelScale**: This turns a normal STFT into a Mel-frequency STFT,
56+
# using a conversion matrix.
57+
# - **AmplitudeToDB**: This turns a spectrogram from the
58+
# power/amplitude scale to the decibel scale.
59+
# - **MFCC**: Create the Mel-frequency cepstrum coefficients from a
60+
# waveform.
61+
# - **MelSpectrogram**: Create MEL Spectrograms from a waveform using the
62+
# STFT function in PyTorch.
63+
# - **MuLawEncoding**: Encode waveform based on mu-law companding.
64+
# - **MuLawDecoding**: Decode mu-law encoded waveform.
65+
#
66+
# Since all transforms are nn.Modules or jit.ScriptModules, they can be
67+
# used as part of a neural network at any point.
68+
#
69+
70+
71+
######################################################################
72+
# To start, we can look at the log of the spectrogram on a log scale.
73+
#
74+
75+
specgram = torchaudio.transforms.Spectrogram()(waveform)
76+
77+
print("Shape of spectrogram: {}".format(specgram.size()))
78+
79+
plt.figure()
80+
plt.imshow(specgram.log2()[0,:,:].numpy(), cmap='gray')
81+
82+
83+
######################################################################
84+
# Or we can look at the Mel Spectrogram on a log scale.
85+
#
86+
87+
specgram = torchaudio.transforms.MelSpectrogram()(waveform)
88+
89+
print("Shape of spectrogram: {}".format(specgram.size()))
90+
91+
plt.figure()
92+
p = plt.imshow(specgram.log2()[0,:,:].detach().numpy(), cmap='gray')
93+
94+
95+
######################################################################
96+
# We can resample the waveform, one channel at a time.
97+
#
98+
99+
new_sample_rate = sample_rate/10
100+
101+
# Since Resample applies to a single channel, we resample first channel here
102+
channel = 0
103+
transformed = torchaudio.transforms.Resample(sample_rate, new_sample_rate)(waveform[channel,:].view(1,-1))
104+
105+
print("Shape of transformed waveform: {}".format(transformed.size()))
106+
107+
plt.figure()
108+
plt.plot(transformed[0,:].numpy())
109+
110+
111+
######################################################################
112+
# As another example of transformations, we can encode the signal based on
113+
# Mu-Law enconding. But to do so, we need the signal to be between -1 and
114+
# 1. Since the tensor is just a regular PyTorch tensor, we can apply
115+
# standard operators on it.
116+
#
117+
118+
# Let's check if the tensor is in the interval [-1,1]
119+
print("Min of waveform: {}\nMax of waveform: {}\nMean of waveform: {}".format(waveform.min(), waveform.max(), waveform.mean()))
120+
121+
122+
######################################################################
123+
# Since the waveform is already between -1 and 1, we do not need to
124+
# normalize it.
125+
#
126+
127+
def normalize(tensor):
128+
# Subtract the mean, and scale to the interval [-1,1]
129+
tensor_minusmean = tensor - tensor.mean()
130+
return tensor_minusmean/tensor_minusmean.abs().max()
131+
132+
# Let's normalize to the full interval [-1,1]
133+
# waveform = normalize(waveform)
134+
135+
136+
######################################################################
137+
# Let’s apply encode the waveform.
138+
#
139+
140+
transformed = torchaudio.transforms.MuLawEncoding()(waveform)
141+
142+
print("Shape of transformed waveform: {}".format(transformed.size()))
143+
144+
plt.figure()
145+
plt.plot(transformed[0,:].numpy())
146+
147+
148+
######################################################################
149+
# And now decode.
150+
#
151+
152+
reconstructed = torchaudio.transforms.MuLawDecoding()(transformed)
153+
154+
print("Shape of recovered waveform: {}".format(reconstructed.size()))
155+
156+
plt.figure()
157+
plt.plot(reconstructed[0,:].numpy())
158+
159+
160+
######################################################################
161+
# We can finally compare the original waveform with its reconstructed
162+
# version.
163+
#
164+
165+
# Compute median relative difference
166+
err = ((waveform-reconstructed).abs() / waveform.abs()).median()
167+
168+
print("Median relative difference between original and MuLaw reconstucted signals: {:.2%}".format(err))
169+
170+
171+
######################################################################
172+
# Migrating to torchaudio from Kaldi
173+
# ----------------------------------
174+
#
175+
# Users may be familiar with
176+
# `Kaldi <http://github.com/kaldi-asr/kaldi>`_, a toolkit for speech
177+
# recognition. torchaudio offers compatibility with it in
178+
# ``torchaudio.kaldi_io``. It can indeed read from kaldi scp, or ark file
179+
# or streams with:
180+
#
181+
# - read_vec_int_ark
182+
# - read_vec_flt_scp
183+
# - read_vec_flt_arkfile/stream
184+
# - read_mat_scp
185+
# - read_mat_ark
186+
#
187+
# torchaudio provides Kaldi-compatible transforms for ``spectrogram`` and
188+
# ``fbank`` with the benefit of GPU support, see
189+
# `here <compliance.kaldi.html>`__ for more information.
190+
#
191+
192+
n_fft = 400.0
193+
frame_length = n_fft / sample_rate * 1000.0
194+
frame_shift = frame_length / 2.0
195+
196+
params = {
197+
"channel": 0,
198+
"dither": 0.0,
199+
"window_type": "hanning",
200+
"frame_length": frame_length,
201+
"frame_shift": frame_shift,
202+
"remove_dc_offset": False,
203+
"round_to_power_of_two": False,
204+
"sample_frequency": sample_rate,
205+
}
206+
207+
specgram = torchaudio.compliance.kaldi.spectrogram(waveform, **params)
208+
209+
print("Shape of spectrogram: {}".format(specgram.size()))
210+
211+
plt.figure()
212+
plt.imshow(specgram.t().numpy(), cmap='gray')
213+
214+
215+
######################################################################
216+
# We also support computing the filterbank features from waveforms,
217+
# matching Kaldi’s implementation.
218+
#
219+
220+
fbank = torchaudio.compliance.kaldi.fbank(waveform, **params)
221+
222+
print("Shape of fbank: {}".format(fbank.size()))
223+
224+
plt.figure()
225+
plt.imshow(fbank.t().numpy(), cmap='gray')
226+
227+
228+
######################################################################
229+
# Conclusion
230+
# ----------
231+
#
232+
# We used an example raw audio signal, or waveform, to illustrate how to
233+
# open an audio file using torchaudio, and how to pre-process and
234+
# transform such waveform. Given that torchaudio is built on PyTorch,
235+
# these techniques can be used as building blocks for more advanced audio
236+
# applications, such as speech recognition, while leveraging GPUs.
237+
#

index.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ Named Tensor (experimental)
9797

9898
<div style='clear:both'></div>
9999

100+
Audio
101+
----------------------
102+
103+
.. customgalleryitem::
104+
:figure: /_static/img/audio_preprocessing_tutorial_waveform.png
105+
:tooltip: Preprocessing with torchaudio Tutorial
106+
:description: :doc:`beginner/audio_preprocessing_tutorial`
107+
108+
.. raw:: html
109+
110+
<div style='clear:both'></div>
111+
112+
100113
Text
101114
----------------------
102115

@@ -295,6 +308,14 @@ PyTorch Fundamentals In-Depth
295308
beginner/fgsm_tutorial
296309
beginner/dcgan_faces_tutorial
297310

311+
.. toctree::
312+
:maxdepth: 2
313+
:includehidden:
314+
:hidden:
315+
:caption: Audio
316+
317+
beginner/audio_preprocessing_tutorial
318+
298319
.. toctree::
299320
:maxdepth: 2
300321
:includehidden:

0 commit comments

Comments
 (0)