288 lines
9.3 KiB
Python
288 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Extract Mel spectrograms with teacher forcing."""
|
|
|
|
import argparse
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from TTS.config import load_config
|
|
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
|
from TTS.tts.models import setup_model
|
|
from TTS.tts.utils.speakers import SpeakerManager
|
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|
from TTS.utils.audio import AudioProcessor
|
|
from TTS.utils.audio.numpy_transforms import quantize
|
|
from TTS.utils.generic_utils import count_parameters
|
|
|
|
use_cuda = torch.cuda.is_available()
|
|
|
|
|
|
def setup_loader(ap, r, verbose=False):
|
|
tokenizer, _ = TTSTokenizer.init_from_config(c)
|
|
dataset = TTSDataset(
|
|
outputs_per_step=r,
|
|
compute_linear_spec=False,
|
|
samples=meta_data,
|
|
tokenizer=tokenizer,
|
|
ap=ap,
|
|
batch_group_size=0,
|
|
min_text_len=c.min_text_len,
|
|
max_text_len=c.max_text_len,
|
|
min_audio_len=c.min_audio_len,
|
|
max_audio_len=c.max_audio_len,
|
|
phoneme_cache_path=c.phoneme_cache_path,
|
|
precompute_num_workers=0,
|
|
use_noise_augment=False,
|
|
verbose=verbose,
|
|
speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
|
|
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
|
|
)
|
|
|
|
if c.use_phonemes and c.compute_input_seq_cache:
|
|
# precompute phonemes to have a better estimate of sequence lengths.
|
|
dataset.compute_input_seq(c.num_loader_workers)
|
|
dataset.preprocess_samples()
|
|
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=c.batch_size,
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=False,
|
|
sampler=None,
|
|
num_workers=c.num_loader_workers,
|
|
pin_memory=False,
|
|
)
|
|
return loader
|
|
|
|
|
|
def set_filename(wav_path, out_path):
|
|
wav_file = os.path.basename(wav_path)
|
|
file_name = wav_file.split(".")[0]
|
|
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
|
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
|
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
|
|
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
|
|
wavq_path = os.path.join(out_path, "quant", file_name)
|
|
mel_path = os.path.join(out_path, "mel", file_name)
|
|
wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
|
|
wav_path = os.path.join(out_path, "wav", file_name + ".wav")
|
|
return file_name, wavq_path, mel_path, wav_gl_path, wav_path
|
|
|
|
|
|
def format_data(data):
|
|
# setup input data
|
|
text_input = data["token_id"]
|
|
text_lengths = data["token_id_lengths"]
|
|
mel_input = data["mel"]
|
|
mel_lengths = data["mel_lengths"]
|
|
item_idx = data["item_idxs"]
|
|
d_vectors = data["d_vectors"]
|
|
speaker_ids = data["speaker_ids"]
|
|
attn_mask = data["attns"]
|
|
avg_text_length = torch.mean(text_lengths.float())
|
|
avg_spec_length = torch.mean(mel_lengths.float())
|
|
|
|
# dispatch data to GPU
|
|
if use_cuda:
|
|
text_input = text_input.cuda(non_blocking=True)
|
|
text_lengths = text_lengths.cuda(non_blocking=True)
|
|
mel_input = mel_input.cuda(non_blocking=True)
|
|
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
|
if speaker_ids is not None:
|
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
|
if d_vectors is not None:
|
|
d_vectors = d_vectors.cuda(non_blocking=True)
|
|
if attn_mask is not None:
|
|
attn_mask = attn_mask.cuda(non_blocking=True)
|
|
return (
|
|
text_input,
|
|
text_lengths,
|
|
mel_input,
|
|
mel_lengths,
|
|
speaker_ids,
|
|
d_vectors,
|
|
avg_text_length,
|
|
avg_spec_length,
|
|
attn_mask,
|
|
item_idx,
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
def inference(
|
|
model_name,
|
|
model,
|
|
ap,
|
|
text_input,
|
|
text_lengths,
|
|
mel_input,
|
|
mel_lengths,
|
|
speaker_ids=None,
|
|
d_vectors=None,
|
|
):
|
|
if model_name == "glow_tts":
|
|
speaker_c = None
|
|
if speaker_ids is not None:
|
|
speaker_c = speaker_ids
|
|
elif d_vectors is not None:
|
|
speaker_c = d_vectors
|
|
outputs = model.inference_with_MAS(
|
|
text_input,
|
|
text_lengths,
|
|
mel_input,
|
|
mel_lengths,
|
|
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
|
|
)
|
|
model_output = outputs["model_outputs"]
|
|
model_output = model_output.detach().cpu().numpy()
|
|
|
|
elif "tacotron" in model_name:
|
|
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
|
outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
|
postnet_outputs = outputs["model_outputs"]
|
|
# normalize tacotron output
|
|
if model_name == "tacotron":
|
|
mel_specs = []
|
|
postnet_outputs = postnet_outputs.data.cpu().numpy()
|
|
for b in range(postnet_outputs.shape[0]):
|
|
postnet_output = postnet_outputs[b]
|
|
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T))
|
|
model_output = torch.stack(mel_specs).cpu().numpy()
|
|
|
|
elif model_name == "tacotron2":
|
|
model_output = postnet_outputs.detach().cpu().numpy()
|
|
return model_output
|
|
|
|
|
|
def extract_spectrograms(
|
|
data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
|
|
):
|
|
model.eval()
|
|
export_metadata = []
|
|
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
|
# format data
|
|
(
|
|
text_input,
|
|
text_lengths,
|
|
mel_input,
|
|
mel_lengths,
|
|
speaker_ids,
|
|
d_vectors,
|
|
_,
|
|
_,
|
|
_,
|
|
item_idx,
|
|
) = format_data(data)
|
|
|
|
model_output = inference(
|
|
c.model.lower(),
|
|
model,
|
|
ap,
|
|
text_input,
|
|
text_lengths,
|
|
mel_input,
|
|
mel_lengths,
|
|
speaker_ids,
|
|
d_vectors,
|
|
)
|
|
|
|
for idx in range(text_input.shape[0]):
|
|
wav_file_path = item_idx[idx]
|
|
wav = ap.load_wav(wav_file_path)
|
|
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
|
|
|
# quantize and save wav
|
|
if quantize_bits > 0:
|
|
wavq = quantize(wav, quantize_bits)
|
|
np.save(wavq_path, wavq)
|
|
|
|
# save TTS mel
|
|
mel = model_output[idx]
|
|
mel_length = mel_lengths[idx]
|
|
mel = mel[:mel_length, :].T
|
|
np.save(mel_path, mel)
|
|
|
|
export_metadata.append([wav_file_path, mel_path])
|
|
if save_audio:
|
|
ap.save_wav(wav, wav_path)
|
|
|
|
if debug:
|
|
print("Audio for debug saved at:", wav_gl_path)
|
|
wav = ap.inv_melspectrogram(mel)
|
|
ap.save_wav(wav, wav_gl_path)
|
|
|
|
with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
|
|
for data in export_metadata:
|
|
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
|
|
|
|
|
def main(args): # pylint: disable=redefined-outer-name
|
|
# pylint: disable=global-variable-undefined
|
|
global meta_data, speaker_manager
|
|
|
|
# Audio processor
|
|
ap = AudioProcessor(**c.audio)
|
|
|
|
# load data instances
|
|
meta_data_train, meta_data_eval = load_tts_samples(
|
|
c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
|
)
|
|
|
|
# use eval and training partitions
|
|
meta_data = meta_data_train + meta_data_eval
|
|
|
|
# init speaker manager
|
|
if c.use_speaker_embedding:
|
|
speaker_manager = SpeakerManager(data_items=meta_data)
|
|
elif c.use_d_vector_file:
|
|
speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file)
|
|
else:
|
|
speaker_manager = None
|
|
|
|
# setup model
|
|
model = setup_model(c)
|
|
|
|
# restore model
|
|
model.load_checkpoint(c, args.checkpoint_path, eval=True)
|
|
|
|
if use_cuda:
|
|
model.cuda()
|
|
|
|
num_params = count_parameters(model)
|
|
print("\n > Model has {} parameters".format(num_params), flush=True)
|
|
# set r
|
|
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
|
|
own_loader = setup_loader(ap, r, verbose=True)
|
|
|
|
extract_spectrograms(
|
|
own_loader,
|
|
model,
|
|
ap,
|
|
args.output_path,
|
|
quantize_bits=args.quantize_bits,
|
|
save_audio=args.save_audio,
|
|
debug=args.debug,
|
|
metada_name="metada.txt",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
|
|
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
|
|
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
|
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
|
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
|
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
|
|
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
|
args = parser.parse_args()
|
|
|
|
c = load_config(args.config_path)
|
|
c.audio.trim_silence = False
|
|
main(args)
|