Skip to content

Update speech recognition tutorial #1741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 9, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 47 additions & 53 deletions intermediate_source/speech_recognition_pipeline_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Speech Recognition with Torchaudio
==================================
Speech Recognition with Wav2Vec2
================================

**Author**: `Moto Hira <moto@fb.com>`__

Expand Down Expand Up @@ -39,34 +39,31 @@

# %matplotlib inline

import os

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)


import requests
import matplotlib
import matplotlib.pyplot as plt

[width, height] = matplotlib.rcParams['figure.figsize']
if width < 10:
matplotlib.rcParams['figure.figsize'] = [width * 2.5, height]

import IPython

import requests
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]

torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SPEECH_FILE = "speech.wav"
print(torch.__version__)
print(torchaudio.__version__)
print(device)

url = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
with open(SPEECH_FILE, 'wb') as file_:
file_.write(requests.get(url).content)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE = "_assets/speech.wav"

if not os.path.exists(SPEECH_FILE):
os.makedirs('_assets', exist_ok=True)
with open(SPEECH_FILE, 'wb') as file:
file.write(requests.get(SPEECH_URL).content)


######################################################################
Expand All @@ -88,11 +85,10 @@
# for other downstream tasks as well, but this tutorial does not
# cover that.
#
# We will use ``torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`` here.
# We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
#
# There are multiple models available in
# ``torchaudio.pipelines``. Please check the
# `documentation <https://pytorch.org/audio/stable/pipelines.html>`__ for
# There are multiple models available as
# :py:mod:`torchaudio.pipelines`. Please check the documentation for
# the detail of how they are trained.
#
# The bundle object provides the interface to instantiate model and other
Expand Down Expand Up @@ -125,21 +121,20 @@
# Creative Commos BY 4.0.
#

IPython.display.display(IPython.display.Audio(SPEECH_FILE))
IPython.display.Audio(SPEECH_FILE)


######################################################################
# To load data, we use ``torchaudio.load``.
# To load data, we use :py:func:`torchaudio.load`.
#
# If the sampling rate is different from what the pipeline expects, then
# we can use ``torchaudio.functional.resample`` for resampling.
# we can use :py:func:`torchaudio.functional.resample` for resampling.
#
# **Note** -
# ```torchaudio.functional.resample`` <https://pytorch.org/audio/stable/functional.html#resample>`__
# works on CUDA tensors as well. - When performing resampling multiple
# times on the same set of sample rates, using
# ```torchaudio.transforms.Resample`` <https://pytorch.org/audio/stable/transforms.html#resample>`__
# might improve the performace.
# .. note::
#
# - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
# - When performing resampling multiple times on the same set of sample rates,
# using :py:func:`torchaudio.transforms.Resample` might improve the performace.
#

waveform, sample_rate = torchaudio.load(SPEECH_FILE)
Expand All @@ -155,9 +150,10 @@
#
# The next step is to extract acoustic features from the audio.
#
# Note that Wav2Vec2 models fine-tuned for ASR task can perform feature
# extraction and classification with one step, but for the sake of the
# tutorial, we also show how to perform feature extraction here.
# .. note::
# Wav2Vec2 models fine-tuned for ASR task can perform feature
# extraction and classification with one step, but for the sake of the
# tutorial, we also show how to perform feature extraction here.
#

with torch.inference_mode():
Expand All @@ -169,12 +165,14 @@
# a transformer layer.
#

fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
plt.imshow(feats[0].cpu())
plt.title(f"Feature from transformer layer {i+1}")
plt.xlabel("Feature dimension")
plt.ylabel("Frame (time-axis)")
plt.show()
ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()


######################################################################
Expand Down Expand Up @@ -203,7 +201,6 @@
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.colorbar()
plt.show()
print("Class labels:", bundle.get_labels())

Expand All @@ -218,7 +215,6 @@
# not used during the training.
#


######################################################################
# Generating transcripts
# ----------------------
Expand All @@ -241,7 +237,7 @@
# There are many decoding techniques proposed, and they require external
# resources, such as word dictionary and language models.
#
# In this tutorial, for the sake of simplicity, we will perform greeding
# In this tutorial, for the sake of simplicity, we will perform greedy
# decoding which does not depend on such external components, and simply
# pick up the best hypothesis at each time step. Therefore, the context
# information are not used, and only one transcript can be generated.
Expand All @@ -259,6 +255,7 @@ def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

Returns:
str: The resulting transcript
"""
Expand All @@ -273,24 +270,22 @@ def forward(self, emission: torch.Tensor) -> str:
#

decoder = GreedyCTCDecoder(
labels=bundle.get_labels(),
labels=bundle.get_labels(),
ignore=(0, 1, 2, 3),
)
transcript = decoder(emission[0])


######################################################################
# Let’s check the result and listen again the audio.
# Let’s check the result and listen again to the audio.
#

print(transcript)
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
IPython.display.Audio(SPEECH_FILE)


######################################################################
# There are few remarks in decoding.
#
# Firstly, the ASR model is fine-tuned using a loss function called CTC.
# The ASR model is fine-tuned using a loss function called Connectionist Temporal Classification (CTC).
# The detail of CTC loss is explained
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
# special token which represents a repetition of the previous symbol. In
Expand All @@ -301,12 +296,11 @@ def forward(self, emission: torch.Tensor) -> str:
# These also have to be ignored.
#


######################################################################
# Conclusion
# ----------
#
# In this tutorial, we looked at how to use ``torchaudio.pipeline`` to
# In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to
# perform acoustic feature extraction and speech recognition. Constructing
# a model and getting the emission is as short as two lines.
#
Expand Down