ai-content-maker/.venv/Lib/site-packages/srsly/msgpack/_msgpack_numpy.py

95 lines
2.7 KiB
Python

#!/usr/bin/env python
"""
Support for serialization of numpy data types with msgpack.
"""
# Copyright (c) 2013-2018, Lev E. Givon
# All rights reserved.
# Distributed under the terms of the BSD license:
# http://www.opensource.org/licenses/bsd-license
try:
import numpy as np
has_numpy = True
except ImportError:
has_numpy = False
try:
import cupy
has_cupy = True
except ImportError:
has_cupy = False
def encode_numpy(obj, chain=None):
"""
Data encoder for serializing numpy data types.
"""
if not has_numpy:
return obj if chain is None else chain(obj)
if has_cupy and isinstance(obj, cupy.ndarray):
obj = obj.get()
if isinstance(obj, np.ndarray):
# If the dtype is structured, store the interface description;
# otherwise, store the corresponding array protocol type string:
if obj.dtype.kind == "V":
kind = b"V"
descr = obj.dtype.descr
else:
kind = b""
descr = obj.dtype.str
return {
b"nd": True,
b"type": descr,
b"kind": kind,
b"shape": obj.shape,
b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(),
}
elif isinstance(obj, (np.bool_, np.number)):
return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data}
elif isinstance(obj, complex):
return {b"complex": True, b"data": obj.__repr__()}
else:
return obj if chain is None else chain(obj)
def tostr(x):
if isinstance(x, bytes):
return x.decode()
else:
return str(x)
def decode_numpy(obj, chain=None):
"""
Decoder for deserializing numpy data types.
"""
try:
if b"nd" in obj:
if obj[b"nd"] is True:
# Check if b'kind' is in obj to enable decoding of data
# serialized with older versions (#20):
if b"kind" in obj and obj[b"kind"] == b"V":
descr = [
tuple(tostr(t) if type(t) is bytes else t for t in d)
for d in obj[b"type"]
]
else:
descr = obj[b"type"]
return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape(
obj[b"shape"]
)
else:
descr = obj[b"type"]
return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0]
elif b"complex" in obj:
return complex(tostr(obj[b"data"]))
else:
return obj if chain is None else chain(obj)
except KeyError:
return obj if chain is None else chain(obj)