235 lines
8.4 KiB
Python
235 lines
8.4 KiB
Python
|
from unittest import TestCase
|
||
|
from numpy.testing import assert_equal, assert_array_equal
|
||
|
import numpy as np
|
||
|
from srsly import msgpack
|
||
|
|
||
|
|
||
|
class ThirdParty(object):
|
||
|
def __init__(self, foo=b"bar"):
|
||
|
self.foo = foo
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return isinstance(other, ThirdParty) and self.foo == other.foo
|
||
|
|
||
|
|
||
|
class test_numpy_msgpack(TestCase):
|
||
|
def encode_decode(self, x, use_bin_type=False, raw=True):
|
||
|
x_enc = msgpack.packb(x, use_bin_type=use_bin_type)
|
||
|
return msgpack.unpackb(x_enc, raw=raw)
|
||
|
|
||
|
def encode_thirdparty(self, obj):
|
||
|
return dict(__thirdparty__=True, foo=obj.foo)
|
||
|
|
||
|
def decode_thirdparty(self, obj):
|
||
|
if b"__thirdparty__" in obj:
|
||
|
return ThirdParty(foo=obj[b"foo"])
|
||
|
return obj
|
||
|
|
||
|
def encode_decode_thirdparty(self, x, use_bin_type=False, raw=True):
|
||
|
x_enc = msgpack.packb(
|
||
|
x, default=self.encode_thirdparty, use_bin_type=use_bin_type
|
||
|
)
|
||
|
return msgpack.unpackb(x_enc, raw=raw, object_hook=self.decode_thirdparty)
|
||
|
|
||
|
def test_bin(self):
|
||
|
# Since bytes == str in Python 2.7, the following
|
||
|
# should pass on both 2.7 and 3.*
|
||
|
assert_equal(type(self.encode_decode(b"foo")), bytes)
|
||
|
|
||
|
def test_str(self):
|
||
|
assert_equal(type(self.encode_decode("foo")), bytes)
|
||
|
|
||
|
def test_numpy_scalar_bool(self):
|
||
|
x = np.bool_(True)
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_equal(x, x_rec)
|
||
|
assert_equal(type(x), type(x_rec))
|
||
|
x = np.bool_(False)
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_equal(x, x_rec)
|
||
|
assert_equal(type(x), type(x_rec))
|
||
|
|
||
|
def test_numpy_scalar_float(self):
|
||
|
x = np.float32(np.random.rand())
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_equal(x, x_rec)
|
||
|
assert_equal(type(x), type(x_rec))
|
||
|
|
||
|
def test_numpy_scalar_complex(self):
|
||
|
x = np.complex64(np.random.rand() + 1j * np.random.rand())
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_equal(x, x_rec)
|
||
|
assert_equal(type(x), type(x_rec))
|
||
|
|
||
|
def test_scalar_float(self):
|
||
|
x = np.random.rand()
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_equal(x, x_rec)
|
||
|
assert_equal(type(x), type(x_rec))
|
||
|
|
||
|
def test_scalar_complex(self):
|
||
|
x = np.random.rand() + 1j * np.random.rand()
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_equal(x, x_rec)
|
||
|
assert_equal(type(x), type(x_rec))
|
||
|
|
||
|
def test_list_numpy_float(self):
|
||
|
x = [np.float32(np.random.rand()) for i in range(5)]
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
|
||
|
|
||
|
def test_list_numpy_float_complex(self):
|
||
|
x = [np.float32(np.random.rand()) for i in range(5)] + [
|
||
|
np.complex128(np.random.rand() + 1j * np.random.rand()) for i in range(5)
|
||
|
]
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
|
||
|
|
||
|
def test_list_float(self):
|
||
|
x = [np.random.rand() for i in range(5)]
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
|
||
|
|
||
|
def test_list_float_complex(self):
|
||
|
x = [(np.random.rand() + 1j * np.random.rand()) for i in range(5)]
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
|
||
|
|
||
|
def test_list_str(self):
|
||
|
x = [b"x" * i for i in range(5)]
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_array_equal([type(e) for e in x_rec], [bytes] * 5)
|
||
|
|
||
|
def test_dict_float(self):
|
||
|
x = {b"foo": 1.0, b"bar": 2.0}
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.values())],
|
||
|
[type(e) for e in sorted(x_rec.values())],
|
||
|
)
|
||
|
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
|
||
|
)
|
||
|
|
||
|
def test_dict_complex(self):
|
||
|
x = {b"foo": 1.0 + 1.0j, b"bar": 2.0 + 2.0j}
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(
|
||
|
sorted(x.values(), key=np.linalg.norm),
|
||
|
sorted(x_rec.values(), key=np.linalg.norm),
|
||
|
)
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.values(), key=np.linalg.norm)],
|
||
|
[type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)],
|
||
|
)
|
||
|
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
|
||
|
)
|
||
|
|
||
|
def test_dict_str(self):
|
||
|
x = {b"foo": b"xxx", b"bar": b"yyyy"}
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.values())],
|
||
|
[type(e) for e in sorted(x_rec.values())],
|
||
|
)
|
||
|
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
|
||
|
)
|
||
|
|
||
|
def test_dict_numpy_float(self):
|
||
|
x = {b"foo": np.float32(1.0), b"bar": np.float32(2.0)}
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.values())],
|
||
|
[type(e) for e in sorted(x_rec.values())],
|
||
|
)
|
||
|
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
|
||
|
)
|
||
|
|
||
|
def test_dict_numpy_complex(self):
|
||
|
x = {b"foo": np.complex128(1.0 + 1.0j), b"bar": np.complex128(2.0 + 2.0j)}
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(
|
||
|
sorted(x.values(), key=np.linalg.norm),
|
||
|
sorted(x_rec.values(), key=np.linalg.norm),
|
||
|
)
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.values(), key=np.linalg.norm)],
|
||
|
[type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)],
|
||
|
)
|
||
|
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
|
||
|
assert_array_equal(
|
||
|
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
|
||
|
)
|
||
|
|
||
|
def test_numpy_array_float(self):
|
||
|
x = np.random.rand(5).astype(np.float32)
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_equal(x.dtype, x_rec.dtype)
|
||
|
|
||
|
def test_numpy_array_complex(self):
|
||
|
x = (np.random.rand(5) + 1j * np.random.rand(5)).astype(np.complex128)
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_equal(x.dtype, x_rec.dtype)
|
||
|
|
||
|
def test_numpy_array_float_2d(self):
|
||
|
x = np.random.rand(5, 5).astype(np.float32)
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_equal(x.dtype, x_rec.dtype)
|
||
|
|
||
|
def test_numpy_array_str(self):
|
||
|
x = np.array([b"aaa", b"bbbb", b"ccccc"])
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_equal(x.dtype, x_rec.dtype)
|
||
|
|
||
|
def test_numpy_array_mixed(self):
|
||
|
x = np.array(
|
||
|
[(1, 2, b"a", [1.0, 2.0])],
|
||
|
np.dtype(
|
||
|
[
|
||
|
("arg0", np.uint32),
|
||
|
("arg1", np.uint32),
|
||
|
("arg2", "S1"),
|
||
|
("arg3", np.float32, (2,)),
|
||
|
]
|
||
|
),
|
||
|
)
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_equal(x.dtype, x_rec.dtype)
|
||
|
|
||
|
def test_numpy_array_noncontiguous(self):
|
||
|
x = np.ones((10, 10), np.uint32)[0:5, 0:5]
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_equal(x.dtype, x_rec.dtype)
|
||
|
|
||
|
def test_list_mixed(self):
|
||
|
x = [1.0, np.float32(3.5), np.complex128(4.25), b"foo"]
|
||
|
x_rec = self.encode_decode(x)
|
||
|
assert_array_equal(x, x_rec)
|
||
|
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
|
||
|
|
||
|
def test_chain(self):
|
||
|
x = ThirdParty(foo=b"test marshal/unmarshal")
|
||
|
x_rec = self.encode_decode_thirdparty(x)
|
||
|
self.assertEqual(x, x_rec)
|