326 lines
12 KiB
Python
326 lines
12 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""EnCodec model implementation."""
|
||
|
|
||
|
import math
|
||
|
from pathlib import Path
|
||
|
import typing as tp
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
from . import quantization as qt
|
||
|
from . import modules as m
|
||
|
from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
|
||
|
|
||
|
|
||
|
ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
|
||
|
|
||
|
EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
|
||
|
|
||
|
|
||
|
class LMModel(nn.Module):
|
||
|
"""Language Model to estimate probabilities of each codebook entry.
|
||
|
We predict all codebooks in parallel for a given time step.
|
||
|
|
||
|
Args:
|
||
|
n_q (int): number of codebooks.
|
||
|
card (int): codebook cardinality.
|
||
|
dim (int): transformer dimension.
|
||
|
**kwargs: passed to `encodec.modules.transformer.StreamingTransformerEncoder`.
|
||
|
"""
|
||
|
def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
|
||
|
super().__init__()
|
||
|
self.card = card
|
||
|
self.n_q = n_q
|
||
|
self.dim = dim
|
||
|
self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
|
||
|
self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
|
||
|
self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
|
||
|
|
||
|
def forward(self, indices: torch.Tensor,
|
||
|
states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
|
||
|
"""
|
||
|
Args:
|
||
|
indices (torch.Tensor): indices from the previous time step. Indices
|
||
|
should be 1 + actual index in the codebook. The value 0 is reserved for
|
||
|
when the index is missing (i.e. first time step). Shape should be
|
||
|
`[B, n_q, T]`.
|
||
|
states: state for the streaming decoding.
|
||
|
offset: offset of the current time step.
|
||
|
|
||
|
Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
|
||
|
with a shape `[B, card, n_q, T]`.
|
||
|
|
||
|
"""
|
||
|
B, K, T = indices.shape
|
||
|
input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
|
||
|
out, states, offset = self.transformer(input_, states, offset)
|
||
|
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
|
||
|
return torch.softmax(logits, dim=1), states, offset
|
||
|
|
||
|
|
||
|
class EncodecModel(nn.Module):
|
||
|
"""EnCodec model operating on the raw waveform.
|
||
|
Args:
|
||
|
target_bandwidths (list of float): Target bandwidths.
|
||
|
encoder (nn.Module): Encoder network.
|
||
|
decoder (nn.Module): Decoder network.
|
||
|
sample_rate (int): Audio sample rate.
|
||
|
channels (int): Number of audio channels.
|
||
|
normalize (bool): Whether to apply audio normalization.
|
||
|
segment (float or None): segment duration in sec. when doing overlap-add.
|
||
|
overlap (float): overlap between segment, given as a fraction of the segment duration.
|
||
|
name (str): name of the model, used as metadata when compressing audio.
|
||
|
"""
|
||
|
def __init__(self,
|
||
|
encoder: m.SEANetEncoder,
|
||
|
decoder: m.SEANetDecoder,
|
||
|
quantizer: qt.ResidualVectorQuantizer,
|
||
|
target_bandwidths: tp.List[float],
|
||
|
sample_rate: int,
|
||
|
channels: int,
|
||
|
normalize: bool = False,
|
||
|
segment: tp.Optional[float] = None,
|
||
|
overlap: float = 0.01,
|
||
|
name: str = 'unset'):
|
||
|
super().__init__()
|
||
|
self.bandwidth: tp.Optional[float] = None
|
||
|
self.target_bandwidths = target_bandwidths
|
||
|
self.encoder = encoder
|
||
|
self.quantizer = quantizer
|
||
|
self.decoder = decoder
|
||
|
self.sample_rate = sample_rate
|
||
|
self.channels = channels
|
||
|
self.normalize = normalize
|
||
|
self.segment = segment
|
||
|
self.overlap = overlap
|
||
|
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios))
|
||
|
self.name = name
|
||
|
self.bits_per_codebook = int(math.log2(self.quantizer.bins))
|
||
|
assert 2 ** self.bits_per_codebook == self.quantizer.bins, \
|
||
|
"quantizer bins must be a power of 2."
|
||
|
|
||
|
@property
|
||
|
def segment_length(self) -> tp.Optional[int]:
|
||
|
if self.segment is None:
|
||
|
return None
|
||
|
return int(self.segment * self.sample_rate)
|
||
|
|
||
|
@property
|
||
|
def segment_stride(self) -> tp.Optional[int]:
|
||
|
segment_length = self.segment_length
|
||
|
if segment_length is None:
|
||
|
return None
|
||
|
return max(1, int((1 - self.overlap) * segment_length))
|
||
|
|
||
|
def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
|
||
|
"""Given a tensor `x`, returns a list of frames containing
|
||
|
the discrete encoded codes for `x`, along with rescaling factors
|
||
|
for each segment, when `self.normalize` is True.
|
||
|
|
||
|
Each frames is a tuple `(codebook, scale)`, with `codebook` of
|
||
|
shape `[B, K, T]`, with `K` the number of codebooks.
|
||
|
"""
|
||
|
assert x.dim() == 3
|
||
|
_, channels, length = x.shape
|
||
|
assert channels > 0 and channels <= 2
|
||
|
segment_length = self.segment_length
|
||
|
if segment_length is None:
|
||
|
segment_length = length
|
||
|
stride = length
|
||
|
else:
|
||
|
stride = self.segment_stride # type: ignore
|
||
|
assert stride is not None
|
||
|
|
||
|
encoded_frames: tp.List[EncodedFrame] = []
|
||
|
for offset in range(0, length, stride):
|
||
|
frame = x[:, :, offset: offset + segment_length]
|
||
|
encoded_frames.append(self._encode_frame(frame))
|
||
|
return encoded_frames
|
||
|
|
||
|
def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
|
||
|
length = x.shape[-1]
|
||
|
duration = length / self.sample_rate
|
||
|
assert self.segment is None or duration <= 1e-5 + self.segment
|
||
|
|
||
|
if self.normalize:
|
||
|
mono = x.mean(dim=1, keepdim=True)
|
||
|
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
||
|
scale = 1e-8 + volume
|
||
|
x = x / scale
|
||
|
scale = scale.view(-1, 1)
|
||
|
else:
|
||
|
scale = None
|
||
|
|
||
|
emb = self.encoder(x)
|
||
|
codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
|
||
|
codes = codes.transpose(0, 1)
|
||
|
# codes is [B, K, T], with T frames, K nb of codebooks.
|
||
|
return codes, scale
|
||
|
|
||
|
def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor:
|
||
|
"""Decode the given frames into a waveform.
|
||
|
Note that the output might be a bit bigger than the input. In that case,
|
||
|
any extra steps at the end can be trimmed.
|
||
|
"""
|
||
|
segment_length = self.segment_length
|
||
|
if segment_length is None:
|
||
|
assert len(encoded_frames) == 1
|
||
|
return self._decode_frame(encoded_frames[0])
|
||
|
|
||
|
frames = [self._decode_frame(frame) for frame in encoded_frames]
|
||
|
return _linear_overlap_add(frames, self.segment_stride or 1)
|
||
|
|
||
|
def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor:
|
||
|
codes, scale = encoded_frame
|
||
|
codes = codes.transpose(0, 1)
|
||
|
emb = self.quantizer.decode(codes)
|
||
|
out = self.decoder(emb)
|
||
|
if scale is not None:
|
||
|
out = out * scale.view(-1, 1, 1)
|
||
|
return out
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
frames = self.encode(x)
|
||
|
return self.decode(frames)[:, :, :x.shape[-1]]
|
||
|
|
||
|
def set_target_bandwidth(self, bandwidth: float):
|
||
|
if bandwidth not in self.target_bandwidths:
|
||
|
raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. "
|
||
|
f"Select one of {self.target_bandwidths}.")
|
||
|
self.bandwidth = bandwidth
|
||
|
|
||
|
def get_lm_model(self) -> LMModel:
|
||
|
"""Return the associated LM model to improve the compression rate.
|
||
|
"""
|
||
|
torch.manual_seed(1234) # todo remove: this
|
||
|
device = next(self.parameters()).device
|
||
|
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
|
||
|
past_context=int(3.5 * self.frame_rate)).to(device)
|
||
|
checkpoints = {
|
||
|
'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
|
||
|
'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
|
||
|
}
|
||
|
try:
|
||
|
checkpoint_name = checkpoints[self.name]
|
||
|
except KeyError:
|
||
|
raise RuntimeError("No LM pre-trained for the current Encodec model.")
|
||
|
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
|
||
|
state = torch.hub.load_state_dict_from_url(
|
||
|
url, map_location='cpu', check_hash=True) # type: ignore
|
||
|
lm.load_state_dict(state)
|
||
|
lm.eval()
|
||
|
return lm
|
||
|
|
||
|
@staticmethod
|
||
|
def _get_model(target_bandwidths: tp.List[float],
|
||
|
sample_rate: int = 24_000,
|
||
|
channels: int = 1,
|
||
|
causal: bool = True,
|
||
|
model_norm: str = 'weight_norm',
|
||
|
audio_normalize: bool = False,
|
||
|
segment: tp.Optional[float] = None,
|
||
|
name: str = 'unset'):
|
||
|
encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal)
|
||
|
decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal)
|
||
|
n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10))
|
||
|
quantizer = qt.ResidualVectorQuantizer(
|
||
|
dimension=encoder.dimension,
|
||
|
n_q=n_q,
|
||
|
bins=1024,
|
||
|
)
|
||
|
model = EncodecModel(
|
||
|
encoder,
|
||
|
decoder,
|
||
|
quantizer,
|
||
|
target_bandwidths,
|
||
|
sample_rate,
|
||
|
channels,
|
||
|
normalize=audio_normalize,
|
||
|
segment=segment,
|
||
|
name=name,
|
||
|
)
|
||
|
return model
|
||
|
|
||
|
@staticmethod
|
||
|
def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None):
|
||
|
if repository is not None:
|
||
|
if not repository.is_dir():
|
||
|
raise ValueError(f"{repository} must exist and be a directory.")
|
||
|
file = repository / checkpoint_name
|
||
|
checksum = file.stem.split('-')[1]
|
||
|
_check_checksum(file, checksum)
|
||
|
return torch.load(file)
|
||
|
else:
|
||
|
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
|
||
|
return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore
|
||
|
|
||
|
@staticmethod
|
||
|
def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
|
||
|
"""Return the pretrained causal 24khz model.
|
||
|
"""
|
||
|
if repository:
|
||
|
assert pretrained
|
||
|
target_bandwidths = [1.5, 3., 6, 12., 24.]
|
||
|
checkpoint_name = 'encodec_24khz-d7cc33bc.th'
|
||
|
sample_rate = 24_000
|
||
|
channels = 1
|
||
|
model = EncodecModel._get_model(
|
||
|
target_bandwidths, sample_rate, channels,
|
||
|
causal=True, model_norm='weight_norm', audio_normalize=False,
|
||
|
name='encodec_24khz' if pretrained else 'unset')
|
||
|
if pretrained:
|
||
|
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
|
||
|
model.load_state_dict(state_dict)
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
@staticmethod
|
||
|
def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
|
||
|
"""Return the pretrained 48khz model.
|
||
|
"""
|
||
|
if repository:
|
||
|
assert pretrained
|
||
|
target_bandwidths = [3., 6., 12., 24.]
|
||
|
checkpoint_name = 'encodec_48khz-7e698e3e.th'
|
||
|
sample_rate = 48_000
|
||
|
channels = 2
|
||
|
model = EncodecModel._get_model(
|
||
|
target_bandwidths, sample_rate, channels,
|
||
|
causal=False, model_norm='time_group_norm', audio_normalize=True,
|
||
|
segment=1., name='encodec_48khz' if pretrained else 'unset')
|
||
|
if pretrained:
|
||
|
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
|
||
|
model.load_state_dict(state_dict)
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
|
||
|
def test():
|
||
|
from itertools import product
|
||
|
import torchaudio
|
||
|
bandwidths = [3, 6, 12, 24]
|
||
|
models = {
|
||
|
'encodec_24khz': EncodecModel.encodec_model_24khz,
|
||
|
'encodec_48khz': EncodecModel.encodec_model_48khz
|
||
|
}
|
||
|
for model_name, bw in product(models.keys(), bandwidths):
|
||
|
model = models[model_name]()
|
||
|
model.set_target_bandwidth(bw)
|
||
|
audio_suffix = model_name.split('_')[1][:3]
|
||
|
wav, sr = torchaudio.load(f"test_{audio_suffix}.wav")
|
||
|
wav = wav[:, :model.sample_rate * 2]
|
||
|
wav_in = wav.unsqueeze(0)
|
||
|
wav_dec = model(wav_in)[0]
|
||
|
assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test()
|