ai-content-maker/.venv/Lib/site-packages/TTS/encoder/utils/generic_utils.py

137 lines
5.1 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import glob
import os
import random
import numpy as np
from scipy import signal
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
class AugmentWAV(object):
def __init__(self, ap, augmentation_config):
self.ap = ap
self.use_additive_noise = False
if "additive" in augmentation_config.keys():
self.additive_noise_config = augmentation_config["additive"]
additive_path = self.additive_noise_config["sounds_path"]
if additive_path:
self.use_additive_noise = True
# get noise types
self.additive_noise_types = []
for key in self.additive_noise_config.keys():
if isinstance(self.additive_noise_config[key], dict):
self.additive_noise_types.append(key)
additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)
self.noise_list = {}
for wav_file in additive_files:
noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
# ignore not listed directories
if noise_dir not in self.additive_noise_types:
continue
if not noise_dir in self.noise_list:
self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file)
print(
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
)
self.use_rir = False
if "rir" in augmentation_config.keys():
self.rir_config = augmentation_config["rir"]
if self.rir_config["rir_path"]:
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
self.use_rir = True
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
self.create_augmentation_global_list()
def create_augmentation_global_list(self):
if self.use_additive_noise:
self.global_noise_list = self.additive_noise_types
else:
self.global_noise_list = []
if self.use_rir:
self.global_noise_list.append("RIR_AUG")
def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
noise_list = random.sample(
self.noise_list[noise_type],
random.randint(
self.additive_noise_config[noise_type]["min_num_noises"],
self.additive_noise_config[noise_type]["max_num_noises"],
),
)
audio_len = audio.shape[0]
noises_wav = None
for noise in noise_list:
noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len]
if noiseaudio.shape[0] < audio_len:
continue
noise_snr = random.uniform(
self.additive_noise_config[noise_type]["min_snr_in_db"],
self.additive_noise_config[noise_type]["max_num_noises"],
)
noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
if noises_wav is None:
noises_wav = noise_wav
else:
noises_wav += noise_wav
# if all possible files is less than audio, choose other files
if noises_wav is None:
return self.additive_noise(noise_type, audio)
return audio + noises_wav
def reverberate(self, audio):
audio_len = audio.shape[0]
rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir**2))
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
def apply_one(self, audio):
noise_type = random.choice(self.global_noise_list)
if noise_type == "RIR_AUG":
return self.reverberate(audio)
return self.additive_noise(noise_type, audio)
def setup_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
config.model_params["input_dim"],
config.model_params["proj_dim"],
config.model_params["lstm_dim"],
config.model_params["num_lstm_layers"],
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
elif config.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(
input_dim=config.model_params["input_dim"],
proj_dim=config.model_params["proj_dim"],
log_input=config.model_params.get("log_input", False),
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
return model