32
32
# Preparation
33
33
# -----------
34
34
#
35
- # First, we install ``torchaudio``, import packages and fetch the speech
36
- # file.
35
+ # First we import the necessary packages, and fetch data that we work on.
37
36
#
38
37
39
- # !pip install torchaudio
40
-
41
38
# %matplotlib inline
42
39
43
40
import os
44
41
from dataclasses import dataclass
45
42
43
+ import torch
44
+ import torchaudio
46
45
import requests
47
46
import matplotlib
48
47
import matplotlib .pyplot as plt
48
+ import IPython
49
49
50
- [width , height ] = matplotlib .rcParams ['figure.figsize' ]
51
- if width < 10 :
52
- matplotlib .rcParams ['figure.figsize' ] = [width * 2.5 , height ]
50
+ matplotlib .rcParams ['figure.figsize' ] = [16.0 , 4.8 ]
53
51
54
- import torch
55
- import torchaudio
52
+ torch . random . manual_seed ( 0 )
53
+ device = torch . device ( 'cuda' if torch . cuda . is_available () else 'cpu' )
56
54
57
- SPEECH_URL = 'https://download.pytorch.org/torchaudio/test-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac'
58
- SPEECH_FILE = 'speech.flac'
55
+ print (torch .__version__ )
56
+ print (torchaudio .__version__ )
57
+ print (device )
58
+
59
+ SPEECH_URL = 'https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav'
60
+ SPEECH_FILE = '_assets/speech.wav'
59
61
60
62
if not os .path .exists (SPEECH_FILE ):
63
+ os .makedirs ('_assets' , exist_ok = True )
61
64
with open (SPEECH_FILE , 'wb' ) as file :
62
- with requests .get (SPEECH_URL ) as resp :
63
- resp .raise_for_status ()
64
- file .write (resp .content )
65
-
66
- import IPython
67
-
65
+ file .write (requests .get (SPEECH_URL ).content )
68
66
69
67
######################################################################
70
68
# Generate frame-wise label probability
71
69
# -------------------------------------
72
70
#
73
71
# The first step is to generate the label class porbability of each aduio
74
- # frame. We can use a Wav2Vec2 model that is trained for ASR.
72
+ # frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use
73
+ # :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
75
74
#
76
75
# ``torchaudio`` provides easy access to pretrained models with associated
77
76
# labels.
78
77
#
79
- # **Note** In the subsequent sections, we will compute the probability in
80
- # log-domain to avoid numerical instability. For this purpose, we
81
- # normalize the ``emission`` with ``log_softmax``.
78
+ # .. note::
79
+ #
80
+ # In the subsequent sections, we will compute the probability in
81
+ # log-domain to avoid numerical instability. For this purpose, we
82
+ # normalize the ``emission`` with :py:func:`torch.log_softmax`.
82
83
#
83
84
84
85
bundle = torchaudio .pipelines .WAV2VEC2_ASR_BASE_960H
85
- model = bundle .get_model ()
86
+ model = bundle .get_model (). to ( device )
86
87
labels = bundle .get_labels ()
87
88
with torch .inference_mode ():
88
89
waveform , _ = torchaudio .load (SPEECH_FILE )
89
- emissions , _ = model (waveform )
90
+ emissions , _ = model (waveform . to ( device ) )
90
91
emissions = torch .log_softmax (emissions , dim = - 1 )
91
92
92
93
emission = emissions [0 ].cpu ().detach ()
132
133
# Since we are looking for the most likely transitions, we take the more
133
134
# likely path for the value of :math:`k_{(t+1, j+1)}`, that is
134
135
#
135
- # $ k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1,
136
- # repeat) ) $
136
+ # :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )`
137
137
#
138
138
# where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)`
139
139
# represents the probability of label :math:`c_j` at time step :math:`t`.
140
140
# :math:`repeat` represents the blank token from CTC formulation. (For the
141
- # detail of CTC algorithm, please refer to the ` Sequence Modeling with CTC
142
- # [distill.pub] <https://distill.pub/2017/ctc/>`__)
141
+ # detail of CTC algorithm, please refer to the * Sequence Modeling with CTC*
142
+ # [` distill.pub <https://distill.pub/2017/ctc/>`__] )
143
143
#
144
144
145
145
transcript = 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT'
@@ -176,8 +176,6 @@ def get_trellis(emission, tokens, blank_id=0):
176
176
plt .colorbar ()
177
177
plt .show ()
178
178
179
-
180
-
181
179
######################################################################
182
180
# In the above visualization, we can see that there is a trace of high
183
181
# probability crossing the matrix diagonally.
@@ -251,7 +249,7 @@ def backtrack(trellis, emission, tokens, blank_id=0):
251
249
print (path )
252
250
253
251
################################################################################
254
- # visualization
252
+ # Visualization
255
253
################################################################################
256
254
def plot_trellis_with_path (trellis , path ):
257
255
# To plot trellis with path, we take advantage of 'nan' value
@@ -264,7 +262,6 @@ def plot_trellis_with_path(trellis, path):
264
262
plt .title ("The path found by backtracking" )
265
263
plt .show ()
266
264
267
-
268
265
######################################################################
269
266
# Looking good. Now this path contains repetations for the same labels, so
270
267
# let’s merge them to make it close to the original transcript.
@@ -304,7 +301,7 @@ def merge_repeats(path):
304
301
print (seg )
305
302
306
303
################################################################################
307
- # visualization
304
+ # Visualization
308
305
################################################################################
309
306
def plot_trellis_with_segments (trellis , segments , transcript ):
310
307
# To plot trellis with path, we take advantage of 'nan' value
@@ -313,39 +310,40 @@ def plot_trellis_with_segments(trellis, segments, transcript):
313
310
if seg .label != '|' :
314
311
trellis_with_path [seg .start + 1 :seg .end + 1 , i + 1 ] = float ('nan' )
315
312
316
- plt .figure ()
317
- plt .title ("Path, label and probability for each label" )
318
- ax1 = plt .axes ()
313
+ fig , [ax1 , ax2 ] = plt .subplots (2 , 1 , figsize = (16 , 9.5 ))
314
+ ax1 .set_title ("Path, label and probability for each label" )
319
315
ax1 .imshow (trellis_with_path .T , origin = 'lower' )
320
316
ax1 .set_xticks ([])
321
317
322
318
for i , seg in enumerate (segments ):
323
319
if seg .label != '|' :
324
- ax1 .annotate (seg .label , (seg .start + .7 , i + 0.3 ))
320
+ ax1 .annotate (seg .label , (seg .start + .7 , i + 0.3 ), weight = 'bold' )
325
321
ax1 .annotate (f'{ seg .score :.2f} ' , (seg .start - .3 , i + 4.3 ))
326
-
327
- plt .figure ()
328
- plt .title ("Probability for each label at each time index" )
329
- ax2 = plt .axes ()
322
+
323
+ ax2 .set_title ("Label probability with and without repetation" )
324
+ xs , hs , ws = [], [], []
325
+ for seg in segments :
326
+ if seg .label != '|' :
327
+ xs .append ((seg .end + seg .start ) / 2 + .4 )
328
+ hs .append (seg .score )
329
+ ws .append (seg .end - seg .start )
330
+ ax2 .annotate (seg .label , (seg .start + .8 , - 0.07 ), weight = 'bold' )
331
+ ax2 .bar (xs , hs , width = ws , color = 'gray' , alpha = 0.5 , edgecolor = 'black' )
332
+
330
333
xs , hs = [], []
331
334
for p in path :
332
335
label = transcript [p .token_index ]
333
336
if label != '|' :
334
337
xs .append (p .time_index + 1 )
335
338
hs .append (p .score )
336
339
337
- for seg in segments :
338
- if seg .label != '|' :
339
- ax2 .axvspan (seg .start + .4 , seg .end + .4 , color = 'gray' , alpha = 0.2 )
340
- ax2 .annotate (seg .label , (seg .start + .8 , - 0.07 ))
341
-
342
- ax2 .bar (xs , hs , width = 0.5 )
340
+ ax2 .bar (xs , hs , width = 0.5 , alpha = 0.5 )
343
341
ax2 .axhline (0 , color = 'black' )
344
- ax2 .set_position (ax1 .get_position ())
345
342
ax2 .set_xlim (ax1 .get_xlim ())
346
343
ax2 .set_ylim (- 0.1 , 1.1 )
347
344
348
345
plot_trellis_with_segments (trellis , segments , transcript )
346
+ plt .tight_layout ()
349
347
plt .show ()
350
348
351
349
@@ -380,64 +378,116 @@ def merge_words(segments, separator='|'):
380
378
print (word )
381
379
382
380
################################################################################
383
- # visualization
381
+ # Visualization
384
382
################################################################################
385
- trellis_with_path = trellis .clone ()
386
- for i , seg in enumerate (segments ):
387
- if seg .label != '|' :
388
- trellis_with_path [seg .start + 1 :seg .end + 1 , i + 1 ] = float ('nan' )
383
+ def plot_alignments (trellis , segments , word_segments , waveform ):
384
+ trellis_with_path = trellis .clone ()
385
+ for i , seg in enumerate (segments ):
386
+ if seg .label != '|' :
387
+ trellis_with_path [seg .start + 1 :seg .end + 1 , i + 1 ] = float ('nan' )
389
388
390
- plt .imshow (trellis_with_path [1 :, 1 :].T , origin = 'lower' )
391
- ax1 = plt .gca ()
392
- ax1 .set_yticks ([])
393
- ax1 .set_xticks ([])
389
+ fig , [ax1 , ax2 ] = plt .subplots (2 , 1 , figsize = (16 , 9.5 ))
394
390
391
+ ax1 .imshow (trellis_with_path [1 :, 1 :].T , origin = 'lower' )
392
+ ax1 .set_xticks ([])
393
+ ax1 .set_yticks ([])
395
394
396
- for word in word_segments :
397
- plt .axvline (word .start - 0.5 )
398
- plt .axvline (word .end - 0.5 )
395
+ for word in word_segments :
396
+ ax1 .axvline (word .start - 0.5 )
397
+ ax1 .axvline (word .end - 0.5 )
399
398
400
- for i , seg in enumerate (segments ):
401
- if seg .label != '|' :
402
- plt .annotate (seg .label , (seg .start , i + 0.3 ))
403
- plt .annotate (f'{ seg .score :.2f} ' , (seg .start , i + 4 ), fontsize = 8 )
399
+ for i , seg in enumerate (segments ):
400
+ if seg .label != '|' :
401
+ ax1 .annotate (seg .label , (seg .start , i + 0.3 ))
402
+ ax1 .annotate (f'{ seg .score :.2f} ' , (seg .start , i + 4 ), fontsize = 8 )
403
+
404
+ # The original waveform
405
+ ratio = waveform .size (0 ) / (trellis .size (0 ) - 1 )
406
+ ax2 .plot (waveform )
407
+ for word in word_segments :
408
+ x0 = ratio * word .start
409
+ x1 = ratio * word .end
410
+ ax2 .axvspan (x0 , x1 , alpha = 0.1 , color = 'red' )
411
+ ax2 .annotate (f'{ word .score :.2f} ' , (x0 , 0.8 ))
404
412
413
+ for seg in segments :
414
+ if seg .label != '|' :
415
+ ax2 .annotate (seg .label , (seg .start * ratio , 0.9 ))
416
+ xticks = ax2 .get_xticks ()
417
+ plt .xticks (xticks , xticks / bundle .sample_rate )
418
+ ax2 .set_xlabel ('time [second]' )
419
+ ax2 .set_yticks ([])
420
+ ax2 .set_ylim (- 1.0 , 1.0 )
421
+ ax2 .set_xlim (0 , waveform .size (- 1 ))
422
+
423
+ plot_alignments (trellis , segments , word_segments , waveform [0 ],)
405
424
plt .show ()
406
425
407
- # The original waveform
408
- ratio = waveform .size (1 ) / (trellis .size (0 ) - 1 )
409
- plt .plot (waveform [0 ])
410
- for word in word_segments :
411
- x0 = ratio * word .start
412
- x1 = ratio * word .end
413
- plt .axvspan (x0 , x1 , alpha = 0.1 , color = 'red' )
414
- plt .annotate (f'{ word .score :.2f} ' , (x0 , 0.8 ))
426
+ # A trick to embed the resulting audio to the generated file.
427
+ # `IPython.display.Audio` has to be the last call in a cell,
428
+ # and there should be only one call par cell.
429
+ def display_segment (i ):
430
+ ratio = waveform .size (1 ) / (trellis .size (0 ) - 1 )
431
+ word = word_segments [i ]
432
+ x0 = int (ratio * word .start )
433
+ x1 = int (ratio * word .end )
434
+ filename = f"_assets/{ i } _{ word .label } .wav"
435
+ torchaudio .save (filename , waveform [:, x0 :x1 ], bundle .sample_rate )
436
+ print (f"{ word .label } ({ word .score :.2f} ): { x0 / bundle .sample_rate :.3f} - { x1 / bundle .sample_rate :.3f} sec" )
437
+ return IPython .display .Audio (filename )
415
438
416
- for seg in segments :
417
- if seg .label != '|' :
418
- plt .annotate (seg .label , (seg .start * ratio , 0.9 ))
419
-
420
- ax2 = plt .gca ()
421
- xticks = ax2 .get_xticks ()
422
- plt .xticks (xticks , xticks / bundle .sample_rate )
423
- plt .xlabel ('time [second]' )
424
- ax2 .set_position (ax1 .get_position ())
425
- ax2 .set_yticks ([])
426
- ax2 .set_ylim (- 1.0 , 1.0 )
427
- ax2 .set_xlim (0 , waveform .size (- 1 ))
428
- plt .show ()
439
+ ######################################################################
440
+ #
429
441
430
442
# Generate the audio for each segment
431
443
print (transcript )
432
- IPython .display .display (IPython .display .Audio (SPEECH_FILE ))
433
- for i , word in enumerate (word_segments ):
434
- x0 = int (ratio * word .start )
435
- x1 = int (ratio * word .end )
436
- filename = f"{ i } _{ word .label } .wav"
437
- torchaudio .save (filename , waveform [:, x0 :x1 ], bundle .sample_rate )
438
- print (f"{ word .label } : { x0 / bundle .sample_rate :.3f} - { x1 / bundle .sample_rate :.3f} " )
439
- IPython .display .display (IPython .display .Audio (filename ))
444
+ IPython .display .Audio (SPEECH_FILE )
445
+
446
+
447
+ ######################################################################
448
+ #
449
+
450
+ display_segment (0 )
451
+
452
+ ######################################################################
453
+ #
454
+
455
+ display_segment (1 )
456
+
457
+ ######################################################################
458
+ #
459
+
460
+ display_segment (2 )
461
+
462
+ ######################################################################
463
+ #
464
+
465
+ display_segment (3 )
466
+
467
+ ######################################################################
468
+ #
469
+
470
+ display_segment (4 )
471
+
472
+ ######################################################################
473
+ #
474
+
475
+ display_segment (5 )
476
+
477
+ ######################################################################
478
+ #
479
+
480
+ display_segment (6 )
481
+
482
+ ######################################################################
483
+ #
484
+
485
+ display_segment (7 )
486
+
487
+ ######################################################################
488
+ #
440
489
490
+ display_segment (8 )
441
491
442
492
######################################################################
443
493
# Conclusion
0 commit comments