Skip to content

Commit 1a946b0

Browse files
committed
Fix TTS+TorchAudio tutorial
Workaround missing link by copy-n-pasting direct link from https://github.com/pytorch/audio/blob/0c5a8bf7c25c53ee14f48c33bc31e3390e6ae48e/examples/tutorials/tacotron2_pipeline_tutorial.py#L321 For more info see #2070
1 parent 39010e4 commit 1a946b0

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

intermediate_source/text_to_speech_with_torchaudio.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,24 @@ def text_to_sequence(text):
297297
# publishe on Torch Hub. One can instantiate the model using ``torch.hub``
298298
# module.
299299
#
300+
if False:
301+
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp32')
302+
else:
303+
# Workaround to load model mapped on GPU
304+
# https://stackoverflow.com/a/61840832
305+
waveglow = torch.hub.load(
306+
"NVIDIA/DeepLearningExamples:torchhub",
307+
"nvidia_waveglow",
308+
model_math="fp32",
309+
pretrained=False,
310+
)
311+
checkpoint = torch.hub.load_state_dict_from_url(
312+
"https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth",
313+
progress=False,
314+
map_location=device,
315+
)
316+
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
300317

301-
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp32')
302318
waveglow = waveglow.remove_weightnorm(waveglow)
303319
waveglow = waveglow.to(device)
304320
waveglow.eval()

0 commit comments

Comments
 (0)