# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """EnCodec model implementation.""" import math from pathlib import Path import typing as tp import numpy as np import torch from torch import nn from . import quantization as qt from . import modules as m from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/' EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]] class LMModel(nn.Module): """Language Model to estimate probabilities of each codebook entry. We predict all codebooks in parallel for a given time step. Args: n_q (int): number of codebooks. card (int): codebook cardinality. dim (int): transformer dimension. **kwargs: passed to `encodec.modules.transformer.StreamingTransformerEncoder`. """ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs): super().__init__() self.card = card self.n_q = n_q self.dim = dim self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs) self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)]) self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)]) def forward(self, indices: torch.Tensor, states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0): """ Args: indices (torch.Tensor): indices from the previous time step. Indices should be 1 + actual index in the codebook. The value 0 is reserved for when the index is missing (i.e. first time step). Shape should be `[B, n_q, T]`. states: state for the streaming decoding. offset: offset of the current time step. Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities with a shape `[B, card, n_q, T]`. """ B, K, T = indices.shape input_ = sum([self.emb[k](indices[:, k]) for k in range(K)]) out, states, offset = self.transformer(input_, states, offset) logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2) return torch.softmax(logits, dim=1), states, offset class EncodecModel(nn.Module): """EnCodec model operating on the raw waveform. Args: target_bandwidths (list of float): Target bandwidths. encoder (nn.Module): Encoder network. decoder (nn.Module): Decoder network. sample_rate (int): Audio sample rate. channels (int): Number of audio channels. normalize (bool): Whether to apply audio normalization. segment (float or None): segment duration in sec. when doing overlap-add. overlap (float): overlap between segment, given as a fraction of the segment duration. name (str): name of the model, used as metadata when compressing audio. """ def __init__(self, encoder: m.SEANetEncoder, decoder: m.SEANetDecoder, quantizer: qt.ResidualVectorQuantizer, target_bandwidths: tp.List[float], sample_rate: int, channels: int, normalize: bool = False, segment: tp.Optional[float] = None, overlap: float = 0.01, name: str = 'unset'): super().__init__() self.bandwidth: tp.Optional[float] = None self.target_bandwidths = target_bandwidths self.encoder = encoder self.quantizer = quantizer self.decoder = decoder self.sample_rate = sample_rate self.channels = channels self.normalize = normalize self.segment = segment self.overlap = overlap self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios)) self.name = name self.bits_per_codebook = int(math.log2(self.quantizer.bins)) assert 2 ** self.bits_per_codebook == self.quantizer.bins, \ "quantizer bins must be a power of 2." @property def segment_length(self) -> tp.Optional[int]: if self.segment is None: return None return int(self.segment * self.sample_rate) @property def segment_stride(self) -> tp.Optional[int]: segment_length = self.segment_length if segment_length is None: return None return max(1, int((1 - self.overlap) * segment_length)) def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]: """Given a tensor `x`, returns a list of frames containing the discrete encoded codes for `x`, along with rescaling factors for each segment, when `self.normalize` is True. Each frames is a tuple `(codebook, scale)`, with `codebook` of shape `[B, K, T]`, with `K` the number of codebooks. """ assert x.dim() == 3 _, channels, length = x.shape assert channels > 0 and channels <= 2 segment_length = self.segment_length if segment_length is None: segment_length = length stride = length else: stride = self.segment_stride # type: ignore assert stride is not None encoded_frames: tp.List[EncodedFrame] = [] for offset in range(0, length, stride): frame = x[:, :, offset: offset + segment_length] encoded_frames.append(self._encode_frame(frame)) return encoded_frames def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: length = x.shape[-1] duration = length / self.sample_rate assert self.segment is None or duration <= 1e-5 + self.segment if self.normalize: mono = x.mean(dim=1, keepdim=True) volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() scale = 1e-8 + volume x = x / scale scale = scale.view(-1, 1) else: scale = None emb = self.encoder(x) codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth) codes = codes.transpose(0, 1) # codes is [B, K, T], with T frames, K nb of codebooks. return codes, scale def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor: """Decode the given frames into a waveform. Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be trimmed. """ segment_length = self.segment_length if segment_length is None: assert len(encoded_frames) == 1 return self._decode_frame(encoded_frames[0]) frames = [self._decode_frame(frame) for frame in encoded_frames] return _linear_overlap_add(frames, self.segment_stride or 1) def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor: codes, scale = encoded_frame codes = codes.transpose(0, 1) emb = self.quantizer.decode(codes) out = self.decoder(emb) if scale is not None: out = out * scale.view(-1, 1, 1) return out def forward(self, x: torch.Tensor) -> torch.Tensor: frames = self.encode(x) return self.decode(frames)[:, :, :x.shape[-1]] def set_target_bandwidth(self, bandwidth: float): if bandwidth not in self.target_bandwidths: raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. " f"Select one of {self.target_bandwidths}.") self.bandwidth = bandwidth def get_lm_model(self) -> LMModel: """Return the associated LM model to improve the compression rate. """ torch.manual_seed(1234) # todo remove: this device = next(self.parameters()).device lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200, past_context=int(3.5 * self.frame_rate)).to(device) checkpoints = { 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th', 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th', } try: checkpoint_name = checkpoints[self.name] except KeyError: raise RuntimeError("No LM pre-trained for the current Encodec model.") url = _get_checkpoint_url(ROOT_URL, checkpoint_name) state = torch.hub.load_state_dict_from_url( url, map_location='cpu', check_hash=True) # type: ignore lm.load_state_dict(state) lm.eval() return lm @staticmethod def _get_model(target_bandwidths: tp.List[float], sample_rate: int = 24_000, channels: int = 1, causal: bool = True, model_norm: str = 'weight_norm', audio_normalize: bool = False, segment: tp.Optional[float] = None, name: str = 'unset'): encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal) decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal) n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10)) quantizer = qt.ResidualVectorQuantizer( dimension=encoder.dimension, n_q=n_q, bins=1024, ) model = EncodecModel( encoder, decoder, quantizer, target_bandwidths, sample_rate, channels, normalize=audio_normalize, segment=segment, name=name, ) return model @staticmethod def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None): if repository is not None: if not repository.is_dir(): raise ValueError(f"{repository} must exist and be a directory.") file = repository / checkpoint_name checksum = file.stem.split('-')[1] _check_checksum(file, checksum) return torch.load(file) else: url = _get_checkpoint_url(ROOT_URL, checkpoint_name) return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore @staticmethod def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None): """Return the pretrained causal 24khz model. """ if repository: assert pretrained target_bandwidths = [1.5, 3., 6, 12., 24.] checkpoint_name = 'encodec_24khz-d7cc33bc.th' sample_rate = 24_000 channels = 1 model = EncodecModel._get_model( target_bandwidths, sample_rate, channels, causal=True, model_norm='weight_norm', audio_normalize=False, name='encodec_24khz' if pretrained else 'unset') if pretrained: state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) model.load_state_dict(state_dict) model.eval() return model @staticmethod def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None): """Return the pretrained 48khz model. """ if repository: assert pretrained target_bandwidths = [3., 6., 12., 24.] checkpoint_name = 'encodec_48khz-7e698e3e.th' sample_rate = 48_000 channels = 2 model = EncodecModel._get_model( target_bandwidths, sample_rate, channels, causal=False, model_norm='time_group_norm', audio_normalize=True, segment=1., name='encodec_48khz' if pretrained else 'unset') if pretrained: state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) model.load_state_dict(state_dict) model.eval() return model def test(): from itertools import product import torchaudio bandwidths = [3, 6, 12, 24] models = { 'encodec_24khz': EncodecModel.encodec_model_24khz, 'encodec_48khz': EncodecModel.encodec_model_48khz } for model_name, bw in product(models.keys(), bandwidths): model = models[model_name]() model.set_target_bandwidth(bw) audio_suffix = model_name.split('_')[1][:3] wav, sr = torchaudio.load(f"test_{audio_suffix}.wav") wav = wav[:, :model.sample_rate * 2] wav_in = wav.unsqueeze(0) wav_dec = model(wav_in)[0] assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape) if __name__ == '__main__': test()