154 lines
5.1 KiB
Python
154 lines
5.1 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
|
||
|
|
||
|
import io
|
||
|
import json
|
||
|
import struct
|
||
|
import typing as tp
|
||
|
|
||
|
# format is `ECDC` magic code, followed by the header size as uint32.
|
||
|
# Then an uint8 indicates the protocol version (0.)
|
||
|
# The header is then provided as json and should contain all required
|
||
|
# informations for decoding. A raw stream of bytes is then provided
|
||
|
# and should be interpretable using the json header.
|
||
|
_encodec_header_struct = struct.Struct('!4sBI')
|
||
|
_ENCODEC_MAGIC = b'ECDC'
|
||
|
|
||
|
|
||
|
def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
|
||
|
meta_dumped = json.dumps(metadata).encode('utf-8')
|
||
|
version = 0
|
||
|
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped))
|
||
|
fo.write(header)
|
||
|
fo.write(meta_dumped)
|
||
|
fo.flush()
|
||
|
|
||
|
|
||
|
def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
|
||
|
buf = b""
|
||
|
while len(buf) < size:
|
||
|
new_buf = fo.read(size)
|
||
|
if not new_buf:
|
||
|
raise EOFError("Impossible to read enough data from the stream, "
|
||
|
f"{size} bytes remaining.")
|
||
|
buf += new_buf
|
||
|
size -= len(new_buf)
|
||
|
return buf
|
||
|
|
||
|
|
||
|
def read_ecdc_header(fo: tp.IO[bytes]):
|
||
|
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
|
||
|
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
|
||
|
if magic != _ENCODEC_MAGIC:
|
||
|
raise ValueError("File is not in ECDC format.")
|
||
|
if version != 0:
|
||
|
raise ValueError("Version not supported.")
|
||
|
meta_bytes = _read_exactly(fo, meta_size)
|
||
|
return json.loads(meta_bytes.decode('utf-8'))
|
||
|
|
||
|
|
||
|
class BitPacker:
|
||
|
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
|
||
|
Note that for some bandwidth (1.5, 3), the codebook representation
|
||
|
will not cover an integer number of bytes.
|
||
|
|
||
|
Args:
|
||
|
bits (int): number of bits per value that will be pushed.
|
||
|
fo (IO[bytes]): file-object to push the bytes to.
|
||
|
"""
|
||
|
def __init__(self, bits: int, fo: tp.IO[bytes]):
|
||
|
self._current_value = 0
|
||
|
self._current_bits = 0
|
||
|
self.bits = bits
|
||
|
self.fo = fo
|
||
|
|
||
|
def push(self, value: int):
|
||
|
"""Push a new value to the stream. This will immediately
|
||
|
write as many uint8 as possible to the underlying file-object."""
|
||
|
self._current_value += (value << self._current_bits)
|
||
|
self._current_bits += self.bits
|
||
|
while self._current_bits >= 8:
|
||
|
lower_8bits = self._current_value & 0xff
|
||
|
self._current_bits -= 8
|
||
|
self._current_value >>= 8
|
||
|
self.fo.write(bytes([lower_8bits]))
|
||
|
|
||
|
def flush(self):
|
||
|
"""Flushes the remaining partial uint8, call this at the end
|
||
|
of the stream to encode."""
|
||
|
if self._current_bits:
|
||
|
self.fo.write(bytes([self._current_value]))
|
||
|
self._current_value = 0
|
||
|
self._current_bits = 0
|
||
|
self.fo.flush()
|
||
|
|
||
|
|
||
|
class BitUnpacker:
|
||
|
"""BitUnpacker does the opposite of `BitPacker`.
|
||
|
|
||
|
Args:
|
||
|
bits (int): number of bits of the values to decode.
|
||
|
fo (IO[bytes]): file-object to push the bytes to.
|
||
|
"""
|
||
|
def __init__(self, bits: int, fo: tp.IO[bytes]):
|
||
|
self.bits = bits
|
||
|
self.fo = fo
|
||
|
self._mask = (1 << bits) - 1
|
||
|
self._current_value = 0
|
||
|
self._current_bits = 0
|
||
|
|
||
|
def pull(self) -> tp.Optional[int]:
|
||
|
"""
|
||
|
Pull a single value from the stream, potentially reading some
|
||
|
extra bytes from the underlying file-object.
|
||
|
Returns `None` when reaching the end of the stream.
|
||
|
"""
|
||
|
while self._current_bits < self.bits:
|
||
|
buf = self.fo.read(1)
|
||
|
if not buf:
|
||
|
return None
|
||
|
character = buf[0]
|
||
|
self._current_value += character << self._current_bits
|
||
|
self._current_bits += 8
|
||
|
|
||
|
out = self._current_value & self._mask
|
||
|
self._current_value >>= self.bits
|
||
|
self._current_bits -= self.bits
|
||
|
return out
|
||
|
|
||
|
|
||
|
def test():
|
||
|
import torch
|
||
|
torch.manual_seed(1234)
|
||
|
for rep in range(4):
|
||
|
length: int = torch.randint(10, 2_000, (1,)).item()
|
||
|
bits: int = torch.randint(1, 16, (1,)).item()
|
||
|
tokens: tp.List[int] = torch.randint(2 ** bits, (length,)).tolist()
|
||
|
rebuilt: tp.List[int] = []
|
||
|
buf = io.BytesIO()
|
||
|
packer = BitPacker(bits, buf)
|
||
|
for token in tokens:
|
||
|
packer.push(token)
|
||
|
packer.flush()
|
||
|
buf.seek(0)
|
||
|
unpacker = BitUnpacker(bits, buf)
|
||
|
while True:
|
||
|
value = unpacker.pull()
|
||
|
if value is None:
|
||
|
break
|
||
|
rebuilt.append(value)
|
||
|
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
|
||
|
# The flushing mechanism might lead to "ghost" values at the end of the stream.
|
||
|
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits)
|
||
|
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
|
||
|
assert a == b, (idx, a, b)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test()
|