916 lines
32 KiB
Python
916 lines
32 KiB
Python
#!/usr/bin/env python
|
|
# This file is distributed under the terms of the 2-clause BSD License.
|
|
# Copyright (c) 2017-2018, Almar Klein
|
|
|
|
"""
|
|
Python implementation of the Binary Structured Data Format (BSDF).
|
|
|
|
BSDF is a binary format for serializing structured (scientific) data.
|
|
See http://bsdf.io for more information.
|
|
|
|
This is the reference implementation, which is relatively relatively
|
|
sophisticated, providing e.g. lazy loading of blobs and streamed
|
|
reading/writing. A simpler Python implementation is available as
|
|
``bsdf_lite.py``.
|
|
|
|
This module has no dependencies and works on Python 2.7 and 3.4+.
|
|
|
|
Note: on Legacy Python (Python 2.7), non-Unicode strings are encoded as bytes.
|
|
"""
|
|
|
|
# todo: in 2020, remove six stuff, __future__ and _isidentifier
|
|
# todo: in 2020, remove 'utf-8' args to encode/decode; it's faster
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import bz2
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import re
|
|
import struct
|
|
import sys
|
|
import types
|
|
import zlib
|
|
from io import BytesIO
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Notes on versioning: the major and minor numbers correspond to the
|
|
# BSDF format version. The major number if increased when backward
|
|
# incompatible changes are introduced. An implementation must raise an
|
|
# exception when the file being read has a higher major version. The
|
|
# minor number is increased when new backward compatible features are
|
|
# introduced. An implementation must display a warning when the file
|
|
# being read has a higher minor version. The patch version is increased
|
|
# for subsequent releases of the implementation.
|
|
VERSION = 2, 1, 2
|
|
__version__ = ".".join(str(i) for i in VERSION)
|
|
|
|
|
|
# %% The encoder and decoder implementation
|
|
|
|
# From six.py
|
|
PY3 = sys.version_info[0] >= 3
|
|
if PY3:
|
|
text_type = str
|
|
string_types = str
|
|
unicode_types = str
|
|
integer_types = int
|
|
classtypes = type
|
|
else: # pragma: no cover
|
|
logging.basicConfig() # avoid "no handlers found" error
|
|
text_type = unicode # noqa
|
|
string_types = basestring # noqa
|
|
unicode_types = unicode # noqa
|
|
integer_types = (int, long) # noqa
|
|
classtypes = type, types.ClassType
|
|
|
|
# Shorthands
|
|
spack = struct.pack
|
|
strunpack = struct.unpack
|
|
|
|
|
|
def lencode(x):
|
|
"""Encode an unsigned integer into a variable sized blob of bytes."""
|
|
# We could support 16 bit and 32 bit as well, but the gain is low, since
|
|
# 9 bytes for collections with over 250 elements is marginal anyway.
|
|
if x <= 250:
|
|
return spack("<B", x)
|
|
# elif x < 65536:
|
|
# return spack('<BH', 251, x)
|
|
# elif x < 4294967296:
|
|
# return spack('<BI', 252, x)
|
|
else:
|
|
return spack("<BQ", 253, x)
|
|
|
|
|
|
# Include len decoder for completeness; we've inlined it for performance.
|
|
def lendecode(f):
|
|
"""Decode an unsigned integer from a file."""
|
|
n = strunpack("<B", f.read(1))[0]
|
|
if n == 253:
|
|
n = strunpack("<Q", f.read(8))[0] # noqa
|
|
return n
|
|
|
|
|
|
def encode_type_id(b, ext_id):
|
|
"""Encode the type identifier, with or without extension id."""
|
|
if ext_id is not None:
|
|
bb = ext_id.encode("UTF-8")
|
|
return b.upper() + lencode(len(bb)) + bb # noqa
|
|
else:
|
|
return b # noqa
|
|
|
|
|
|
def _isidentifier(s): # pragma: no cover
|
|
"""Use of str.isidentifier() for Legacy Python, but slower."""
|
|
# http://stackoverflow.com/questions/2544972/
|
|
return (
|
|
isinstance(s, string_types)
|
|
and re.match(r"^\w+$", s, re.UNICODE)
|
|
and re.match(r"^[0-9]", s) is None
|
|
)
|
|
|
|
|
|
class BsdfSerializer(object):
|
|
"""Instances of this class represent a BSDF encoder/decoder.
|
|
|
|
It acts as a placeholder for a set of extensions and encoding/decoding
|
|
options. Use this to predefine extensions and options for high
|
|
performance encoding/decoding. For general use, see the functions
|
|
`save()`, `encode()`, `load()`, and `decode()`.
|
|
|
|
This implementation of BSDF supports streaming lists (keep adding
|
|
to a list after writing the main file), lazy loading of blobs, and
|
|
in-place editing of blobs (for streams opened with a+).
|
|
|
|
Options for encoding:
|
|
|
|
* compression (int or str): ``0`` or "no" for no compression (default),
|
|
``1`` or "zlib" for Zlib compression (same as zip files and PNG), and
|
|
``2`` or "bz2" for Bz2 compression (more compact but slower writing).
|
|
Note that some BSDF implementations (e.g. JavaScript) may not support
|
|
compression.
|
|
* use_checksum (bool): whether to include a checksum with binary blobs.
|
|
* float64 (bool): Whether to write floats as 64 bit (default) or 32 bit.
|
|
|
|
Options for decoding:
|
|
|
|
* load_streaming (bool): if True, and the final object in the structure was
|
|
a stream, will make it available as a stream in the decoded object.
|
|
* lazy_blob (bool): if True, bytes are represented as Blob objects that can
|
|
be used to lazily access the data, and also overwrite the data if the
|
|
file is open in a+ mode.
|
|
"""
|
|
|
|
def __init__(self, extensions=None, **options):
|
|
self._extensions = {} # name -> extension
|
|
self._extensions_by_cls = {} # cls -> (name, extension.encode)
|
|
if extensions is None:
|
|
extensions = standard_extensions
|
|
for extension in extensions:
|
|
self.add_extension(extension)
|
|
self._parse_options(**options)
|
|
|
|
def _parse_options(
|
|
self,
|
|
compression=0,
|
|
use_checksum=False,
|
|
float64=True,
|
|
load_streaming=False,
|
|
lazy_blob=False,
|
|
):
|
|
# Validate compression
|
|
if isinstance(compression, string_types):
|
|
m = {"no": 0, "zlib": 1, "bz2": 2}
|
|
compression = m.get(compression.lower(), compression)
|
|
if compression not in (0, 1, 2):
|
|
raise TypeError("Compression must be 0, 1, 2, " '"no", "zlib", or "bz2"')
|
|
self._compression = compression
|
|
|
|
# Other encoding args
|
|
self._use_checksum = bool(use_checksum)
|
|
self._float64 = bool(float64)
|
|
|
|
# Decoding args
|
|
self._load_streaming = bool(load_streaming)
|
|
self._lazy_blob = bool(lazy_blob)
|
|
|
|
def add_extension(self, extension_class):
|
|
"""Add an extension to this serializer instance, which must be
|
|
a subclass of Extension. Can be used as a decorator.
|
|
"""
|
|
# Check class
|
|
if not (
|
|
isinstance(extension_class, type) and issubclass(extension_class, Extension)
|
|
):
|
|
raise TypeError("add_extension() expects a Extension class.")
|
|
extension = extension_class()
|
|
|
|
# Get name
|
|
name = extension.name
|
|
if not isinstance(name, str):
|
|
raise TypeError("Extension name must be str.")
|
|
if len(name) == 0 or len(name) > 250:
|
|
raise NameError(
|
|
"Extension names must be nonempty and shorter " "than 251 chars."
|
|
)
|
|
if name in self._extensions:
|
|
logger.warning(
|
|
'BSDF warning: overwriting extension "%s", '
|
|
"consider removing first" % name
|
|
)
|
|
|
|
# Get classes
|
|
cls = extension.cls
|
|
if not cls:
|
|
clss = []
|
|
elif isinstance(cls, (tuple, list)):
|
|
clss = cls
|
|
else:
|
|
clss = [cls]
|
|
for cls in clss:
|
|
if not isinstance(cls, classtypes):
|
|
raise TypeError("Extension classes must be types.")
|
|
|
|
# Store
|
|
for cls in clss:
|
|
self._extensions_by_cls[cls] = name, extension.encode
|
|
self._extensions[name] = extension
|
|
return extension_class
|
|
|
|
def remove_extension(self, name):
|
|
"""Remove a converted by its unique name."""
|
|
if not isinstance(name, str):
|
|
raise TypeError("Extension name must be str.")
|
|
if name in self._extensions:
|
|
self._extensions.pop(name)
|
|
for cls in list(self._extensions_by_cls.keys()):
|
|
if self._extensions_by_cls[cls][0] == name:
|
|
self._extensions_by_cls.pop(cls)
|
|
|
|
def _encode(self, f, value, streams, ext_id):
|
|
"""Main encoder function."""
|
|
x = encode_type_id
|
|
|
|
if value is None:
|
|
f.write(x(b"v", ext_id)) # V for void
|
|
elif value is True:
|
|
f.write(x(b"y", ext_id)) # Y for yes
|
|
elif value is False:
|
|
f.write(x(b"n", ext_id)) # N for no
|
|
elif isinstance(value, integer_types):
|
|
if -32768 <= value <= 32767:
|
|
f.write(x(b"h", ext_id) + spack("h", value)) # H for ...
|
|
else:
|
|
f.write(x(b"i", ext_id) + spack("<q", value)) # I for int
|
|
elif isinstance(value, float):
|
|
if self._float64:
|
|
f.write(x(b"d", ext_id) + spack("<d", value)) # D for double
|
|
else:
|
|
f.write(x(b"f", ext_id) + spack("<f", value)) # f for float
|
|
elif isinstance(value, unicode_types):
|
|
bb = value.encode("UTF-8")
|
|
f.write(x(b"s", ext_id) + lencode(len(bb))) # S for str
|
|
f.write(bb)
|
|
elif isinstance(value, (list, tuple)):
|
|
f.write(x(b"l", ext_id) + lencode(len(value))) # L for list
|
|
for v in value:
|
|
self._encode(f, v, streams, None)
|
|
elif isinstance(value, dict):
|
|
f.write(x(b"m", ext_id) + lencode(len(value))) # M for mapping
|
|
for key, v in value.items():
|
|
if PY3:
|
|
assert key.isidentifier() # faster
|
|
else: # pragma: no cover
|
|
assert _isidentifier(key)
|
|
# yield ' ' * indent + key
|
|
name_b = key.encode("UTF-8")
|
|
f.write(lencode(len(name_b)))
|
|
f.write(name_b)
|
|
self._encode(f, v, streams, None)
|
|
elif isinstance(value, bytes):
|
|
f.write(x(b"b", ext_id)) # B for blob
|
|
blob = Blob(
|
|
value, compression=self._compression, use_checksum=self._use_checksum
|
|
)
|
|
blob._to_file(f) # noqa
|
|
elif isinstance(value, Blob):
|
|
f.write(x(b"b", ext_id)) # B for blob
|
|
value._to_file(f) # noqa
|
|
elif isinstance(value, BaseStream):
|
|
# Initialize the stream
|
|
if value.mode != "w":
|
|
raise ValueError("Cannot serialize a read-mode stream.")
|
|
elif isinstance(value, ListStream):
|
|
f.write(x(b"l", ext_id) + spack("<BQ", 255, 0)) # L for list
|
|
else:
|
|
raise TypeError("Only ListStream is supported")
|
|
# Mark this as *the* stream, and activate the stream.
|
|
# The save() function verifies this is the last written object.
|
|
if len(streams) > 0:
|
|
raise ValueError("Can only have one stream per file.")
|
|
streams.append(value)
|
|
value._activate(f, self._encode, self._decode) # noqa
|
|
else:
|
|
if ext_id is not None:
|
|
raise ValueError(
|
|
"Extension %s wronfully encodes object to another "
|
|
"extension object (though it may encode to a list/dict "
|
|
"that contains other extension objects)." % ext_id
|
|
)
|
|
# Try if the value is of a type we know
|
|
ex = self._extensions_by_cls.get(value.__class__, None)
|
|
# Maybe its a subclass of a type we know
|
|
if ex is None:
|
|
for name, c in self._extensions.items():
|
|
if c.match(self, value):
|
|
ex = name, c.encode
|
|
break
|
|
else:
|
|
ex = None
|
|
# Success or fail
|
|
if ex is not None:
|
|
ext_id2, extension_encode = ex
|
|
self._encode(f, extension_encode(self, value), streams, ext_id2)
|
|
else:
|
|
t = (
|
|
"Class %r is not a valid base BSDF type, nor is it "
|
|
"handled by an extension."
|
|
)
|
|
raise TypeError(t % value.__class__.__name__)
|
|
|
|
def _decode(self, f):
|
|
"""Main decoder function."""
|
|
|
|
# Get value
|
|
char = f.read(1)
|
|
c = char.lower()
|
|
|
|
# Conversion (uppercase value identifiers signify converted values)
|
|
if not char:
|
|
raise EOFError()
|
|
elif char != c:
|
|
n = strunpack("<B", f.read(1))[0]
|
|
# if n == 253: n = strunpack('<Q', f.read(8))[0] # noqa - noneed
|
|
ext_id = f.read(n).decode("UTF-8")
|
|
else:
|
|
ext_id = None
|
|
|
|
if c == b"v":
|
|
value = None
|
|
elif c == b"y":
|
|
value = True
|
|
elif c == b"n":
|
|
value = False
|
|
elif c == b"h":
|
|
value = strunpack("<h", f.read(2))[0]
|
|
elif c == b"i":
|
|
value = strunpack("<q", f.read(8))[0]
|
|
elif c == b"f":
|
|
value = strunpack("<f", f.read(4))[0]
|
|
elif c == b"d":
|
|
value = strunpack("<d", f.read(8))[0]
|
|
elif c == b"s":
|
|
n_s = strunpack("<B", f.read(1))[0]
|
|
if n_s == 253:
|
|
n_s = strunpack("<Q", f.read(8))[0] # noqa
|
|
value = f.read(n_s).decode("UTF-8")
|
|
elif c == b"l":
|
|
n = strunpack("<B", f.read(1))[0]
|
|
if n >= 254:
|
|
# Streaming
|
|
closed = n == 254
|
|
n = strunpack("<Q", f.read(8))[0]
|
|
if self._load_streaming:
|
|
value = ListStream(n if closed else "r")
|
|
value._activate(f, self._encode, self._decode) # noqa
|
|
elif closed:
|
|
value = [self._decode(f) for i in range(n)]
|
|
else:
|
|
value = []
|
|
try:
|
|
while True:
|
|
value.append(self._decode(f))
|
|
except EOFError:
|
|
pass
|
|
else:
|
|
# Normal
|
|
if n == 253:
|
|
n = strunpack("<Q", f.read(8))[0] # noqa
|
|
value = [self._decode(f) for i in range(n)]
|
|
elif c == b"m":
|
|
value = dict()
|
|
n = strunpack("<B", f.read(1))[0]
|
|
if n == 253:
|
|
n = strunpack("<Q", f.read(8))[0] # noqa
|
|
for i in range(n):
|
|
n_name = strunpack("<B", f.read(1))[0]
|
|
if n_name == 253:
|
|
n_name = strunpack("<Q", f.read(8))[0] # noqa
|
|
assert n_name > 0
|
|
name = f.read(n_name).decode("UTF-8")
|
|
value[name] = self._decode(f)
|
|
elif c == b"b":
|
|
if self._lazy_blob:
|
|
value = Blob((f, True))
|
|
else:
|
|
blob = Blob((f, False))
|
|
value = blob.get_bytes()
|
|
else:
|
|
raise RuntimeError("Parse error %r" % char)
|
|
|
|
# Convert value if we have an extension for it
|
|
if ext_id is not None:
|
|
extension = self._extensions.get(ext_id, None)
|
|
if extension is not None:
|
|
value = extension.decode(self, value)
|
|
else:
|
|
logger.warning("BSDF warning: no extension found for %r" % ext_id)
|
|
|
|
return value
|
|
|
|
def encode(self, ob):
|
|
"""Save the given object to bytes."""
|
|
f = BytesIO()
|
|
self.save(f, ob)
|
|
return f.getvalue()
|
|
|
|
def save(self, f, ob):
|
|
"""Write the given object to the given file object."""
|
|
f.write(b"BSDF")
|
|
f.write(struct.pack("<B", VERSION[0]))
|
|
f.write(struct.pack("<B", VERSION[1]))
|
|
|
|
# Prepare streaming, this list will have 0 or 1 item at the end
|
|
streams = []
|
|
|
|
self._encode(f, ob, streams, None)
|
|
|
|
# Verify that stream object was at the end, and add initial elements
|
|
if len(streams) > 0:
|
|
stream = streams[0]
|
|
if stream._start_pos != f.tell():
|
|
raise ValueError(
|
|
"The stream object must be " "the last object to be encoded."
|
|
)
|
|
|
|
def decode(self, bb):
|
|
"""Load the data structure that is BSDF-encoded in the given bytes."""
|
|
f = BytesIO(bb)
|
|
return self.load(f)
|
|
|
|
def load(self, f):
|
|
"""Load a BSDF-encoded object from the given file object."""
|
|
# Check magic string
|
|
f4 = f.read(4)
|
|
if f4 != b"BSDF":
|
|
raise RuntimeError("This does not look like a BSDF file: %r" % f4)
|
|
# Check version
|
|
major_version = strunpack("<B", f.read(1))[0]
|
|
minor_version = strunpack("<B", f.read(1))[0]
|
|
file_version = "%i.%i" % (major_version, minor_version)
|
|
if major_version != VERSION[0]: # major version should be 2
|
|
t = (
|
|
"Reading file with different major version (%s) "
|
|
"from the implementation (%s)."
|
|
)
|
|
raise RuntimeError(t % (__version__, file_version))
|
|
if minor_version > VERSION[1]: # minor should be < ours
|
|
t = (
|
|
"BSDF warning: reading file with higher minor version (%s) "
|
|
"than the implementation (%s)."
|
|
)
|
|
logger.warning(t % (__version__, file_version))
|
|
|
|
return self._decode(f)
|
|
|
|
|
|
# %% Streaming and blob-files
|
|
|
|
|
|
class BaseStream(object):
|
|
"""Base class for streams."""
|
|
|
|
def __init__(self, mode="w"):
|
|
self._i = 0
|
|
self._count = -1
|
|
if isinstance(mode, int):
|
|
self._count = mode
|
|
mode = "r"
|
|
elif mode == "w":
|
|
self._count = 0
|
|
assert mode in ("r", "w")
|
|
self._mode = mode
|
|
self._f = None
|
|
self._start_pos = 0
|
|
|
|
def _activate(self, file, encode_func, decode_func):
|
|
if self._f is not None: # Associated with another write
|
|
raise IOError("Stream object cannot be activated twice?")
|
|
self._f = file
|
|
self._start_pos = self._f.tell()
|
|
self._encode = encode_func
|
|
self._decode = decode_func
|
|
|
|
@property
|
|
def mode(self):
|
|
"""The mode of this stream: 'r' or 'w'."""
|
|
return self._mode
|
|
|
|
|
|
class ListStream(BaseStream):
|
|
"""A streamable list object used for writing or reading.
|
|
In read mode, it can also be iterated over.
|
|
"""
|
|
|
|
@property
|
|
def count(self):
|
|
"""The number of elements in the stream (can be -1 for unclosed
|
|
streams in read-mode).
|
|
"""
|
|
return self._count
|
|
|
|
@property
|
|
def index(self):
|
|
"""The current index of the element to read/write."""
|
|
return self._i
|
|
|
|
def append(self, item):
|
|
"""Append an item to the streaming list. The object is immediately
|
|
serialized and written to the underlying file.
|
|
"""
|
|
# if self._mode != 'w':
|
|
# raise IOError('This ListStream is not in write mode.')
|
|
if self._count != self._i:
|
|
raise IOError("Can only append items to the end of the stream.")
|
|
if self._f is None:
|
|
raise IOError("List stream is not associated with a file yet.")
|
|
if self._f.closed:
|
|
raise IOError("Cannot stream to a close file.")
|
|
self._encode(self._f, item, [self], None)
|
|
self._i += 1
|
|
self._count += 1
|
|
|
|
def close(self, unstream=False):
|
|
"""Close the stream, marking the number of written elements. New
|
|
elements may still be appended, but they won't be read during decoding.
|
|
If ``unstream`` is False, the stream is turned into a regular list
|
|
(not streaming).
|
|
"""
|
|
# if self._mode != 'w':
|
|
# raise IOError('This ListStream is not in write mode.')
|
|
if self._count != self._i:
|
|
raise IOError("Can only close when at the end of the stream.")
|
|
if self._f is None:
|
|
raise IOError("ListStream is not associated with a file yet.")
|
|
if self._f.closed:
|
|
raise IOError("Cannot close a stream on a close file.")
|
|
i = self._f.tell()
|
|
self._f.seek(self._start_pos - 8 - 1)
|
|
self._f.write(spack("<B", 253 if unstream else 254))
|
|
self._f.write(spack("<Q", self._count))
|
|
self._f.seek(i)
|
|
|
|
def next(self):
|
|
"""Read and return the next element in the streaming list.
|
|
Raises StopIteration if the stream is exhausted.
|
|
"""
|
|
if self._mode != "r":
|
|
raise IOError("This ListStream in not in read mode.")
|
|
if self._f is None:
|
|
raise IOError("ListStream is not associated with a file yet.")
|
|
if getattr(self._f, "closed", None): # not present on 2.7 http req :/
|
|
raise IOError("Cannot read a stream from a close file.")
|
|
if self._count >= 0:
|
|
if self._i >= self._count:
|
|
raise StopIteration()
|
|
self._i += 1
|
|
return self._decode(self._f)
|
|
else:
|
|
# This raises EOFError at some point.
|
|
try:
|
|
res = self._decode(self._f)
|
|
self._i += 1
|
|
return res
|
|
except EOFError:
|
|
self._count = self._i
|
|
raise StopIteration()
|
|
|
|
def __iter__(self):
|
|
if self._mode != "r":
|
|
raise IOError("Cannot iterate: ListStream in not in read mode.")
|
|
return self
|
|
|
|
def __next__(self):
|
|
return self.next()
|
|
|
|
|
|
class Blob(object):
|
|
"""Object to represent a blob of bytes. When used to write a BSDF file,
|
|
it's a wrapper for bytes plus properties such as what compression to apply.
|
|
When used to read a BSDF file, it can be used to read the data lazily, and
|
|
also modify the data if reading in 'r+' mode and the blob isn't compressed.
|
|
"""
|
|
|
|
# For now, this does not allow re-sizing blobs (within the allocated size)
|
|
# but this can be added later.
|
|
|
|
def __init__(self, bb, compression=0, extra_size=0, use_checksum=False):
|
|
if isinstance(bb, bytes):
|
|
self._f = None
|
|
self.compressed = self._from_bytes(bb, compression)
|
|
self.compression = compression
|
|
self.allocated_size = self.used_size + extra_size
|
|
self.use_checksum = use_checksum
|
|
elif isinstance(bb, tuple) and len(bb) == 2 and hasattr(bb[0], "read"):
|
|
self._f, allow_seek = bb
|
|
self.compressed = None
|
|
self._from_file(self._f, allow_seek)
|
|
self._modified = False
|
|
else:
|
|
raise TypeError("Wrong argument to create Blob.")
|
|
|
|
def _from_bytes(self, value, compression):
|
|
"""When used to wrap bytes in a blob."""
|
|
if compression == 0:
|
|
compressed = value
|
|
elif compression == 1:
|
|
compressed = zlib.compress(value, 9)
|
|
elif compression == 2:
|
|
compressed = bz2.compress(value, 9)
|
|
else: # pragma: no cover
|
|
assert False, "Unknown compression identifier"
|
|
|
|
self.data_size = len(value)
|
|
self.used_size = len(compressed)
|
|
return compressed
|
|
|
|
def _to_file(self, f):
|
|
"""Private friend method called by encoder to write a blob to a file."""
|
|
# Write sizes - write at least in a size that allows resizing
|
|
if self.allocated_size <= 250 and self.compression == 0:
|
|
f.write(spack("<B", self.allocated_size))
|
|
f.write(spack("<B", self.used_size))
|
|
f.write(lencode(self.data_size))
|
|
else:
|
|
f.write(spack("<BQ", 253, self.allocated_size))
|
|
f.write(spack("<BQ", 253, self.used_size))
|
|
f.write(spack("<BQ", 253, self.data_size))
|
|
# Compression and checksum
|
|
f.write(spack("B", self.compression))
|
|
if self.use_checksum:
|
|
f.write(b"\xff" + hashlib.md5(self.compressed).digest())
|
|
else:
|
|
f.write(b"\x00")
|
|
# Byte alignment (only necessary for uncompressed data)
|
|
if self.compression == 0:
|
|
alignment = 8 - (f.tell() + 1) % 8 # +1 for the byte to write
|
|
f.write(spack("<B", alignment)) # padding for byte alignment
|
|
f.write(b"\x00" * alignment)
|
|
else:
|
|
f.write(spack("<B", 0))
|
|
# The actual data and extra space
|
|
f.write(self.compressed)
|
|
f.write(b"\x00" * (self.allocated_size - self.used_size))
|
|
|
|
def _from_file(self, f, allow_seek):
|
|
"""Used when a blob is read by the decoder."""
|
|
# Read blob header data (5 to 42 bytes)
|
|
# Size
|
|
allocated_size = strunpack("<B", f.read(1))[0]
|
|
if allocated_size == 253:
|
|
allocated_size = strunpack("<Q", f.read(8))[0] # noqa
|
|
used_size = strunpack("<B", f.read(1))[0]
|
|
if used_size == 253:
|
|
used_size = strunpack("<Q", f.read(8))[0] # noqa
|
|
data_size = strunpack("<B", f.read(1))[0]
|
|
if data_size == 253:
|
|
data_size = strunpack("<Q", f.read(8))[0] # noqa
|
|
# Compression and checksum
|
|
compression = strunpack("<B", f.read(1))[0]
|
|
has_checksum = strunpack("<B", f.read(1))[0]
|
|
if has_checksum:
|
|
checksum = f.read(16)
|
|
# Skip alignment
|
|
alignment = strunpack("<B", f.read(1))[0]
|
|
f.read(alignment)
|
|
# Get or skip data + extra space
|
|
if allow_seek:
|
|
self.start_pos = f.tell()
|
|
self.end_pos = self.start_pos + used_size
|
|
f.seek(self.start_pos + allocated_size)
|
|
else:
|
|
self.start_pos = None
|
|
self.end_pos = None
|
|
self.compressed = f.read(used_size)
|
|
f.read(allocated_size - used_size)
|
|
# Store info
|
|
self.alignment = alignment
|
|
self.compression = compression
|
|
self.use_checksum = checksum if has_checksum else None
|
|
self.used_size = used_size
|
|
self.allocated_size = allocated_size
|
|
self.data_size = data_size
|
|
|
|
def seek(self, p):
|
|
"""Seek to the given position (relative to the blob start)."""
|
|
if self._f is None:
|
|
raise RuntimeError(
|
|
"Cannot seek in a blob " "that is not created by the BSDF decoder."
|
|
)
|
|
if p < 0:
|
|
p = self.allocated_size + p
|
|
if p < 0 or p > self.allocated_size:
|
|
raise IOError("Seek beyond blob boundaries.")
|
|
self._f.seek(self.start_pos + p)
|
|
|
|
def tell(self):
|
|
"""Get the current file pointer position (relative to the blob start)."""
|
|
if self._f is None:
|
|
raise RuntimeError(
|
|
"Cannot tell in a blob " "that is not created by the BSDF decoder."
|
|
)
|
|
return self._f.tell() - self.start_pos
|
|
|
|
def write(self, bb):
|
|
"""Write bytes to the blob."""
|
|
if self._f is None:
|
|
raise RuntimeError(
|
|
"Cannot write in a blob " "that is not created by the BSDF decoder."
|
|
)
|
|
if self.compression:
|
|
raise IOError("Cannot arbitrarily write in compressed blob.")
|
|
if self._f.tell() + len(bb) > self.end_pos:
|
|
raise IOError("Write beyond blob boundaries.")
|
|
self._modified = True
|
|
return self._f.write(bb)
|
|
|
|
def read(self, n):
|
|
"""Read n bytes from the blob."""
|
|
if self._f is None:
|
|
raise RuntimeError(
|
|
"Cannot read in a blob " "that is not created by the BSDF decoder."
|
|
)
|
|
if self.compression:
|
|
raise IOError("Cannot arbitrarily read in compressed blob.")
|
|
if self._f.tell() + n > self.end_pos:
|
|
raise IOError("Read beyond blob boundaries.")
|
|
return self._f.read(n)
|
|
|
|
def get_bytes(self):
|
|
"""Get the contents of the blob as bytes."""
|
|
if self.compressed is not None:
|
|
compressed = self.compressed
|
|
else:
|
|
i = self._f.tell()
|
|
self.seek(0)
|
|
compressed = self._f.read(self.used_size)
|
|
self._f.seek(i)
|
|
if self.compression == 0:
|
|
value = compressed
|
|
elif self.compression == 1:
|
|
value = zlib.decompress(compressed)
|
|
elif self.compression == 2:
|
|
value = bz2.decompress(compressed)
|
|
else: # pragma: no cover
|
|
raise RuntimeError("Invalid compression %i" % self.compression)
|
|
return value
|
|
|
|
def update_checksum(self):
|
|
"""Reset the blob's checksum if present. Call this after modifying
|
|
the data.
|
|
"""
|
|
# or ... should the presence of a checksum mean that data is proteced?
|
|
if self.use_checksum and self._modified:
|
|
self.seek(0)
|
|
compressed = self._f.read(self.used_size)
|
|
self._f.seek(self.start_pos - self.alignment - 1 - 16)
|
|
self._f.write(hashlib.md5(compressed).digest())
|
|
|
|
|
|
# %% High-level functions
|
|
|
|
|
|
def encode(ob, extensions=None, **options):
|
|
"""Save (BSDF-encode) the given object to bytes.
|
|
See `BSDFSerializer` for details on extensions and options.
|
|
"""
|
|
s = BsdfSerializer(extensions, **options)
|
|
return s.encode(ob)
|
|
|
|
|
|
def save(f, ob, extensions=None, **options):
|
|
"""Save (BSDF-encode) the given object to the given filename or
|
|
file object. See` BSDFSerializer` for details on extensions and options.
|
|
"""
|
|
s = BsdfSerializer(extensions, **options)
|
|
if isinstance(f, string_types):
|
|
with open(f, "wb") as fp:
|
|
return s.save(fp, ob)
|
|
else:
|
|
return s.save(f, ob)
|
|
|
|
|
|
def decode(bb, extensions=None, **options):
|
|
"""Load a (BSDF-encoded) structure from bytes.
|
|
See `BSDFSerializer` for details on extensions and options.
|
|
"""
|
|
s = BsdfSerializer(extensions, **options)
|
|
return s.decode(bb)
|
|
|
|
|
|
def load(f, extensions=None, **options):
|
|
"""Load a (BSDF-encoded) structure from the given filename or file object.
|
|
See `BSDFSerializer` for details on extensions and options.
|
|
"""
|
|
s = BsdfSerializer(extensions, **options)
|
|
if isinstance(f, string_types):
|
|
if f.startswith(("~/", "~\\")): # pragma: no cover
|
|
f = os.path.expanduser(f)
|
|
with open(f, "rb") as fp:
|
|
return s.load(fp)
|
|
else:
|
|
return s.load(f)
|
|
|
|
|
|
# Aliases for json compat
|
|
loads = decode
|
|
dumps = encode
|
|
|
|
|
|
# %% Standard extensions
|
|
|
|
# Defining extensions as a dict would be more compact and feel lighter, but
|
|
# that would only allow lambdas, which is too limiting, e.g. for ndarray
|
|
# extension.
|
|
|
|
|
|
class Extension(object):
|
|
"""Base class to implement BSDF extensions for special data types.
|
|
|
|
Extension classes are provided to the BSDF serializer, which
|
|
instantiates the class. That way, the extension can be somewhat dynamic:
|
|
e.g. the NDArrayExtension exposes the ndarray class only when numpy
|
|
is imported.
|
|
|
|
A extension instance must have two attributes. These can be attributes of
|
|
the class, or of the instance set in ``__init__()``:
|
|
|
|
* name (str): the name by which encoded values will be identified.
|
|
* cls (type): the type (or list of types) to match values with.
|
|
This is optional, but it makes the encoder select extensions faster.
|
|
|
|
Further, it needs 3 methods:
|
|
|
|
* `match(serializer, value) -> bool`: return whether the extension can
|
|
convert the given value. The default is ``isinstance(value, self.cls)``.
|
|
* `encode(serializer, value) -> encoded_value`: the function to encode a
|
|
value to more basic data types.
|
|
* `decode(serializer, encoded_value) -> value`: the function to decode an
|
|
encoded value back to its intended representation.
|
|
|
|
"""
|
|
|
|
name = ""
|
|
cls = ()
|
|
|
|
def __repr__(self):
|
|
return "<BSDF extension %r at 0x%s>" % (self.name, hex(id(self)))
|
|
|
|
def match(self, s, v):
|
|
return isinstance(v, self.cls)
|
|
|
|
def encode(self, s, v):
|
|
raise NotImplementedError()
|
|
|
|
def decode(self, s, v):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class ComplexExtension(Extension):
|
|
name = "c"
|
|
cls = complex
|
|
|
|
def encode(self, s, v):
|
|
return (v.real, v.imag)
|
|
|
|
def decode(self, s, v):
|
|
return complex(v[0], v[1])
|
|
|
|
|
|
class NDArrayExtension(Extension):
|
|
name = "ndarray"
|
|
|
|
def __init__(self):
|
|
if "numpy" in sys.modules:
|
|
import numpy as np
|
|
|
|
self.cls = np.ndarray
|
|
|
|
def match(self, s, v): # pragma: no cover - e.g. work for nd arrays in JS
|
|
return hasattr(v, "shape") and hasattr(v, "dtype") and hasattr(v, "tobytes")
|
|
|
|
def encode(self, s, v):
|
|
return dict(shape=v.shape, dtype=text_type(v.dtype), data=v.tobytes())
|
|
|
|
def decode(self, s, v):
|
|
try:
|
|
import numpy as np
|
|
except ImportError: # pragma: no cover
|
|
return v
|
|
a = np.frombuffer(v["data"], dtype=v["dtype"])
|
|
a.shape = v["shape"]
|
|
return a
|
|
|
|
|
|
standard_extensions = [ComplexExtension, NDArrayExtension]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Invoke CLI
|
|
import bsdf_cli
|
|
|
|
bsdf_cli.main()
|