73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
from typing import Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from matplotlib import pyplot as plt
|
|
|
|
from TTS.tts.utils.visual import plot_spectrogram
|
|
from TTS.utils.audio import AudioProcessor
|
|
|
|
|
|
def interpolate_vocoder_input(scale_factor, spec):
|
|
"""Interpolate spectrogram by the scale factor.
|
|
It is mainly used to match the sampling rates of
|
|
the tts and vocoder models.
|
|
|
|
Args:
|
|
scale_factor (float): scale factor to interpolate the spectrogram
|
|
spec (np.array): spectrogram to be interpolated
|
|
|
|
Returns:
|
|
torch.tensor: interpolated spectrogram.
|
|
"""
|
|
print(" > before interpolation :", spec.shape)
|
|
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
|
|
spec = torch.nn.functional.interpolate(
|
|
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
|
|
).squeeze(0)
|
|
print(" > after interpolation :", spec.shape)
|
|
return spec
|
|
|
|
|
|
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
|
"""Plot the predicted and the real waveform and their spectrograms.
|
|
|
|
Args:
|
|
y_hat (torch.tensor): Predicted waveform.
|
|
y (torch.tensor): Real waveform.
|
|
ap (AudioProcessor): Audio processor used to process the waveform.
|
|
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
|
|
|
Returns:
|
|
Dict: output figures keyed by the name of the figures.
|
|
""" """Plot vocoder model results"""
|
|
if name_prefix is None:
|
|
name_prefix = ""
|
|
|
|
# select an instance from batch
|
|
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
|
y = y[0].squeeze().detach().cpu().numpy()
|
|
|
|
spec_fake = ap.melspectrogram(y_hat).T
|
|
spec_real = ap.melspectrogram(y).T
|
|
spec_diff = np.abs(spec_fake - spec_real)
|
|
|
|
# plot figure and save it
|
|
fig_wave = plt.figure()
|
|
plt.subplot(2, 1, 1)
|
|
plt.plot(y)
|
|
plt.title("groundtruth speech")
|
|
plt.subplot(2, 1, 2)
|
|
plt.plot(y_hat)
|
|
plt.title("generated speech")
|
|
plt.tight_layout()
|
|
plt.close()
|
|
|
|
figures = {
|
|
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
|
|
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
|
|
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
|
|
name_prefix + "speech_comparison": fig_wave,
|
|
}
|
|
return figures
|