ai-content-maker/.venv/Lib/site-packages/encodec/compress.py

212 lines
8.1 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""API to compress/decompress audio to bytestreams."""
import io
import math
import struct
import time
import typing as tp
import torch
from . import binary
from .quantization.ac import ArithmeticCoder, ArithmeticDecoder, build_stable_quantized_cdf
from .model import EncodecModel, EncodedFrame
MODELS = {
'encodec_24khz': EncodecModel.encodec_model_24khz,
'encodec_48khz': EncodecModel.encodec_model_48khz,
}
def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
use_lm: bool = True):
"""Compress a waveform to a file-object using the given model.
Args:
model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
Use `utils.convert_audio` if this is not the case.
fo (IO[bytes]): file-object to which the compressed bits will be written.
See `compress` if you want obtain a `bytes` object instead.
use_lm (bool): if True, use a pre-trained language model to further
compress the stream using Entropy Coding. This will slow down compression
quite a bit, expect between 20 to 30% of size reduction.
"""
assert wav.dim() == 2, "Only single waveform can be encoded."
if model.name not in MODELS:
raise ValueError(f"The provided model {model.name} is not supported.")
if use_lm:
lm = model.get_lm_model()
with torch.no_grad():
frames = model.encode(wav[None])
metadata = {
'm': model.name, # model name
'al': wav.shape[-1], # audio_length
'nc': frames[0][0].shape[1], # num_codebooks
'lm': use_lm, # use lm?
}
binary.write_ecdc_header(fo, metadata)
for (frame, scale) in frames:
if scale is not None:
fo.write(struct.pack('!f', scale.cpu().item()))
_, K, T = frame.shape
if use_lm:
coder = ArithmeticCoder(fo)
states: tp.Any = None
offset = 0
input_ = torch.zeros(1, K, 1, dtype=torch.long, device=wav.device)
else:
packer = binary.BitPacker(model.bits_per_codebook, fo)
for t in range(T):
if use_lm:
with torch.no_grad():
probas, states, offset = lm(input_, states, offset)
# We emulate a streaming scenario even though we do not provide an API for it.
# This gives us a more accurate benchmark.
input_ = 1 + frame[:, :, t: t + 1]
for k, value in enumerate(frame[0, :, t].tolist()):
if use_lm:
q_cdf = build_stable_quantized_cdf(
probas[0, :, k, 0], coder.total_range_bits, check=False)
coder.push(value, q_cdf)
else:
packer.push(value)
if use_lm:
coder.flush()
else:
packer.flush()
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
"""Decompress from a file-object.
Returns a tuple `(wav, sample_rate)`.
Args:
fo (IO[bytes]): file-object from which to read. If you want to decompress
from `bytes` instead, see `decompress`.
device: device to use to perform the computations.
"""
metadata = binary.read_ecdc_header(fo)
model_name = metadata['m']
audio_length = metadata['al']
num_codebooks = metadata['nc']
use_lm = metadata['lm']
assert isinstance(audio_length, int)
assert isinstance(num_codebooks, int)
if model_name not in MODELS:
raise ValueError(f"The audio was compressed with an unsupported model {model_name}.")
model = MODELS[model_name]().to(device)
if use_lm:
lm = model.get_lm_model()
frames: tp.List[EncodedFrame] = []
segment_length = model.segment_length or audio_length
segment_stride = model.segment_stride or audio_length
for offset in range(0, audio_length, segment_stride):
this_segment_length = min(audio_length - offset, segment_length)
frame_length = int(math.ceil(this_segment_length / model.sample_rate * model.frame_rate))
if model.normalize:
scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
scale = torch.tensor(scale_f, device=device).view(1)
else:
scale = None
if use_lm:
decoder = ArithmeticDecoder(fo)
states: tp.Any = None
offset = 0
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=device)
for t in range(frame_length):
if use_lm:
with torch.no_grad():
probas, states, offset = lm(input_, states, offset)
code_list: tp.List[int] = []
for k in range(num_codebooks):
if use_lm:
q_cdf = build_stable_quantized_cdf(
probas[0, :, k, 0], decoder.total_range_bits, check=False)
code = decoder.pull(q_cdf)
else:
code = unpacker.pull()
if code is None:
raise EOFError("The stream ended sooner than expected.")
code_list.append(code)
codes = torch.tensor(code_list, dtype=torch.long, device=device)
frame[0, :, t] = codes
if use_lm:
input_ = 1 + frame[:, :, t: t + 1]
frames.append((frame, scale))
with torch.no_grad():
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes:
"""Compress a waveform using the given model. Returns the compressed bytes.
Args:
model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
Use `utils.convert_audio` if this is not the case.
use_lm (bool): if True, use a pre-trained language model to further
compress the stream using Entropy Coding. This will slow down compression
quite a bit, expect between 20 to 30% of size reduction.
"""
fo = io.BytesIO()
compress_to_file(model, wav, fo, use_lm=use_lm)
return fo.getvalue()
def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
"""Decompress from a file-object.
Returns a tuple `(wav, sample_rate)`.
Args:
compressed (bytes): compressed bytes.
device: device to use to perform the computations.
"""
fo = io.BytesIO(compressed)
return decompress_from_file(fo, device=device)
def test():
import torchaudio
torch.set_num_threads(1)
for name in MODELS.keys():
model = MODELS[name]()
sr = model.sample_rate // 1000
x, _ = torchaudio.load(f'test_{sr}k.wav')
x = x[:, :model.sample_rate * 5]
model.set_target_bandwidth(12)
for use_lm in [False, True]:
print(f"Doing {name}, use_lm={use_lm}")
begin = time.time()
res = compress(model, x, use_lm=use_lm)
t_comp = time.time() - begin
x_dec, _ = decompress(res)
t_decomp = time.time() - begin - t_comp
kbps = 8 * len(res) / 1000 / (x.shape[-1] / model.sample_rate)
print(f"kbps: {kbps:.1f}, time comp: {t_comp:.1f} sec. "
f"time decomp:{t_decomp:.1f}.")
assert x_dec.shape == x.shape
if __name__ == '__main__':
test()