104 lines
3.6 KiB
Python
104 lines
3.6 KiB
Python
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
|
|
import argparse
|
|
from itertools import product as cartesian_product
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from TTS.config import load_config
|
|
from TTS.utils.audio import AudioProcessor
|
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
|
from TTS.vocoder.models import setup_model
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
|
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
|
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
|
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
|
parser.add_argument(
|
|
"--num_iter",
|
|
type=int,
|
|
help="Number of model inference iterations that you like to optimize noise schedule for.",
|
|
)
|
|
parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
|
|
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
|
parser.add_argument(
|
|
"--search_depth",
|
|
type=int,
|
|
default=3,
|
|
help="Search granularity. Increasing this increases the run-time exponentially.",
|
|
)
|
|
|
|
# load config
|
|
args = parser.parse_args()
|
|
config = load_config(args.config_path)
|
|
|
|
# setup audio processor
|
|
ap = AudioProcessor(**config.audio)
|
|
|
|
# load dataset
|
|
_, train_data = load_wav_data(args.data_path, 0)
|
|
train_data = train_data[: args.num_samples]
|
|
dataset = WaveGradDataset(
|
|
ap=ap,
|
|
items=train_data,
|
|
seq_len=-1,
|
|
hop_len=ap.hop_length,
|
|
pad_short=config.pad_short,
|
|
conv_pad=config.conv_pad,
|
|
is_training=True,
|
|
return_segments=False,
|
|
use_noise_augment=False,
|
|
use_cache=False,
|
|
verbose=True,
|
|
)
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_full_clips,
|
|
drop_last=False,
|
|
num_workers=config.num_loader_workers,
|
|
pin_memory=False,
|
|
)
|
|
|
|
# setup the model
|
|
model = setup_model(config)
|
|
if args.use_cuda:
|
|
model.cuda()
|
|
|
|
# setup optimization parameters
|
|
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
|
print(f" > base values: {base_values}")
|
|
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
|
best_error = float("inf")
|
|
best_schedule = None # pylint: disable=C0103
|
|
total_search_iter = len(base_values) ** args.num_iter
|
|
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
|
beta = exponents * base
|
|
model.compute_noise_level(beta)
|
|
for data in loader:
|
|
mel, audio = data
|
|
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
|
|
|
if args.use_cuda:
|
|
y_hat = y_hat.cpu()
|
|
y_hat = y_hat.numpy()
|
|
|
|
mel_hat = []
|
|
for i in range(y_hat.shape[0]):
|
|
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
|
mel_hat.append(torch.from_numpy(m))
|
|
|
|
mel_hat = torch.stack(mel_hat)
|
|
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
|
if mse.item() < best_error:
|
|
best_error = mse.item()
|
|
best_schedule = {"beta": beta}
|
|
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
|
np.save(args.output_path, best_schedule)
|