Use send_multipart instead of send when sending multipart message.

This commit is contained in:
Insoo Ha 2023-12-10 18:50:07 +09:00 committed by Pedro Algarvio
parent 487a1ad3d0
commit 98c92a3fac
4 changed files with 220 additions and 28 deletions

1
changelog/65018.fixed.md Normal file
View file

@ -0,0 +1 @@
Use `send_multipart` instead of `send` when sending multipart message.

View file

@ -777,21 +777,18 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
htopic = salt.utils.stringutils.to_bytes(
hashlib.sha1(salt.utils.stringutils.to_bytes(topic)).hexdigest()
)
yield self.dpub_sock.send(htopic, flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
yield self.dpub_sock.send_multipart([htopic, payload])
log.trace("Filtered data has been sent")
# Syndic broadcast
if self.opts.get("order_masters"):
log.trace("Sending filtered data to syndic")
yield self.dpub_sock.send(b"syndic", flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
yield self.dpub_sock.send_multipart([b"syndic", payload])
log.trace("Filtered data has been sent to syndic")
# otherwise its a broadcast
else:
# TODO: constants file for "broadcast"
log.trace("Sending broadcasted data over publisher %s", self.pub_uri)
yield self.dpub_sock.send(b"broadcast", flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
yield self.dpub_sock.send_multipart([b"broadcast", payload])
log.trace("Broadcasted data has been sent")
else:
log.trace("Sending ZMQ-unfiltered data over publisher %s", self.pub_uri)

View file

@ -1,10 +1,16 @@
from contextlib import contextmanager
import copy
import logging
import threading
import random
import time
import threading
import pytest
from saltfactories.utils import random_string
import salt.transport.zeromq
import salt.utils.process
from tests.support.mock import MagicMock, patch
from tests.support.pytest.transport import PubServerChannelProcess
@ -20,9 +26,151 @@ pytestmark = [
]
class PubServerChannelSender:
def __init__(self, pub_server_channel, payload_list):
self.pub_server_channel = pub_server_channel
self.payload_list = payload_list
def run(self):
for payload in self.payload_list:
self.pub_server_channel.publish(payload)
time.sleep(2)
def generate_msg_list(msg_cnt, minions_list, broadcast):
msg_list = []
for i in range(msg_cnt):
for idx, minion_id in enumerate(minions_list):
if broadcast:
msg_list.append({"tgt_type": "grain", "tgt": 'id:*', "jid": msg_cnt * idx + i})
else:
msg_list.append({"tgt_type": "list", "tgt": [minion_id], "jid": msg_cnt * idx + i})
return msg_list
@contextmanager
def channel_publisher_manager(msg_list, p_cnt, pub_server_channel):
process_list = []
msg_list = copy.deepcopy(msg_list)
random.shuffle(msg_list)
batch_size = len(msg_list) // p_cnt
list_batch = [[x * batch_size, x * batch_size + batch_size] for x in range(0, p_cnt)]
list_batch[-1][1] = list_batch[-1][1] + 1
try:
for i, j in list_batch:
c = PubServerChannelSender(pub_server_channel, msg_list[i:j])
p = salt.utils.process.Process(target=c.run)
process_list.append(p)
for p in process_list:
p.start()
yield
finally:
for p in process_list:
p.join()
@pytest.mark.skip_on_windows
@pytest.mark.slow_test
def test_zeromq_filtering(salt_master, salt_minion):
def test_zeromq_filtering_minion(salt_master, salt_minion):
opts = dict(
salt_master.config.copy(),
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
acceptance_wait_time=5,
)
minion_opts = dict(
salt_minion.config.copy(),
zmq_filtering=True,
)
messages = 200
workers = 5
minions = 3
expect = set(range(messages))
target_minion_id = salt_minion.id
minions_list = [target_minion_id]
for _ in range(minions - 1):
minions_list.append(random_string("zeromq-minion-"))
msg_list = generate_msg_list(messages, minions_list, False)
with patch(
"salt.utils.minions.CkMinions.check_minions",
MagicMock(
return_value={
"minions": minions_list,
"missing": [],
"ssh_minions": False,
}
),
):
with PubServerChannelProcess(opts, minion_opts) as server_channel:
with channel_publisher_manager(msg_list, workers, server_channel.pub_server_channel):
cnt = 0
last_results_len = 0
while cnt < 20:
time.sleep(2)
results_len = len(server_channel.collector.results)
if last_results_len == results_len:
break
last_results_len = results_len
cnt += 1
results = set(server_channel.collector.results)
assert results == expect, \
f"{len(results)}, != {len(expect)}, difference: {expect.difference(results)} {results}"
@pytest.mark.skip_on_windows
@pytest.mark.slow_test
def test_zeromq_filtering_syndic(salt_master, salt_minion):
opts = dict(
salt_master.config.copy(),
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
acceptance_wait_time=5,
order_masters=True,
)
minion_opts = dict(
salt_minion.config.copy(),
zmq_filtering=True,
__role='syndic',
)
messages = 200
workers = 5
minions = 3
expect = set(range(messages * minions))
minions_list = []
for _ in range(minions):
minions_list.append(random_string("zeromq-minion-"))
msg_list = generate_msg_list(messages, minions_list, False)
with patch(
"salt.utils.minions.CkMinions.check_minions",
MagicMock(
return_value={
"minions": minions_list,
"missing": [],
"ssh_minions": False,
}
),
):
with PubServerChannelProcess(opts, minion_opts) as server_channel:
with channel_publisher_manager(msg_list, workers, server_channel.pub_server_channel):
cnt = 0
last_results_len = 0
while cnt < 20:
time.sleep(2)
results_len = len(server_channel.collector.results)
if last_results_len == results_len:
break
last_results_len = results_len
cnt += 1
results = set(server_channel.collector.results)
assert results == expect, \
f"{len(results)}, != {len(expect)}, difference: {expect.difference(results)} {results}"
@pytest.mark.skip_on_windows
@pytest.mark.slow_test
def test_zeromq_filtering_broadcast(salt_master, salt_minion):
"""
Test sending messages to publisher using UDP with zeromq_filtering enabled
"""
@ -33,28 +181,43 @@ def test_zeromq_filtering(salt_master, salt_minion):
zmq_filtering=True,
acceptance_wait_time=5,
)
send_num = 1
expect = []
minion_opts = dict(
salt_minion.config.copy(),
zmq_filtering=True,
)
messages = 200
workers = 5
minions = 3
expect = set(range(messages * minions))
target_minion_id = salt_minion.id
minions_list = [target_minion_id]
for _ in range(minions - 1):
minions_list.append(random_string("zeromq-minion-"))
msg_list = generate_msg_list(messages, minions_list, True)
with patch(
"salt.utils.minions.CkMinions.check_minions",
MagicMock(
return_value={
"minions": [salt_minion.id],
"minions": minions_list,
"missing": [],
"ssh_minions": False,
}
),
):
with PubServerChannelProcess(
opts, salt_minion.config.copy(), zmq_filtering=True
) as server_channel:
expect.append(send_num)
load = {"tgt_type": "glob", "tgt": "*", "jid": send_num}
server_channel.publish(load)
results = server_channel.collector.results
assert len(results) == send_num, "{} != {}, difference: {}".format(
len(results), send_num, set(expect).difference(results)
)
with PubServerChannelProcess(opts, minion_opts) as server_channel:
with channel_publisher_manager(msg_list, workers, server_channel.pub_server_channel):
cnt = 0
last_results_len = 0
while cnt < 20:
time.sleep(2)
results_len = len(server_channel.collector.results)
if last_results_len == results_len:
break
last_results_len = results_len
cnt += 1
results = set(server_channel.collector.results)
assert results == expect, \
f"{len(results)}, != {len(expect)}, difference: {expect.difference(results)} {results}"
def test_pub_channel(master_opts):

View file

@ -1,4 +1,5 @@
import ctypes
import hashlib
import logging
import multiprocessing
import socket
@ -9,11 +10,13 @@ import zmq
from pytestshellutils.utils.processes import terminate_process
import salt.channel.server
import salt.crypt
import salt.exceptions
import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.ext.tornado.iostream
import salt.master
import salt.payload
import salt.utils.msgpack
import salt.utils.process
import salt.utils.stringutils
@ -36,10 +39,10 @@ class Collector(salt.utils.process.SignalHandlingProcess):
port,
aes_key,
timeout=300,
zmq_filtering=False,
):
super().__init__()
self.minion_config = minion_config
self.hexid = hashlib.sha1(salt.utils.stringutils.to_bytes(self.minion_config["id"])).hexdigest()
self.interface = interface
self.port = port
self.aes_key = aes_key
@ -48,10 +51,11 @@ class Collector(salt.utils.process.SignalHandlingProcess):
self.hard_timeout = time.time() + timeout + 120
self.manager = multiprocessing.Manager()
self.results = self.manager.list()
self.zmq_filtering = zmq_filtering
self.zmq_filtering = minion_config['zmq_filtering']
self.stopped = multiprocessing.Event()
self.started = multiprocessing.Event()
self.running = multiprocessing.Event()
self.stop_running = multiprocessing.Event()
self.unpacker = salt.utils.msgpack.Unpacker(raw=False)
@property
@ -78,7 +82,14 @@ class Collector(salt.utils.process.SignalHandlingProcess):
ctx = zmq.Context()
self.sock = ctx.socket(zmq.SUB)
self.sock.setsockopt(zmq.LINGER, -1)
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
if self.zmq_filtering:
self.sock.setsockopt(zmq.SUBSCRIBE, b"broadcast")
if self.minion_config.get("__role") == "syndic":
self.sock.setsockopt(zmq.SUBSCRIBE, b"syndic")
else:
self.sock.setsockopt(zmq.SUBSCRIBE, salt.utils.stringutils.to_bytes(self.hexid))
else:
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
self.sock.connect(pub_uri)
else:
@ -101,8 +112,23 @@ class Collector(salt.utils.process.SignalHandlingProcess):
# test_zeromq_filtering requires catching the
# SaltDeserializationError in order to pass.
try:
payload = self.sock.recv(zmq.NOBLOCK)
serial_payload = salt.payload.loads(payload)
messages = self.sock.recv_multipart(zmq.NOBLOCK)
messages_len = len(messages)
if messages_len == 1:
serial_payload = salt.payload.loads(messages[0])
elif messages_len == 2:
message_target = salt.utils.stringutils.to_str(messages[0])
is_syndic = self.minion_config.get("__role") == "syndic"
if (
not is_syndic and message_target not in ("broadcast", self.hexid)
) or (
is_syndic and message_target not in ("broadcast", "syndic")
):
log.debug("Publish received for not this minion: %s", message_target)
raise salt.ext.tornado.gen.Return(None)
serial_payload = salt.payload.loads(messages[1])
else:
raise Exception("Invalid number of messages")
raise salt.ext.tornado.gen.Return(serial_payload)
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
raise RecvError("ZMQ Error")
@ -125,7 +151,6 @@ class Collector(salt.utils.process.SignalHandlingProcess):
return
self.started.set()
last_msg = time.time()
serial = salt.payload.Serial(self.minion_config)
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
while True:
curr_time = time.time()
@ -150,6 +175,7 @@ class Collector(salt.utils.process.SignalHandlingProcess):
continue
if "stop" in payload:
log.info("Collector stopped")
self.stop_running.set()
break
last_msg = time.time()
self.results.append(payload["jid"])
@ -275,7 +301,12 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
def __exit__(self, *args):
# Publish a payload to tell the collection it's done processing
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
attempts = 300
while attempts > 0:
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
if self.collector.stop_running.wait(1) is True:
break
attempts -= 1
# Now trigger the collector to also exit
self.collector.__exit__(*args)
# We can safely wait here without a timeout because the Collector instance has a