mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Use send_multipart instead of send when sending multipart message.
This commit is contained in:
parent
487a1ad3d0
commit
98c92a3fac
4 changed files with 220 additions and 28 deletions
1
changelog/65018.fixed.md
Normal file
1
changelog/65018.fixed.md
Normal file
|
@ -0,0 +1 @@
|
|||
Use `send_multipart` instead of `send` when sending multipart message.
|
|
@ -777,21 +777,18 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
htopic = salt.utils.stringutils.to_bytes(
|
||||
hashlib.sha1(salt.utils.stringutils.to_bytes(topic)).hexdigest()
|
||||
)
|
||||
yield self.dpub_sock.send(htopic, flags=zmq.SNDMORE)
|
||||
yield self.dpub_sock.send(payload)
|
||||
yield self.dpub_sock.send_multipart([htopic, payload])
|
||||
log.trace("Filtered data has been sent")
|
||||
# Syndic broadcast
|
||||
if self.opts.get("order_masters"):
|
||||
log.trace("Sending filtered data to syndic")
|
||||
yield self.dpub_sock.send(b"syndic", flags=zmq.SNDMORE)
|
||||
yield self.dpub_sock.send(payload)
|
||||
yield self.dpub_sock.send_multipart([b"syndic", payload])
|
||||
log.trace("Filtered data has been sent to syndic")
|
||||
# otherwise its a broadcast
|
||||
else:
|
||||
# TODO: constants file for "broadcast"
|
||||
log.trace("Sending broadcasted data over publisher %s", self.pub_uri)
|
||||
yield self.dpub_sock.send(b"broadcast", flags=zmq.SNDMORE)
|
||||
yield self.dpub_sock.send(payload)
|
||||
yield self.dpub_sock.send_multipart([b"broadcast", payload])
|
||||
log.trace("Broadcasted data has been sent")
|
||||
else:
|
||||
log.trace("Sending ZMQ-unfiltered data over publisher %s", self.pub_uri)
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
from contextlib import contextmanager
|
||||
import copy
|
||||
import logging
|
||||
import threading
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from saltfactories.utils import random_string
|
||||
import salt.transport.zeromq
|
||||
import salt.utils.process
|
||||
|
||||
from tests.support.mock import MagicMock, patch
|
||||
from tests.support.pytest.transport import PubServerChannelProcess
|
||||
|
||||
|
@ -20,9 +26,151 @@ pytestmark = [
|
|||
]
|
||||
|
||||
|
||||
class PubServerChannelSender:
|
||||
def __init__(self, pub_server_channel, payload_list):
|
||||
self.pub_server_channel = pub_server_channel
|
||||
self.payload_list = payload_list
|
||||
|
||||
def run(self):
|
||||
for payload in self.payload_list:
|
||||
self.pub_server_channel.publish(payload)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def generate_msg_list(msg_cnt, minions_list, broadcast):
|
||||
msg_list = []
|
||||
for i in range(msg_cnt):
|
||||
for idx, minion_id in enumerate(minions_list):
|
||||
if broadcast:
|
||||
msg_list.append({"tgt_type": "grain", "tgt": 'id:*', "jid": msg_cnt * idx + i})
|
||||
else:
|
||||
msg_list.append({"tgt_type": "list", "tgt": [minion_id], "jid": msg_cnt * idx + i})
|
||||
return msg_list
|
||||
|
||||
|
||||
@contextmanager
|
||||
def channel_publisher_manager(msg_list, p_cnt, pub_server_channel):
|
||||
process_list = []
|
||||
msg_list = copy.deepcopy(msg_list)
|
||||
random.shuffle(msg_list)
|
||||
batch_size = len(msg_list) // p_cnt
|
||||
list_batch = [[x * batch_size, x * batch_size + batch_size] for x in range(0, p_cnt)]
|
||||
list_batch[-1][1] = list_batch[-1][1] + 1
|
||||
try:
|
||||
for i, j in list_batch:
|
||||
c = PubServerChannelSender(pub_server_channel, msg_list[i:j])
|
||||
p = salt.utils.process.Process(target=c.run)
|
||||
process_list.append(p)
|
||||
for p in process_list:
|
||||
p.start()
|
||||
yield
|
||||
finally:
|
||||
for p in process_list:
|
||||
p.join()
|
||||
|
||||
|
||||
@pytest.mark.skip_on_windows
|
||||
@pytest.mark.slow_test
|
||||
def test_zeromq_filtering(salt_master, salt_minion):
|
||||
def test_zeromq_filtering_minion(salt_master, salt_minion):
|
||||
opts = dict(
|
||||
salt_master.config.copy(),
|
||||
ipc_mode="ipc",
|
||||
pub_hwm=0,
|
||||
zmq_filtering=True,
|
||||
acceptance_wait_time=5,
|
||||
)
|
||||
minion_opts = dict(
|
||||
salt_minion.config.copy(),
|
||||
zmq_filtering=True,
|
||||
)
|
||||
messages = 200
|
||||
workers = 5
|
||||
minions = 3
|
||||
expect = set(range(messages))
|
||||
target_minion_id = salt_minion.id
|
||||
minions_list = [target_minion_id]
|
||||
for _ in range(minions - 1):
|
||||
minions_list.append(random_string("zeromq-minion-"))
|
||||
msg_list = generate_msg_list(messages, minions_list, False)
|
||||
with patch(
|
||||
"salt.utils.minions.CkMinions.check_minions",
|
||||
MagicMock(
|
||||
return_value={
|
||||
"minions": minions_list,
|
||||
"missing": [],
|
||||
"ssh_minions": False,
|
||||
}
|
||||
),
|
||||
):
|
||||
with PubServerChannelProcess(opts, minion_opts) as server_channel:
|
||||
with channel_publisher_manager(msg_list, workers, server_channel.pub_server_channel):
|
||||
cnt = 0
|
||||
last_results_len = 0
|
||||
while cnt < 20:
|
||||
time.sleep(2)
|
||||
results_len = len(server_channel.collector.results)
|
||||
if last_results_len == results_len:
|
||||
break
|
||||
last_results_len = results_len
|
||||
cnt += 1
|
||||
results = set(server_channel.collector.results)
|
||||
assert results == expect, \
|
||||
f"{len(results)}, != {len(expect)}, difference: {expect.difference(results)} {results}"
|
||||
|
||||
|
||||
@pytest.mark.skip_on_windows
|
||||
@pytest.mark.slow_test
|
||||
def test_zeromq_filtering_syndic(salt_master, salt_minion):
|
||||
opts = dict(
|
||||
salt_master.config.copy(),
|
||||
ipc_mode="ipc",
|
||||
pub_hwm=0,
|
||||
zmq_filtering=True,
|
||||
acceptance_wait_time=5,
|
||||
order_masters=True,
|
||||
)
|
||||
minion_opts = dict(
|
||||
salt_minion.config.copy(),
|
||||
zmq_filtering=True,
|
||||
__role='syndic',
|
||||
)
|
||||
messages = 200
|
||||
workers = 5
|
||||
minions = 3
|
||||
expect = set(range(messages * minions))
|
||||
minions_list = []
|
||||
for _ in range(minions):
|
||||
minions_list.append(random_string("zeromq-minion-"))
|
||||
msg_list = generate_msg_list(messages, minions_list, False)
|
||||
with patch(
|
||||
"salt.utils.minions.CkMinions.check_minions",
|
||||
MagicMock(
|
||||
return_value={
|
||||
"minions": minions_list,
|
||||
"missing": [],
|
||||
"ssh_minions": False,
|
||||
}
|
||||
),
|
||||
):
|
||||
with PubServerChannelProcess(opts, minion_opts) as server_channel:
|
||||
with channel_publisher_manager(msg_list, workers, server_channel.pub_server_channel):
|
||||
cnt = 0
|
||||
last_results_len = 0
|
||||
while cnt < 20:
|
||||
time.sleep(2)
|
||||
results_len = len(server_channel.collector.results)
|
||||
if last_results_len == results_len:
|
||||
break
|
||||
last_results_len = results_len
|
||||
cnt += 1
|
||||
results = set(server_channel.collector.results)
|
||||
assert results == expect, \
|
||||
f"{len(results)}, != {len(expect)}, difference: {expect.difference(results)} {results}"
|
||||
|
||||
|
||||
@pytest.mark.skip_on_windows
|
||||
@pytest.mark.slow_test
|
||||
def test_zeromq_filtering_broadcast(salt_master, salt_minion):
|
||||
"""
|
||||
Test sending messages to publisher using UDP with zeromq_filtering enabled
|
||||
"""
|
||||
|
@ -33,28 +181,43 @@ def test_zeromq_filtering(salt_master, salt_minion):
|
|||
zmq_filtering=True,
|
||||
acceptance_wait_time=5,
|
||||
)
|
||||
send_num = 1
|
||||
expect = []
|
||||
minion_opts = dict(
|
||||
salt_minion.config.copy(),
|
||||
zmq_filtering=True,
|
||||
)
|
||||
messages = 200
|
||||
workers = 5
|
||||
minions = 3
|
||||
expect = set(range(messages * minions))
|
||||
target_minion_id = salt_minion.id
|
||||
minions_list = [target_minion_id]
|
||||
for _ in range(minions - 1):
|
||||
minions_list.append(random_string("zeromq-minion-"))
|
||||
msg_list = generate_msg_list(messages, minions_list, True)
|
||||
with patch(
|
||||
"salt.utils.minions.CkMinions.check_minions",
|
||||
MagicMock(
|
||||
return_value={
|
||||
"minions": [salt_minion.id],
|
||||
"minions": minions_list,
|
||||
"missing": [],
|
||||
"ssh_minions": False,
|
||||
}
|
||||
),
|
||||
):
|
||||
with PubServerChannelProcess(
|
||||
opts, salt_minion.config.copy(), zmq_filtering=True
|
||||
) as server_channel:
|
||||
expect.append(send_num)
|
||||
load = {"tgt_type": "glob", "tgt": "*", "jid": send_num}
|
||||
server_channel.publish(load)
|
||||
results = server_channel.collector.results
|
||||
assert len(results) == send_num, "{} != {}, difference: {}".format(
|
||||
len(results), send_num, set(expect).difference(results)
|
||||
)
|
||||
with PubServerChannelProcess(opts, minion_opts) as server_channel:
|
||||
with channel_publisher_manager(msg_list, workers, server_channel.pub_server_channel):
|
||||
cnt = 0
|
||||
last_results_len = 0
|
||||
while cnt < 20:
|
||||
time.sleep(2)
|
||||
results_len = len(server_channel.collector.results)
|
||||
if last_results_len == results_len:
|
||||
break
|
||||
last_results_len = results_len
|
||||
cnt += 1
|
||||
results = set(server_channel.collector.results)
|
||||
assert results == expect, \
|
||||
f"{len(results)}, != {len(expect)}, difference: {expect.difference(results)} {results}"
|
||||
|
||||
|
||||
def test_pub_channel(master_opts):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import ctypes
|
||||
import hashlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import socket
|
||||
|
@ -9,11 +10,13 @@ import zmq
|
|||
from pytestshellutils.utils.processes import terminate_process
|
||||
|
||||
import salt.channel.server
|
||||
import salt.crypt
|
||||
import salt.exceptions
|
||||
import salt.ext.tornado.gen
|
||||
import salt.ext.tornado.ioloop
|
||||
import salt.ext.tornado.iostream
|
||||
import salt.master
|
||||
import salt.payload
|
||||
import salt.utils.msgpack
|
||||
import salt.utils.process
|
||||
import salt.utils.stringutils
|
||||
|
@ -36,10 +39,10 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
port,
|
||||
aes_key,
|
||||
timeout=300,
|
||||
zmq_filtering=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.minion_config = minion_config
|
||||
self.hexid = hashlib.sha1(salt.utils.stringutils.to_bytes(self.minion_config["id"])).hexdigest()
|
||||
self.interface = interface
|
||||
self.port = port
|
||||
self.aes_key = aes_key
|
||||
|
@ -48,10 +51,11 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
self.hard_timeout = time.time() + timeout + 120
|
||||
self.manager = multiprocessing.Manager()
|
||||
self.results = self.manager.list()
|
||||
self.zmq_filtering = zmq_filtering
|
||||
self.zmq_filtering = minion_config['zmq_filtering']
|
||||
self.stopped = multiprocessing.Event()
|
||||
self.started = multiprocessing.Event()
|
||||
self.running = multiprocessing.Event()
|
||||
self.stop_running = multiprocessing.Event()
|
||||
self.unpacker = salt.utils.msgpack.Unpacker(raw=False)
|
||||
|
||||
@property
|
||||
|
@ -78,7 +82,14 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
ctx = zmq.Context()
|
||||
self.sock = ctx.socket(zmq.SUB)
|
||||
self.sock.setsockopt(zmq.LINGER, -1)
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
|
||||
if self.zmq_filtering:
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, b"broadcast")
|
||||
if self.minion_config.get("__role") == "syndic":
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, b"syndic")
|
||||
else:
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, salt.utils.stringutils.to_bytes(self.hexid))
|
||||
else:
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
|
||||
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
|
||||
self.sock.connect(pub_uri)
|
||||
else:
|
||||
|
@ -101,8 +112,23 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
# test_zeromq_filtering requires catching the
|
||||
# SaltDeserializationError in order to pass.
|
||||
try:
|
||||
payload = self.sock.recv(zmq.NOBLOCK)
|
||||
serial_payload = salt.payload.loads(payload)
|
||||
messages = self.sock.recv_multipart(zmq.NOBLOCK)
|
||||
messages_len = len(messages)
|
||||
if messages_len == 1:
|
||||
serial_payload = salt.payload.loads(messages[0])
|
||||
elif messages_len == 2:
|
||||
message_target = salt.utils.stringutils.to_str(messages[0])
|
||||
is_syndic = self.minion_config.get("__role") == "syndic"
|
||||
if (
|
||||
not is_syndic and message_target not in ("broadcast", self.hexid)
|
||||
) or (
|
||||
is_syndic and message_target not in ("broadcast", "syndic")
|
||||
):
|
||||
log.debug("Publish received for not this minion: %s", message_target)
|
||||
raise salt.ext.tornado.gen.Return(None)
|
||||
serial_payload = salt.payload.loads(messages[1])
|
||||
else:
|
||||
raise Exception("Invalid number of messages")
|
||||
raise salt.ext.tornado.gen.Return(serial_payload)
|
||||
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
|
||||
raise RecvError("ZMQ Error")
|
||||
|
@ -125,7 +151,6 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
return
|
||||
self.started.set()
|
||||
last_msg = time.time()
|
||||
serial = salt.payload.Serial(self.minion_config)
|
||||
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
|
||||
while True:
|
||||
curr_time = time.time()
|
||||
|
@ -150,6 +175,7 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
continue
|
||||
if "stop" in payload:
|
||||
log.info("Collector stopped")
|
||||
self.stop_running.set()
|
||||
break
|
||||
last_msg = time.time()
|
||||
self.results.append(payload["jid"])
|
||||
|
@ -275,7 +301,12 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
|||
|
||||
def __exit__(self, *args):
|
||||
# Publish a payload to tell the collection it's done processing
|
||||
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
|
||||
attempts = 300
|
||||
while attempts > 0:
|
||||
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
|
||||
if self.collector.stop_running.wait(1) is True:
|
||||
break
|
||||
attempts -= 1
|
||||
# Now trigger the collector to also exit
|
||||
self.collector.__exit__(*args)
|
||||
# We can safely wait here without a timeout because the Collector instance has a
|
||||
|
|
Loading…
Add table
Reference in a new issue