166 lines
6.1 KiB
Python
166 lines
6.1 KiB
Python
|
import argparse
|
||
|
import importlib
|
||
|
import os
|
||
|
from argparse import RawTextHelpFormatter
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from TTS.config import load_config
|
||
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||
|
from TTS.tts.models import setup_model
|
||
|
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
|
||
|
from TTS.utils.audio import AudioProcessor
|
||
|
from TTS.utils.io import load_checkpoint
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# pylint: disable=bad-option-value
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="""Extract attention masks from trained Tacotron/Tacotron2 models.
|
||
|
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n"""
|
||
|
"""Each attention mask is written to the same path as the input wav file with ".npy" file extension.
|
||
|
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n"""
|
||
|
"""
|
||
|
Example run:
|
||
|
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||
|
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth
|
||
|
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
||
|
--dataset_metafile metadata.csv
|
||
|
--data_path /root/LJSpeech-1.1/
|
||
|
--batch_size 32
|
||
|
--dataset ljspeech
|
||
|
--use_cuda True
|
||
|
""",
|
||
|
formatter_class=RawTextHelpFormatter,
|
||
|
)
|
||
|
parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ")
|
||
|
parser.add_argument(
|
||
|
"--config_path",
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="Path to Tacotron/Tacotron2 config file.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--dataset",
|
||
|
type=str,
|
||
|
default="",
|
||
|
required=True,
|
||
|
help="Target dataset processor name from TTS.tts.dataset.preprocess.",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--dataset_metafile",
|
||
|
type=str,
|
||
|
default="",
|
||
|
required=True,
|
||
|
help="Dataset metafile inclusing file paths with transcripts.",
|
||
|
)
|
||
|
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
|
||
|
parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.")
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA."
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
C = load_config(args.config_path)
|
||
|
ap = AudioProcessor(**C.audio)
|
||
|
|
||
|
# if the vocabulary was passed, replace the default
|
||
|
if "characters" in C.keys():
|
||
|
symbols, phonemes = make_symbols(**C.characters)
|
||
|
|
||
|
# load the model
|
||
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||
|
# TODO: handle multi-speaker
|
||
|
model = setup_model(C)
|
||
|
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
|
||
|
|
||
|
# data loader
|
||
|
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
||
|
preprocessor = getattr(preprocessor, args.dataset)
|
||
|
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
||
|
dataset = TTSDataset(
|
||
|
model.decoder.r,
|
||
|
C.text_cleaner,
|
||
|
compute_linear_spec=False,
|
||
|
ap=ap,
|
||
|
meta_data=meta_data,
|
||
|
characters=C.characters if "characters" in C.keys() else None,
|
||
|
add_blank=C["add_blank"] if "add_blank" in C.keys() else False,
|
||
|
use_phonemes=C.use_phonemes,
|
||
|
phoneme_cache_path=C.phoneme_cache_path,
|
||
|
phoneme_language=C.phoneme_language,
|
||
|
enable_eos_bos=C.enable_eos_bos_chars,
|
||
|
)
|
||
|
|
||
|
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
|
||
|
loader = DataLoader(
|
||
|
dataset,
|
||
|
batch_size=args.batch_size,
|
||
|
num_workers=4,
|
||
|
collate_fn=dataset.collate_fn,
|
||
|
shuffle=False,
|
||
|
drop_last=False,
|
||
|
)
|
||
|
|
||
|
# compute attentions
|
||
|
file_paths = []
|
||
|
with torch.no_grad():
|
||
|
for data in tqdm(loader):
|
||
|
# setup input data
|
||
|
text_input = data[0]
|
||
|
text_lengths = data[1]
|
||
|
linear_input = data[3]
|
||
|
mel_input = data[4]
|
||
|
mel_lengths = data[5]
|
||
|
stop_targets = data[6]
|
||
|
item_idxs = data[7]
|
||
|
|
||
|
# dispatch data to GPU
|
||
|
if args.use_cuda:
|
||
|
text_input = text_input.cuda()
|
||
|
text_lengths = text_lengths.cuda()
|
||
|
mel_input = mel_input.cuda()
|
||
|
mel_lengths = mel_lengths.cuda()
|
||
|
|
||
|
model_outputs = model.forward(text_input, text_lengths, mel_input)
|
||
|
|
||
|
alignments = model_outputs["alignments"].detach()
|
||
|
for idx, alignment in enumerate(alignments):
|
||
|
item_idx = item_idxs[idx]
|
||
|
# interpolate if r > 1
|
||
|
alignment = (
|
||
|
torch.nn.functional.interpolate(
|
||
|
alignment.transpose(0, 1).unsqueeze(0),
|
||
|
size=None,
|
||
|
scale_factor=model.decoder.r,
|
||
|
mode="nearest",
|
||
|
align_corners=None,
|
||
|
recompute_scale_factor=None,
|
||
|
)
|
||
|
.squeeze(0)
|
||
|
.transpose(0, 1)
|
||
|
)
|
||
|
# remove paddings
|
||
|
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
||
|
# set file paths
|
||
|
wav_file_name = os.path.basename(item_idx)
|
||
|
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
|
||
|
file_path = item_idx.replace(wav_file_name, align_file_name)
|
||
|
# save output
|
||
|
wav_file_abs_path = os.path.abspath(item_idx)
|
||
|
file_abs_path = os.path.abspath(file_path)
|
||
|
file_paths.append([wav_file_abs_path, file_abs_path])
|
||
|
np.save(file_path, alignment)
|
||
|
|
||
|
# ourput metafile
|
||
|
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
||
|
|
||
|
with open(metafile, "w", encoding="utf-8") as f:
|
||
|
for p in file_paths:
|
||
|
f.write(f"{p[0]}|{p[1]}\n")
|
||
|
print(f" >> Metafile created: {metafile}")
|