Skip to content

Commit 96f2247

Browse files
authored
Update forced alignment tutorial (#1740)
1 parent 254d927 commit 96f2247

File tree

1 file changed

+142
-92
lines changed

1 file changed

+142
-92
lines changed

intermediate_source/forced_alignment_with_torchaudio_tutorial.py

Lines changed: 142 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -32,61 +32,62 @@
3232
# Preparation
3333
# -----------
3434
#
35-
# First, we install ``torchaudio``, import packages and fetch the speech
36-
# file.
35+
# First we import the necessary packages, and fetch data that we work on.
3736
#
3837

39-
# !pip install torchaudio
40-
4138
# %matplotlib inline
4239

4340
import os
4441
from dataclasses import dataclass
4542

43+
import torch
44+
import torchaudio
4645
import requests
4746
import matplotlib
4847
import matplotlib.pyplot as plt
48+
import IPython
4949

50-
[width, height] = matplotlib.rcParams['figure.figsize']
51-
if width < 10:
52-
matplotlib.rcParams['figure.figsize'] = [width * 2.5, height]
50+
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
5351

54-
import torch
55-
import torchaudio
52+
torch.random.manual_seed(0)
53+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5654

57-
SPEECH_URL = 'https://download.pytorch.org/torchaudio/test-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac'
58-
SPEECH_FILE = 'speech.flac'
55+
print(torch.__version__)
56+
print(torchaudio.__version__)
57+
print(device)
58+
59+
SPEECH_URL = 'https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav'
60+
SPEECH_FILE = '_assets/speech.wav'
5961

6062
if not os.path.exists(SPEECH_FILE):
63+
os.makedirs('_assets', exist_ok=True)
6164
with open(SPEECH_FILE, 'wb') as file:
62-
with requests.get(SPEECH_URL) as resp:
63-
resp.raise_for_status()
64-
file.write(resp.content)
65-
66-
import IPython
67-
65+
file.write(requests.get(SPEECH_URL).content)
6866

6967
######################################################################
7068
# Generate frame-wise label probability
7169
# -------------------------------------
7270
#
7371
# The first step is to generate the label class porbability of each aduio
74-
# frame. We can use a Wav2Vec2 model that is trained for ASR.
72+
# frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use
73+
# :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
7574
#
7675
# ``torchaudio`` provides easy access to pretrained models with associated
7776
# labels.
7877
#
79-
# **Note** In the subsequent sections, we will compute the probability in
80-
# log-domain to avoid numerical instability. For this purpose, we
81-
# normalize the ``emission`` with ``log_softmax``.
78+
# .. note::
79+
#
80+
# In the subsequent sections, we will compute the probability in
81+
# log-domain to avoid numerical instability. For this purpose, we
82+
# normalize the ``emission`` with :py:func:`torch.log_softmax`.
8283
#
8384

8485
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
85-
model = bundle.get_model()
86+
model = bundle.get_model().to(device)
8687
labels = bundle.get_labels()
8788
with torch.inference_mode():
8889
waveform, _ = torchaudio.load(SPEECH_FILE)
89-
emissions, _ = model(waveform)
90+
emissions, _ = model(waveform.to(device))
9091
emissions = torch.log_softmax(emissions, dim=-1)
9192

9293
emission = emissions[0].cpu().detach()
@@ -132,14 +133,13 @@
132133
# Since we are looking for the most likely transitions, we take the more
133134
# likely path for the value of :math:`k_{(t+1, j+1)}`, that is
134135
#
135-
# $ k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1,
136-
# repeat) ) $
136+
# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )`
137137
#
138138
# where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)`
139139
# represents the probability of label :math:`c_j` at time step :math:`t`.
140140
# :math:`repeat` represents the blank token from CTC formulation. (For the
141-
# detail of CTC algorithm, please refer to the `Sequence Modeling with CTC
142-
# [distill.pub] <https://distill.pub/2017/ctc/>`__)
141+
# detail of CTC algorithm, please refer to the *Sequence Modeling with CTC*
142+
# [`distill.pub <https://distill.pub/2017/ctc/>`__])
143143
#
144144

145145
transcript = 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT'
@@ -176,8 +176,6 @@ def get_trellis(emission, tokens, blank_id=0):
176176
plt.colorbar()
177177
plt.show()
178178

179-
180-
181179
######################################################################
182180
# In the above visualization, we can see that there is a trace of high
183181
# probability crossing the matrix diagonally.
@@ -251,7 +249,7 @@ def backtrack(trellis, emission, tokens, blank_id=0):
251249
print(path)
252250

253251
################################################################################
254-
# visualization
252+
# Visualization
255253
################################################################################
256254
def plot_trellis_with_path(trellis, path):
257255
# To plot trellis with path, we take advantage of 'nan' value
@@ -264,7 +262,6 @@ def plot_trellis_with_path(trellis, path):
264262
plt.title("The path found by backtracking")
265263
plt.show()
266264

267-
268265
######################################################################
269266
# Looking good. Now this path contains repetations for the same labels, so
270267
# let’s merge them to make it close to the original transcript.
@@ -304,7 +301,7 @@ def merge_repeats(path):
304301
print(seg)
305302

306303
################################################################################
307-
# visualization
304+
# Visualization
308305
################################################################################
309306
def plot_trellis_with_segments(trellis, segments, transcript):
310307
# To plot trellis with path, we take advantage of 'nan' value
@@ -313,39 +310,40 @@ def plot_trellis_with_segments(trellis, segments, transcript):
313310
if seg.label != '|':
314311
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
315312

316-
plt.figure()
317-
plt.title("Path, label and probability for each label")
318-
ax1 = plt.axes()
313+
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
314+
ax1.set_title("Path, label and probability for each label")
319315
ax1.imshow(trellis_with_path.T, origin='lower')
320316
ax1.set_xticks([])
321317

322318
for i, seg in enumerate(segments):
323319
if seg.label != '|':
324-
ax1.annotate(seg.label, (seg.start + .7, i + 0.3))
320+
ax1.annotate(seg.label, (seg.start + .7, i + 0.3), weight='bold')
325321
ax1.annotate(f'{seg.score:.2f}', (seg.start - .3, i + 4.3))
326-
327-
plt.figure()
328-
plt.title("Probability for each label at each time index")
329-
ax2 = plt.axes()
322+
323+
ax2.set_title("Label probability with and without repetation")
324+
xs, hs, ws = [], [], []
325+
for seg in segments:
326+
if seg.label != '|':
327+
xs.append((seg.end + seg.start) / 2 + .4)
328+
hs.append(seg.score)
329+
ws.append(seg.end - seg.start)
330+
ax2.annotate(seg.label, (seg.start + .8, -0.07), weight='bold')
331+
ax2.bar(xs, hs, width=ws, color='gray', alpha=0.5, edgecolor='black')
332+
330333
xs, hs = [], []
331334
for p in path:
332335
label = transcript[p.token_index]
333336
if label != '|':
334337
xs.append(p.time_index + 1)
335338
hs.append(p.score)
336339

337-
for seg in segments:
338-
if seg.label != '|':
339-
ax2.axvspan(seg.start+.4, seg.end+.4, color='gray', alpha=0.2)
340-
ax2.annotate(seg.label, (seg.start + .8, -0.07))
341-
342-
ax2.bar(xs, hs, width=0.5)
340+
ax2.bar(xs, hs, width=0.5, alpha=0.5)
343341
ax2.axhline(0, color='black')
344-
ax2.set_position(ax1.get_position())
345342
ax2.set_xlim(ax1.get_xlim())
346343
ax2.set_ylim(-0.1, 1.1)
347344

348345
plot_trellis_with_segments(trellis, segments, transcript)
346+
plt.tight_layout()
349347
plt.show()
350348

351349

@@ -380,64 +378,116 @@ def merge_words(segments, separator='|'):
380378
print(word)
381379

382380
################################################################################
383-
# visualization
381+
# Visualization
384382
################################################################################
385-
trellis_with_path = trellis.clone()
386-
for i, seg in enumerate(segments):
387-
if seg.label != '|':
388-
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
383+
def plot_alignments(trellis, segments, word_segments, waveform):
384+
trellis_with_path = trellis.clone()
385+
for i, seg in enumerate(segments):
386+
if seg.label != '|':
387+
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
389388

390-
plt.imshow(trellis_with_path[1:, 1:].T, origin='lower')
391-
ax1 = plt.gca()
392-
ax1.set_yticks([])
393-
ax1.set_xticks([])
389+
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
394390

391+
ax1.imshow(trellis_with_path[1:, 1:].T, origin='lower')
392+
ax1.set_xticks([])
393+
ax1.set_yticks([])
395394

396-
for word in word_segments:
397-
plt.axvline(word.start - 0.5)
398-
plt.axvline(word.end - 0.5)
395+
for word in word_segments:
396+
ax1.axvline(word.start - 0.5)
397+
ax1.axvline(word.end - 0.5)
399398

400-
for i, seg in enumerate(segments):
401-
if seg.label != '|':
402-
plt.annotate(seg.label, (seg.start, i + 0.3))
403-
plt.annotate(f'{seg.score:.2f}', (seg.start , i + 4), fontsize=8)
399+
for i, seg in enumerate(segments):
400+
if seg.label != '|':
401+
ax1.annotate(seg.label, (seg.start, i + 0.3))
402+
ax1.annotate(f'{seg.score:.2f}', (seg.start , i + 4), fontsize=8)
403+
404+
# The original waveform
405+
ratio = waveform.size(0) / (trellis.size(0) - 1)
406+
ax2.plot(waveform)
407+
for word in word_segments:
408+
x0 = ratio * word.start
409+
x1 = ratio * word.end
410+
ax2.axvspan(x0, x1, alpha=0.1, color='red')
411+
ax2.annotate(f'{word.score:.2f}', (x0, 0.8))
404412

413+
for seg in segments:
414+
if seg.label != '|':
415+
ax2.annotate(seg.label, (seg.start * ratio, 0.9))
416+
xticks = ax2.get_xticks()
417+
plt.xticks(xticks, xticks / bundle.sample_rate)
418+
ax2.set_xlabel('time [second]')
419+
ax2.set_yticks([])
420+
ax2.set_ylim(-1.0, 1.0)
421+
ax2.set_xlim(0, waveform.size(-1))
422+
423+
plot_alignments(trellis, segments, word_segments, waveform[0],)
405424
plt.show()
406425

407-
# The original waveform
408-
ratio = waveform.size(1) / (trellis.size(0) - 1)
409-
plt.plot(waveform[0])
410-
for word in word_segments:
411-
x0 = ratio * word.start
412-
x1 = ratio * word.end
413-
plt.axvspan(x0, x1, alpha=0.1, color='red')
414-
plt.annotate(f'{word.score:.2f}', (x0, 0.8))
426+
# A trick to embed the resulting audio to the generated file.
427+
# `IPython.display.Audio` has to be the last call in a cell,
428+
# and there should be only one call par cell.
429+
def display_segment(i):
430+
ratio = waveform.size(1) / (trellis.size(0) - 1)
431+
word = word_segments[i]
432+
x0 = int(ratio * word.start)
433+
x1 = int(ratio * word.end)
434+
filename = f"_assets/{i}_{word.label}.wav"
435+
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
436+
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
437+
return IPython.display.Audio(filename)
415438

416-
for seg in segments:
417-
if seg.label != '|':
418-
plt.annotate(seg.label, (seg.start * ratio, 0.9))
419-
420-
ax2 = plt.gca()
421-
xticks = ax2.get_xticks()
422-
plt.xticks(xticks, xticks / bundle.sample_rate)
423-
plt.xlabel('time [second]')
424-
ax2.set_position(ax1.get_position())
425-
ax2.set_yticks([])
426-
ax2.set_ylim(-1.0, 1.0)
427-
ax2.set_xlim(0, waveform.size(-1))
428-
plt.show()
439+
######################################################################
440+
#
429441

430442
# Generate the audio for each segment
431443
print(transcript)
432-
IPython.display.display(IPython.display.Audio(SPEECH_FILE))
433-
for i, word in enumerate(word_segments):
434-
x0 = int(ratio * word.start)
435-
x1 = int(ratio * word.end)
436-
filename = f"{i}_{word.label}.wav"
437-
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
438-
print(f"{word.label}: {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f}")
439-
IPython.display.display(IPython.display.Audio(filename))
444+
IPython.display.Audio(SPEECH_FILE)
445+
446+
447+
######################################################################
448+
#
449+
450+
display_segment(0)
451+
452+
######################################################################
453+
#
454+
455+
display_segment(1)
456+
457+
######################################################################
458+
#
459+
460+
display_segment(2)
461+
462+
######################################################################
463+
#
464+
465+
display_segment(3)
466+
467+
######################################################################
468+
#
469+
470+
display_segment(4)
471+
472+
######################################################################
473+
#
474+
475+
display_segment(5)
476+
477+
######################################################################
478+
#
479+
480+
display_segment(6)
481+
482+
######################################################################
483+
#
484+
485+
display_segment(7)
486+
487+
######################################################################
488+
#
440489

490+
display_segment(8)
441491

442492
######################################################################
443493
# Conclusion

0 commit comments

Comments
 (0)