We depend on `msgpack >= 1.0`, simplify logic

This commit is contained in:
Pedro Algarvio 2024-02-26 09:32:20 +00:00
parent 890140fba1
commit 5840ab68a5
7 changed files with 45 additions and 117 deletions

View file

@ -2,7 +2,7 @@
Jinja2
jmespath
msgpack>=0.5,!=0.5.5
msgpack>=1.0.0
PyYAML
MarkupSafe
requests>=1.0.0

View file

@ -84,26 +84,15 @@ def loads(msg, encoding=None, raw=False):
gc.disable() # performance optimization for msgpack
loads_kwargs = {"use_list": True, "ext_hook": ext_type_decoder}
if salt.utils.msgpack.version >= (0, 4, 0):
# msgpack only supports 'encoding' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions
# of msgpack.
if salt.utils.msgpack.version >= (0, 5, 2):
if encoding is None:
loads_kwargs["raw"] = True
else:
loads_kwargs["raw"] = False
else:
loads_kwargs["encoding"] = encoding
try:
ret = salt.utils.msgpack.unpackb(msg, **loads_kwargs)
except UnicodeDecodeError:
# msg contains binary data
loads_kwargs.pop("raw", None)
loads_kwargs.pop("encoding", None)
ret = salt.utils.msgpack.loads(msg, **loads_kwargs)
if encoding is None:
loads_kwargs["raw"] = True
else:
loads_kwargs["raw"] = False
try:
ret = salt.utils.msgpack.unpackb(msg, **loads_kwargs)
except UnicodeDecodeError:
# msg contains binary data
loads_kwargs.pop("raw", None)
ret = salt.utils.msgpack.loads(msg, **loads_kwargs)
if encoding is None and not raw:
ret = salt.transport.frame.decode_embedded_strs(ret)

View file

@ -15,7 +15,7 @@ log = logging.getLogger(__name__)
__all__ = ["deserialize", "serialize", "available"]
available = True
available = salt.utils.msgpack.HAS_MSGPACK
def serialize(obj, **options):

View file

@ -171,13 +171,7 @@ class IPCServer:
else:
return _null
# msgpack deprecated `encoding` starting with version 0.5.2
if salt.utils.msgpack.version >= (0, 5, 2):
# Under Py2 we still want raw to be set to True
msgpack_kwargs = {"raw": False}
else:
msgpack_kwargs = {"encoding": "utf-8"}
unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
unpacker = salt.utils.msgpack.Unpacker(raw=False)
while not stream.closed():
try:
wire_bytes = yield stream.read_bytes(4096, partial=True)
@ -280,13 +274,7 @@ class IPCClient:
self.socket_path = socket_path
self._closing = False
self.stream = None
# msgpack deprecated `encoding` starting with version 0.5.2
if salt.utils.msgpack.version >= (0, 5, 2):
# Under Py2 we still want raw to be set to True
msgpack_kwargs = {"raw": False}
else:
msgpack_kwargs = {"encoding": "utf-8"}
self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
self.unpacker = salt.utils.msgpack.Unpacker(raw=False)
self._connecting_future = None
def connected(self):

View file

@ -11,13 +11,10 @@ try:
# There is a serialization issue on ARM and potentially other platforms for some msgpack bindings, check for it
if (
msgpack.version >= (0, 4, 0)
and msgpack.loads(msgpack.dumps([1, 2, 3], use_bin_type=False), use_list=True)
msgpack.loads(msgpack.dumps([1, 2, 3], use_bin_type=False), use_list=True)
is None
):
raise ImportError
elif msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None:
raise ImportError
HAS_MSGPACK = True
except ImportError:
try:
@ -59,13 +56,7 @@ def _sanitize_msgpack_kwargs(kwargs):
https://github.com/msgpack/msgpack-python/blob/master/ChangeLog.rst
"""
assert isinstance(kwargs, dict)
if version < (0, 6, 0) and kwargs.pop("strict_map_key", None) is not None:
log.info("removing unsupported `strict_map_key` argument from msgpack call")
if version < (0, 5, 2) and kwargs.pop("raw", None) is not None:
log.info("removing unsupported `raw` argument from msgpack call")
if version < (0, 4, 0) and kwargs.pop("use_bin_type", None) is not None:
log.info("removing unsupported `use_bin_type` argument from msgpack call")
if version >= (1, 0, 0) and kwargs.pop("encoding", None) is not None:
if kwargs.pop("encoding", None) is not None:
log.debug("removing unsupported `encoding` argument from msgpack call")
return kwargs
@ -78,32 +69,20 @@ def _sanitize_msgpack_unpack_kwargs(kwargs):
https://github.com/msgpack/msgpack-python/blob/master/ChangeLog.rst
"""
assert isinstance(kwargs, dict)
if version >= (1, 0, 0):
kwargs.setdefault("raw", True)
kwargs.setdefault("strict_map_key", False)
kwargs.setdefault("raw", True)
kwargs.setdefault("strict_map_key", False)
return _sanitize_msgpack_kwargs(kwargs)
def _add_msgpack_unpack_kwargs(kwargs):
"""
Add any msgpack unpack kwargs here.
max_buffer_size: will make sure the buffer is set to a minimum
of 100MiB in versions >=6 and <1.0
"""
assert isinstance(kwargs, dict)
if version >= (0, 6, 0) and version < (1, 0, 0):
kwargs["max_buffer_size"] = 100 * 1024 * 1024
return _sanitize_msgpack_unpack_kwargs(kwargs)
class Unpacker(msgpack.Unpacker):
"""
Wraps the msgpack.Unpacker and removes non-relevant arguments
"""
def __init__(self, *args, **kwargs):
msgpack.Unpacker.__init__(self, *args, **_add_msgpack_unpack_kwargs(kwargs))
msgpack.Unpacker.__init__(
self, *args, **_sanitize_msgpack_unpack_kwargs(kwargs)
)
def pack(o, stream, **kwargs):

View file

@ -26,9 +26,13 @@ def test_load_encoding(tmp_path):
@pytest.mark.parametrize(
"version,encoding", [((2, 1, 3), False), ((1, 0, 0), False), ((0, 6, 2), True)]
"version",
[
(2, 1, 3),
(1, 0, 0),
],
)
def test_load_multiple_versions(version, encoding, tmp_path):
def test_load_multiple_versions(version, tmp_path):
"""
test when using msgpack on multiple versions that
we only remove encoding on >= 1.0.0
@ -47,26 +51,18 @@ def test_load_multiple_versions(version, encoding, tmp_path):
with patch_dump, patch_load:
with salt.utils.files.fopen(fname, "wb") as wfh:
salt.utils.msgpack.dump(data, wfh, encoding="utf-8")
if encoding:
assert "encoding" in mock_dump.call_args.kwargs
else:
assert "encoding" not in mock_dump.call_args.kwargs
assert "encoding" not in mock_dump.call_args.kwargs
with salt.utils.files.fopen(fname, "rb") as rfh:
salt.utils.msgpack.load(rfh, **kwargs)
if encoding:
assert "encoding" in mock_load.call_args.kwargs
else:
assert "encoding" not in mock_load.call_args.kwargs
assert "encoding" not in mock_load.call_args.kwargs
@pytest.mark.parametrize(
"version,exp_kwargs",
[
((1, 0, 0), {"raw": True, "strict_map_key": True, "use_bin_type": True}),
((0, 6, 0), {"raw": True, "strict_map_key": True, "use_bin_type": True}),
((0, 5, 2), {"raw": True, "use_bin_type": True}),
((0, 4, 0), {"use_bin_type": True}),
((0, 3, 0), {}),
],
)
def test_sanitize_msgpack_kwargs(version, exp_kwargs):
@ -82,14 +78,8 @@ def test_sanitize_msgpack_kwargs(version, exp_kwargs):
@pytest.mark.parametrize(
"version,exp_kwargs",
[
((2, 0, 0), {"raw": True, "strict_map_key": True, "use_bin_type": True}),
((1, 0, 0), {"raw": True, "strict_map_key": True, "use_bin_type": True}),
(
(0, 6, 0),
{"strict_map_key": True, "use_bin_type": True, "encoding": "utf-8"},
),
((0, 5, 2), {"use_bin_type": True, "encoding": "utf-8"}),
((0, 4, 0), {"use_bin_type": True, "encoding": "utf-8"}),
((0, 3, 0), {"encoding": "utf-8"}),
],
)
def test_sanitize_msgpack_unpack_kwargs(version, exp_kwargs):

View file

@ -14,14 +14,7 @@ import salt.utils.msgpack
from salt.utils.odict import OrderedDict
from tests.support.unit import TestCase
try:
import msgpack
except ImportError:
import msgpack_pure as msgpack # pylint: disable=import-error
# A keyword to pass to tests that use `raw`, which was added in msgpack 0.5.2
raw = {"raw": False} if msgpack.version > (0, 5, 2) else {}
msgpack = pytest.importorskip("msgpack")
@pytest.mark.skipif(
@ -156,10 +149,7 @@ class TestMsgpack(TestCase):
bio.write(packer.pack(i * 2)) # value
bio.seek(0)
if salt.utils.msgpack.version > (0, 6, 0):
unpacker = salt.utils.msgpack.Unpacker(bio, strict_map_key=False)
else:
unpacker = salt.utils.msgpack.Unpacker(bio)
unpacker = salt.utils.msgpack.Unpacker(bio, strict_map_key=False)
for size in sizes:
self.assertEqual(unpacker.unpack(), {i: i * 2 for i in range(size)})
@ -293,7 +283,7 @@ class TestMsgpack(TestCase):
class MyUnpacker(salt.utils.msgpack.Unpacker):
def __init__(self):
my_kwargs = {}
super().__init__(ext_hook=self._hook, **raw)
super().__init__(ext_hook=self._hook, raw=False)
def _hook(self, code, data):
if code == 1:
@ -314,21 +304,20 @@ class TestMsgpack(TestCase):
def _check(
self, data, pack_func, unpack_func, use_list=False, strict_map_key=False
):
my_kwargs = {}
if salt.utils.msgpack.version >= (0, 6, 0):
my_kwargs["strict_map_key"] = strict_map_key
ret = unpack_func(pack_func(data), use_list=use_list, **my_kwargs)
ret = unpack_func(
pack_func(data), use_list=use_list, strict_map_key=strict_map_key
)
self.assertEqual(ret, data)
def _test_pack_unicode(self, pack_func, unpack_func):
test_data = ["", "abcd", ["defgh"], "Русский текст"]
for td in test_data:
ret = unpack_func(pack_func(td), use_list=True, **raw)
ret = unpack_func(pack_func(td), use_list=True, raw=False)
self.assertEqual(ret, td)
packer = salt.utils.msgpack.Packer()
data = packer.pack(td)
ret = salt.utils.msgpack.Unpacker(
BytesIO(data), use_list=True, **raw
BytesIO(data), use_list=True, raw=False
).unpack()
self.assertEqual(ret, td)
@ -352,19 +341,23 @@ class TestMsgpack(TestCase):
def _test_ignore_unicode_errors(self, pack_func, unpack_func):
ret = unpack_func(
pack_func(b"abc\xeddef", use_bin_type=False), unicode_errors="ignore", **raw
pack_func(b"abc\xeddef", use_bin_type=False),
unicode_errors="ignore",
raw=False,
)
self.assertEqual("abcdef", ret)
def _test_strict_unicode_unpack(self, pack_func, unpack_func):
packed = pack_func(b"abc\xeddef", use_bin_type=False)
self.assertRaises(UnicodeDecodeError, unpack_func, packed, use_list=True, **raw)
self.assertRaises(
UnicodeDecodeError, unpack_func, packed, use_list=True, raw=False
)
def _test_ignore_errors_pack(self, pack_func, unpack_func):
ret = unpack_func(
pack_func("abc\uDC80\uDCFFdef", use_bin_type=True, unicode_errors="ignore"),
use_list=True,
**raw
raw=False,
)
self.assertEqual("abcdef", ret)
@ -372,10 +365,6 @@ class TestMsgpack(TestCase):
ret = unpack_func(pack_func(b"abc"), use_list=True)
self.assertEqual(b"abc", ret)
@pytest.mark.skipif(
salt.utils.msgpack.version < (0, 2, 2),
"use_single_float was added in msgpack==0.2.2",
)
def _test_pack_float(self, pack_func, **kwargs):
self.assertEqual(
b"\xca" + struct.pack(">f", 1.0), pack_func(1.0, use_single_float=True)
@ -402,16 +391,9 @@ class TestMsgpack(TestCase):
pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")]
packer = salt.utils.msgpack.Packer()
packed = packer.pack_map_pairs(pairlist)
if salt.utils.msgpack.version > (0, 6, 0):
unpacked = unpack_func(packed, object_pairs_hook=list, strict_map_key=False)
else:
unpacked = unpack_func(packed, object_pairs_hook=list)
unpacked = unpack_func(packed, object_pairs_hook=list, strict_map_key=False)
self.assertEqual(pairlist, unpacked)
@pytest.mark.skipif(
salt.utils.msgpack.version < (0, 6, 0),
"getbuffer() was added to Packer in msgpack 0.6.0",
)
def _test_get_buffer(self, pack_func, **kwargs):
packer = msgpack.Packer(autoreset=False, use_bin_type=True)
packer.pack([1, 2])