# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """API to compress/decompress audio to bytestreams.""" import io import math import struct import time import typing as tp import torch from . import binary from .quantization.ac import ArithmeticCoder, ArithmeticDecoder, build_stable_quantized_cdf from .model import EncodecModel, EncodedFrame MODELS = { 'encodec_24khz': EncodecModel.encodec_model_24khz, 'encodec_48khz': EncodecModel.encodec_model_48khz, } def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes], use_lm: bool = True): """Compress a waveform to a file-object using the given model. Args: model (EncodecModel): a pre-trained EncodecModel to use to compress the audio. wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C` matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`). Use `utils.convert_audio` if this is not the case. fo (IO[bytes]): file-object to which the compressed bits will be written. See `compress` if you want obtain a `bytes` object instead. use_lm (bool): if True, use a pre-trained language model to further compress the stream using Entropy Coding. This will slow down compression quite a bit, expect between 20 to 30% of size reduction. """ assert wav.dim() == 2, "Only single waveform can be encoded." if model.name not in MODELS: raise ValueError(f"The provided model {model.name} is not supported.") if use_lm: lm = model.get_lm_model() with torch.no_grad(): frames = model.encode(wav[None]) metadata = { 'm': model.name, # model name 'al': wav.shape[-1], # audio_length 'nc': frames[0][0].shape[1], # num_codebooks 'lm': use_lm, # use lm? } binary.write_ecdc_header(fo, metadata) for (frame, scale) in frames: if scale is not None: fo.write(struct.pack('!f', scale.cpu().item())) _, K, T = frame.shape if use_lm: coder = ArithmeticCoder(fo) states: tp.Any = None offset = 0 input_ = torch.zeros(1, K, 1, dtype=torch.long, device=wav.device) else: packer = binary.BitPacker(model.bits_per_codebook, fo) for t in range(T): if use_lm: with torch.no_grad(): probas, states, offset = lm(input_, states, offset) # We emulate a streaming scenario even though we do not provide an API for it. # This gives us a more accurate benchmark. input_ = 1 + frame[:, :, t: t + 1] for k, value in enumerate(frame[0, :, t].tolist()): if use_lm: q_cdf = build_stable_quantized_cdf( probas[0, :, k, 0], coder.total_range_bits, check=False) coder.push(value, q_cdf) else: packer.push(value) if use_lm: coder.flush() else: packer.flush() def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]: """Decompress from a file-object. Returns a tuple `(wav, sample_rate)`. Args: fo (IO[bytes]): file-object from which to read. If you want to decompress from `bytes` instead, see `decompress`. device: device to use to perform the computations. """ metadata = binary.read_ecdc_header(fo) model_name = metadata['m'] audio_length = metadata['al'] num_codebooks = metadata['nc'] use_lm = metadata['lm'] assert isinstance(audio_length, int) assert isinstance(num_codebooks, int) if model_name not in MODELS: raise ValueError(f"The audio was compressed with an unsupported model {model_name}.") model = MODELS[model_name]().to(device) if use_lm: lm = model.get_lm_model() frames: tp.List[EncodedFrame] = [] segment_length = model.segment_length or audio_length segment_stride = model.segment_stride or audio_length for offset in range(0, audio_length, segment_stride): this_segment_length = min(audio_length - offset, segment_length) frame_length = int(math.ceil(this_segment_length / model.sample_rate * model.frame_rate)) if model.normalize: scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f'))) scale = torch.tensor(scale_f, device=device).view(1) else: scale = None if use_lm: decoder = ArithmeticDecoder(fo) states: tp.Any = None offset = 0 input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=device) else: unpacker = binary.BitUnpacker(model.bits_per_codebook, fo) frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=device) for t in range(frame_length): if use_lm: with torch.no_grad(): probas, states, offset = lm(input_, states, offset) code_list: tp.List[int] = [] for k in range(num_codebooks): if use_lm: q_cdf = build_stable_quantized_cdf( probas[0, :, k, 0], decoder.total_range_bits, check=False) code = decoder.pull(q_cdf) else: code = unpacker.pull() if code is None: raise EOFError("The stream ended sooner than expected.") code_list.append(code) codes = torch.tensor(code_list, dtype=torch.long, device=device) frame[0, :, t] = codes if use_lm: input_ = 1 + frame[:, :, t: t + 1] frames.append((frame, scale)) with torch.no_grad(): wav = model.decode(frames) return wav[0, :, :audio_length], model.sample_rate def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes: """Compress a waveform using the given model. Returns the compressed bytes. Args: model (EncodecModel): a pre-trained EncodecModel to use to compress the audio. wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C` matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`). Use `utils.convert_audio` if this is not the case. use_lm (bool): if True, use a pre-trained language model to further compress the stream using Entropy Coding. This will slow down compression quite a bit, expect between 20 to 30% of size reduction. """ fo = io.BytesIO() compress_to_file(model, wav, fo, use_lm=use_lm) return fo.getvalue() def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]: """Decompress from a file-object. Returns a tuple `(wav, sample_rate)`. Args: compressed (bytes): compressed bytes. device: device to use to perform the computations. """ fo = io.BytesIO(compressed) return decompress_from_file(fo, device=device) def test(): import torchaudio torch.set_num_threads(1) for name in MODELS.keys(): model = MODELS[name]() sr = model.sample_rate // 1000 x, _ = torchaudio.load(f'test_{sr}k.wav') x = x[:, :model.sample_rate * 5] model.set_target_bandwidth(12) for use_lm in [False, True]: print(f"Doing {name}, use_lm={use_lm}") begin = time.time() res = compress(model, x, use_lm=use_lm) t_comp = time.time() - begin x_dec, _ = decompress(res) t_decomp = time.time() - begin - t_comp kbps = 8 * len(res) / 1000 / (x.shape[-1] / model.sample_rate) print(f"kbps: {kbps:.1f}, time comp: {t_comp:.1f} sec. " f"time decomp:{t_decomp:.1f}.") assert x_dec.shape == x.shape if __name__ == '__main__': test()