384 lines
13 KiB
Python
384 lines
13 KiB
Python
|
import json
|
||
|
import random
|
||
|
from typing import Any, Dict, List, Tuple, Union
|
||
|
|
||
|
import fsspec
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from TTS.config import load_config
|
||
|
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||
|
from TTS.utils.audio import AudioProcessor
|
||
|
|
||
|
|
||
|
def load_file(path: str):
|
||
|
if path.endswith(".json"):
|
||
|
with fsspec.open(path, "r") as f:
|
||
|
return json.load(f)
|
||
|
elif path.endswith(".pth"):
|
||
|
with fsspec.open(path, "rb") as f:
|
||
|
return torch.load(f, map_location="cpu")
|
||
|
else:
|
||
|
raise ValueError("Unsupported file type")
|
||
|
|
||
|
|
||
|
def save_file(obj: Any, path: str):
|
||
|
if path.endswith(".json"):
|
||
|
with fsspec.open(path, "w") as f:
|
||
|
json.dump(obj, f, indent=4)
|
||
|
elif path.endswith(".pth"):
|
||
|
with fsspec.open(path, "wb") as f:
|
||
|
torch.save(obj, f)
|
||
|
else:
|
||
|
raise ValueError("Unsupported file type")
|
||
|
|
||
|
|
||
|
class BaseIDManager:
|
||
|
"""Base `ID` Manager class. Every new `ID` manager must inherit this.
|
||
|
It defines common `ID` manager specific functions.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, id_file_path: str = ""):
|
||
|
self.name_to_id = {}
|
||
|
|
||
|
if id_file_path:
|
||
|
self.load_ids_from_file(id_file_path)
|
||
|
|
||
|
@staticmethod
|
||
|
def _load_json(json_file_path: str) -> Dict:
|
||
|
with fsspec.open(json_file_path, "r") as f:
|
||
|
return json.load(f)
|
||
|
|
||
|
@staticmethod
|
||
|
def _save_json(json_file_path: str, data: dict) -> None:
|
||
|
with fsspec.open(json_file_path, "w") as f:
|
||
|
json.dump(data, f, indent=4)
|
||
|
|
||
|
def set_ids_from_data(self, items: List, parse_key: str) -> None:
|
||
|
"""Set IDs from data samples.
|
||
|
|
||
|
Args:
|
||
|
items (List): Data sampled returned by `load_tts_samples()`.
|
||
|
"""
|
||
|
self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)
|
||
|
|
||
|
def load_ids_from_file(self, file_path: str) -> None:
|
||
|
"""Set IDs from a file.
|
||
|
|
||
|
Args:
|
||
|
file_path (str): Path to the file.
|
||
|
"""
|
||
|
self.name_to_id = load_file(file_path)
|
||
|
|
||
|
def save_ids_to_file(self, file_path: str) -> None:
|
||
|
"""Save IDs to a json file.
|
||
|
|
||
|
Args:
|
||
|
file_path (str): Path to the output file.
|
||
|
"""
|
||
|
save_file(self.name_to_id, file_path)
|
||
|
|
||
|
def get_random_id(self) -> Any:
|
||
|
"""Get a random embedding.
|
||
|
|
||
|
Args:
|
||
|
|
||
|
Returns:
|
||
|
np.ndarray: embedding.
|
||
|
"""
|
||
|
if self.name_to_id:
|
||
|
return self.name_to_id[random.choices(list(self.name_to_id.keys()))[0]]
|
||
|
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
|
||
|
"""Parse IDs from data samples retured by `load_tts_samples()`.
|
||
|
|
||
|
Args:
|
||
|
items (list): Data sampled returned by `load_tts_samples()`.
|
||
|
parse_key (str): The key to being used to parse the data.
|
||
|
Returns:
|
||
|
Tuple[Dict]: speaker IDs.
|
||
|
"""
|
||
|
classes = sorted({item[parse_key] for item in items})
|
||
|
ids = {name: i for i, name in enumerate(classes)}
|
||
|
return ids
|
||
|
|
||
|
|
||
|
class EmbeddingManager(BaseIDManager):
|
||
|
"""Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||
|
It defines common `Embedding` manager specific functions.
|
||
|
|
||
|
It expects embeddings files in the following format:
|
||
|
|
||
|
::
|
||
|
|
||
|
{
|
||
|
'audio_file_key':{
|
||
|
'name': 'category_name',
|
||
|
'embedding'[<embedding_values>]
|
||
|
},
|
||
|
...
|
||
|
}
|
||
|
|
||
|
`audio_file_key` is a unique key to the audio file in the dataset. It can be the path to the file or any other unique key.
|
||
|
`embedding` is the embedding vector of the audio file.
|
||
|
`name` can be name of the speaker of the audio file.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embedding_file_path: Union[str, List[str]] = "",
|
||
|
id_file_path: str = "",
|
||
|
encoder_model_path: str = "",
|
||
|
encoder_config_path: str = "",
|
||
|
use_cuda: bool = False,
|
||
|
):
|
||
|
super().__init__(id_file_path=id_file_path)
|
||
|
|
||
|
self.embeddings = {}
|
||
|
self.embeddings_by_names = {}
|
||
|
self.clip_ids = []
|
||
|
self.encoder = None
|
||
|
self.encoder_ap = None
|
||
|
self.use_cuda = use_cuda
|
||
|
|
||
|
if embedding_file_path:
|
||
|
if isinstance(embedding_file_path, list):
|
||
|
self.load_embeddings_from_list_of_files(embedding_file_path)
|
||
|
else:
|
||
|
self.load_embeddings_from_file(embedding_file_path)
|
||
|
|
||
|
if encoder_model_path and encoder_config_path:
|
||
|
self.init_encoder(encoder_model_path, encoder_config_path, use_cuda)
|
||
|
|
||
|
@property
|
||
|
def num_embeddings(self):
|
||
|
"""Get number of embeddings."""
|
||
|
return len(self.embeddings)
|
||
|
|
||
|
@property
|
||
|
def num_names(self):
|
||
|
"""Get number of embeddings."""
|
||
|
return len(self.embeddings_by_names)
|
||
|
|
||
|
@property
|
||
|
def embedding_dim(self):
|
||
|
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
|
||
|
if self.embeddings:
|
||
|
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
|
||
|
return 0
|
||
|
|
||
|
@property
|
||
|
def embedding_names(self):
|
||
|
"""Get embedding names."""
|
||
|
return list(self.embeddings_by_names.keys())
|
||
|
|
||
|
def save_embeddings_to_file(self, file_path: str) -> None:
|
||
|
"""Save embeddings to a json file.
|
||
|
|
||
|
Args:
|
||
|
file_path (str): Path to the output file.
|
||
|
"""
|
||
|
save_file(self.embeddings, file_path)
|
||
|
|
||
|
@staticmethod
|
||
|
def read_embeddings_from_file(file_path: str):
|
||
|
"""Load embeddings from a json file.
|
||
|
|
||
|
Args:
|
||
|
file_path (str): Path to the file.
|
||
|
"""
|
||
|
embeddings = load_file(file_path)
|
||
|
speakers = sorted({x["name"] for x in embeddings.values()})
|
||
|
name_to_id = {name: i for i, name in enumerate(speakers)}
|
||
|
clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys())))
|
||
|
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||
|
embeddings_by_names = {}
|
||
|
for x in embeddings.values():
|
||
|
if x["name"] not in embeddings_by_names.keys():
|
||
|
embeddings_by_names[x["name"]] = [x["embedding"]]
|
||
|
else:
|
||
|
embeddings_by_names[x["name"]].append(x["embedding"])
|
||
|
return name_to_id, clip_ids, embeddings, embeddings_by_names
|
||
|
|
||
|
def load_embeddings_from_file(self, file_path: str) -> None:
|
||
|
"""Load embeddings from a json file.
|
||
|
|
||
|
Args:
|
||
|
file_path (str): Path to the target json file.
|
||
|
"""
|
||
|
self.name_to_id, self.clip_ids, self.embeddings, self.embeddings_by_names = self.read_embeddings_from_file(
|
||
|
file_path
|
||
|
)
|
||
|
|
||
|
def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None:
|
||
|
"""Load embeddings from a list of json files and don't allow duplicate keys.
|
||
|
|
||
|
Args:
|
||
|
file_paths (List[str]): List of paths to the target json files.
|
||
|
"""
|
||
|
self.name_to_id = {}
|
||
|
self.clip_ids = []
|
||
|
self.embeddings_by_names = {}
|
||
|
self.embeddings = {}
|
||
|
for file_path in file_paths:
|
||
|
ids, clip_ids, embeddings, embeddings_by_names = self.read_embeddings_from_file(file_path)
|
||
|
# check colliding keys
|
||
|
duplicates = set(self.embeddings.keys()) & set(embeddings.keys())
|
||
|
if duplicates:
|
||
|
raise ValueError(f" [!] Duplicate embedding names <{duplicates}> in {file_path}")
|
||
|
# store values
|
||
|
self.name_to_id.update(ids)
|
||
|
self.clip_ids.extend(clip_ids)
|
||
|
self.embeddings_by_names.update(embeddings_by_names)
|
||
|
self.embeddings.update(embeddings)
|
||
|
|
||
|
# reset name_to_id to get the right speaker ids
|
||
|
self.name_to_id = {name: i for i, name in enumerate(self.name_to_id)}
|
||
|
|
||
|
def get_embedding_by_clip(self, clip_idx: str) -> List:
|
||
|
"""Get embedding by clip ID.
|
||
|
|
||
|
Args:
|
||
|
clip_idx (str): Target clip ID.
|
||
|
|
||
|
Returns:
|
||
|
List: embedding as a list.
|
||
|
"""
|
||
|
return self.embeddings[clip_idx]["embedding"]
|
||
|
|
||
|
def get_embeddings_by_name(self, idx: str) -> List[List]:
|
||
|
"""Get all embeddings of a speaker.
|
||
|
|
||
|
Args:
|
||
|
idx (str): Target name.
|
||
|
|
||
|
Returns:
|
||
|
List[List]: all the embeddings of the given speaker.
|
||
|
"""
|
||
|
return self.embeddings_by_names[idx]
|
||
|
|
||
|
def get_embeddings_by_names(self) -> Dict:
|
||
|
"""Get all embeddings by names.
|
||
|
|
||
|
Returns:
|
||
|
Dict: all the embeddings of each speaker.
|
||
|
"""
|
||
|
embeddings_by_names = {}
|
||
|
for x in self.embeddings.values():
|
||
|
if x["name"] not in embeddings_by_names.keys():
|
||
|
embeddings_by_names[x["name"]] = [x["embedding"]]
|
||
|
else:
|
||
|
embeddings_by_names[x["name"]].append(x["embedding"])
|
||
|
return embeddings_by_names
|
||
|
|
||
|
def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
|
||
|
"""Get mean embedding of a idx.
|
||
|
|
||
|
Args:
|
||
|
idx (str): Target name.
|
||
|
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
|
||
|
randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.
|
||
|
|
||
|
Returns:
|
||
|
np.ndarray: Mean embedding.
|
||
|
"""
|
||
|
embeddings = self.get_embeddings_by_name(idx)
|
||
|
if num_samples is None:
|
||
|
embeddings = np.stack(embeddings).mean(0)
|
||
|
else:
|
||
|
assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
|
||
|
if randomize:
|
||
|
embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
|
||
|
else:
|
||
|
embeddings = np.stack(embeddings[:num_samples]).mean(0)
|
||
|
return embeddings
|
||
|
|
||
|
def get_random_embedding(self) -> Any:
|
||
|
"""Get a random embedding.
|
||
|
|
||
|
Args:
|
||
|
|
||
|
Returns:
|
||
|
np.ndarray: embedding.
|
||
|
"""
|
||
|
if self.embeddings:
|
||
|
return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]
|
||
|
|
||
|
return None
|
||
|
|
||
|
def get_clips(self) -> List:
|
||
|
return sorted(self.embeddings.keys())
|
||
|
|
||
|
def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None:
|
||
|
"""Initialize a speaker encoder model.
|
||
|
|
||
|
Args:
|
||
|
model_path (str): Model file path.
|
||
|
config_path (str): Model config file path.
|
||
|
use_cuda (bool, optional): Use CUDA. Defaults to False.
|
||
|
"""
|
||
|
self.use_cuda = use_cuda
|
||
|
self.encoder_config = load_config(config_path)
|
||
|
self.encoder = setup_encoder_model(self.encoder_config)
|
||
|
self.encoder_criterion = self.encoder.load_checkpoint(
|
||
|
self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True
|
||
|
)
|
||
|
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
|
||
|
|
||
|
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||
|
"""Compute a embedding from a given audio file.
|
||
|
|
||
|
Args:
|
||
|
wav_file (Union[str, List[str]]): Target file path.
|
||
|
|
||
|
Returns:
|
||
|
list: Computed embedding.
|
||
|
"""
|
||
|
|
||
|
def _compute(wav_file: str):
|
||
|
waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
|
||
|
if not self.encoder_config.model_params.get("use_torch_spec", False):
|
||
|
m_input = self.encoder_ap.melspectrogram(waveform)
|
||
|
m_input = torch.from_numpy(m_input)
|
||
|
else:
|
||
|
m_input = torch.from_numpy(waveform)
|
||
|
|
||
|
if self.use_cuda:
|
||
|
m_input = m_input.cuda()
|
||
|
m_input = m_input.unsqueeze(0)
|
||
|
embedding = self.encoder.compute_embedding(m_input)
|
||
|
return embedding
|
||
|
|
||
|
if isinstance(wav_file, list):
|
||
|
# compute the mean embedding
|
||
|
embeddings = None
|
||
|
for wf in wav_file:
|
||
|
embedding = _compute(wf)
|
||
|
if embeddings is None:
|
||
|
embeddings = embedding
|
||
|
else:
|
||
|
embeddings += embedding
|
||
|
return (embeddings / len(wav_file))[0].tolist()
|
||
|
embedding = _compute(wav_file)
|
||
|
return embedding[0].tolist()
|
||
|
|
||
|
def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
|
||
|
"""Compute embedding from features.
|
||
|
|
||
|
Args:
|
||
|
feats (Union[torch.Tensor, np.ndarray]): Input features.
|
||
|
|
||
|
Returns:
|
||
|
List: computed embedding.
|
||
|
"""
|
||
|
if isinstance(feats, np.ndarray):
|
||
|
feats = torch.from_numpy(feats)
|
||
|
if feats.ndim == 2:
|
||
|
feats = feats.unsqueeze(0)
|
||
|
if self.use_cuda:
|
||
|
feats = feats.cuda()
|
||
|
return self.encoder.compute_embedding(feats)
|