import os import pickle as pickle_tts from typing import Any, Callable, Dict, Union import fsspec import torch from TTS.utils.generic_utils import get_user_data_dir class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" def find_class(self, module, name): return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) class AttrDict(dict): """A custom dict which converts dict keys to class attributes""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self def load_fsspec( path: str, map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, cache: bool = True, **kwargs, ) -> Any: """Like torch.load but can load from other locations (e.g. s3:// , gs://). Args: path: Any path or url supported by fsspec. map_location: torch.device or str. cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. **kwargs: Keyword arguments forwarded to torch.load. Returns: Object stored in path. """ is_local = os.path.isdir(path) or os.path.isfile(path) if cache and not is_local: with fsspec.open( f"filecache::{path}", filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, mode="rb", ) as f: return torch.load(f, map_location=map_location, **kwargs) else: with fsspec.open(path, "rb") as f: return torch.load(f, map_location=map_location, **kwargs) def load_checkpoint( model, checkpoint_path, use_cuda=False, eval=False, cache=False ): # pylint: disable=redefined-builtin try: state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) model.load_state_dict(state["model"]) if use_cuda: model.cuda() if eval: model.eval() return model, state