Add pub server tests

This commit is contained in:
Daniel A. Wozniak 2021-09-30 13:32:05 -07:00 committed by Gareth J. Greenaway
parent 5873780ee3
commit 9536db83ce
6 changed files with 186 additions and 32 deletions

View file

@ -106,28 +106,24 @@ class ReqServerChannel:
and call payload_handler. You will also be passed io_loop, for all of your
asynchronous needs
"""
import salt.master
if self.opts["pub_server_niceness"] and not salt.utils.platform.is_windows():
log.info(
"setting Publish daemon niceness to %i",
self.opts["pub_server_niceness"],
)
os.nice(self.opts["pub_server_niceness"])
self.payload_handler = payload_handler
self.io_loop = io_loop
self.transport.post_fork(self.handle_message, io_loop)
import salt.master
self.crypticle = salt.crypt.Crypticle(
self.opts, salt.master.SMaster.secrets["aes"]["secret"].value
)
# other things needed for _auth
# Create the event manager
self.event = salt.utils.event.get_master_event(
self.opts, self.opts["sock_dir"], listen=False
)
self.auto_key = salt.daemons.masterapi.AutoKey(self.opts)
# only create a con_cache-client if the con_cache is active
if self.opts["con_cache"]:
self.cache_cli = CacheCli(self.opts)
@ -135,14 +131,13 @@ class ReqServerChannel:
self.cache_cli = False
# Make an minion checker object
self.ckminions = salt.utils.minions.CkMinions(self.opts)
self.master_key = salt.crypt.MasterKeys(self.opts)
self.payload_handler = payload_handler
self.transport.post_fork(self.handle_message, io_loop)
@salt.ext.tornado.gen.coroutine
def handle_message(self, stream, payload, header=None):
stream = self.transport.wrap_stream(stream)
def handle_message(self, payload, send_reply=None, header=None):
try:
payload = self.transport.decode_payload(payload)
payload = self._decode_payload(payload)
except Exception as exc: # pylint: disable=broad-except
exc_type = type(exc).__name__
@ -155,7 +150,7 @@ class ReqServerChannel:
)
else:
log.error("Bad load from minion: %s: %s", exc_type, exc)
yield stream.send("bad load", header)
yield send_reply("bad load", header)
raise salt.ext.tornado.gen.Return()
# TODO helper functions to normalize payload?
@ -165,24 +160,24 @@ class ReqServerChannel:
payload,
payload.get("load"),
)
yield stream.send("payload and load must be a dict", header)
yield send_reply("payload and load must be a dict", header)
raise salt.ext.tornado.gen.Return()
try:
id_ = payload["load"].get("id", "")
if "\0" in id_:
log.error("Payload contains an id with a null byte: %s", payload)
stream.send("bad load: id contains a null byte", header)
yield send_reply("bad load: id contains a null byte", header)
raise salt.ext.tornado.gen.Return()
except TypeError:
log.error("Payload contains non-string id: %s", payload)
stream.send("bad load: id {} is not a string".format(id_), header)
yield send_reply("bad load: id {} is not a string".format(id_), header)
raise salt.ext.tornado.gen.Return()
# intercept the "_auth" commands, since the main daemon shouldn't know
# anything about our key auth
if payload["enc"] == "clear" and payload.get("load", {}).get("cmd") == "_auth":
stream.send(self._auth(payload["load"]), header)
yield send_reply(self._auth(payload["load"]), header)
raise salt.ext.tornado.gen.Return()
# TODO: test
@ -192,17 +187,17 @@ class ReqServerChannel:
ret, req_opts = yield self.payload_handler(payload)
except Exception as e: # pylint: disable=broad-except
# always attempt to return an error to the minion
stream.send("Some exception handling minion payload", header)
yield send_reply("Some exception handling minion payload", header)
log.error("Some exception handling a payload from minion", exc_info=True)
raise salt.ext.tornado.gen.Return()
req_fun = req_opts.get("fun", "send")
if req_fun == "send_clear":
stream.send(ret, header)
yield send_reply(ret, header)
elif req_fun == "send":
stream.send(self.crypticle.dumps(ret), header)
yield send_reply(self.crypticle.dumps(ret), header)
elif req_fun == "send_private":
stream.send(
yield send_reply(
self._encrypt_private(
ret,
req_opts["key"],
@ -213,7 +208,7 @@ class ReqServerChannel:
else:
log.error("Unknown req_fun %s", req_fun)
# always attempt to return an error to the minion
stream.send("Server-side exception handling payload", header)
yield send_reply("Server-side exception handling payload", header)
raise salt.ext.tornado.gen.Return()
def _encrypt_private(self, ret, dictkey, target):

View file

@ -18,6 +18,7 @@ import time
import salt.acl
import salt.auth
import salt.channel.server
import salt.client
import salt.client.ssh.client
import salt.crypt
@ -34,7 +35,6 @@ import salt.pillar
import salt.runner
import salt.serializers.msgpack
import salt.state
import salt.transport.server
import salt.utils.args
import salt.utils.atomicfile
import salt.utils.crypt
@ -675,7 +675,7 @@ class Master(SMaster):
log.info("Creating master publisher process")
log_queue = salt.log.setup.get_multiprocessing_logging_queue()
for _, opts in iter_transport_opts(self.opts):
chan = salt.transport.server.PubServerChannel.factory(opts)
chan = salt.channel.server.PubServerChannel.factory(opts)
chan.pre_fork(self.process_manager, kwargs={"log_queue": log_queue})
pub_channels.append(chan)
@ -856,7 +856,7 @@ class ReqServer(salt.utils.process.SignalHandlingProcess):
req_channels = []
tcp_only = True
for transport, opts in iter_transport_opts(self.opts):
chan = salt.transport.server.ReqServerChannel.factory(opts)
chan = salt.channel.server.ReqServerChannel.factory(opts)
chan.pre_fork(self.process_manager)
req_channels.append(chan)
if transport != "tcp":
@ -2295,7 +2295,7 @@ class ClearFuncs(TransportMethods):
Take a load and send it across the network to connected minions
"""
for transport, opts in iter_transport_opts(self.opts):
chan = salt.transport.server.PubServerChannel.factory(opts)
chan = salt.channel.server.PubServerChannel.factory(opts)
chan.publish(load)
@property

View file

@ -357,6 +357,7 @@ class TCPReqServerChannel:
payload_handler: function to call with your payloads
"""
self.message_handler = message_handler
with salt.utils.asynchronous.current_ioloop(io_loop):
if USE_LOAD_BALANCER:
@ -375,13 +376,19 @@ class TCPReqServerChannel:
(self.opts["interface"], int(self.opts["ret_port"]))
)
self.req_server = SaltMessageServer(
message_handler,
self.handle_message,
ssl_options=self.opts.get("ssl"),
io_loop=io_loop,
)
self.req_server.add_socket(self._socket)
self._socket.listen(self.backlog)
@salt.ext.tornado.gen.coroutine
def handle_message(self, stream, payload, header=None):
stream = self.wrap_stream(stream)
payload = self.decode_payload(payload)
yield self.message_handler(payload, send_reply=stream.send, header=header)
def wrap_stream(self, stream):
class Stream:
def __init__(self, stream):

View file

@ -414,7 +414,14 @@ class ZeroMQReqServerChannel:
log.info("Worker binding to socket %s", self.w_uri)
self._socket.connect(self.w_uri)
self.stream = zmq.eventloop.zmqstream.ZMQStream(self._socket, io_loop=io_loop)
self.stream.on_recv_stream(message_handler)
self.message_handler = message_handler
self.stream.on_recv_stream(self.handle_message)
@salt.ext.tornado.gen.coroutine
def handle_message(self, stream, payload, header=None):
stream = self.wrap_stream(stream)
payload = self.decode_payload(payload)
self.message_handler(payload, send_reply=stream.send, header=header)
def __setup_signals(self):
signal.signal(signal.SIGINT, self._handle_signals)
@ -508,7 +515,7 @@ class AsyncReqMessageClient:
# mapping of message -> future
self.send_future_map = {}
self.send_timeout_map = {} # message -> timeout
# self.send_timeout_map = {} # message -> timeout
self._closing = False
# TODO: timeout all in-flight sessions, or error
@ -565,7 +572,7 @@ class AsyncReqMessageClient:
elif hasattr(zmq, "IPV4ONLY"):
self.socket.setsockopt(zmq.IPV4ONLY, 0)
self.socket.linger = self.linger
log.debug("Trying to connect to: %s", self.addr)
log.debug("**** Trying to connect to: %s", self.addr)
self.socket.connect(self.addr)
self.stream = zmq.eventloop.zmqstream.ZMQStream(
self.socket, io_loop=self.io_loop
@ -582,7 +589,6 @@ class AsyncReqMessageClient:
# In a race condition the message might have been sent by the time
# we're timing it out. Make sure the future is not None
if future is not None:
del self.send_timeout_map[message]
if future.attempts < future.tries:
future.attempts += 1
log.debug(
@ -726,8 +732,8 @@ class ZeroMQPubServerChannel:
"""
ioloop = salt.ext.tornado.ioloop.IOLoop()
ioloop.make_current()
context = zmq.Context(1)
self.io_loop = ioloop
context = self.context = zmq.Context(1)
pub_sock = context.socket(zmq.PUB)
monitor = ZeroMQSocketMonitor(pub_sock)
monitor.start_io_loop(ioloop)
@ -770,7 +776,8 @@ class ZeroMQPubServerChannel:
try:
ioloop.start()
finally:
context.term()
pub_sock.close()
pull_sock.close()
@property
def pull_uri(self):
@ -894,6 +901,9 @@ class ZeroMQPubServerChannel:
def topic_support(self):
return self.opts.get("zmq_filtering", False)
def close(self):
self.pub_close()
class ZeroMQReqChannel:
ttype = "zeromq"

View file

@ -0,0 +1,142 @@
import ctypes
import io
import logging
import multiprocessing
import os
import threading
import time
from concurrent.futures.thread import ThreadPoolExecutor
import pytest
import salt.channel.client
import salt.channel.server
import salt.config
import salt.exceptions
import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.log.setup
import salt.master
import salt.transport.client
import salt.transport.server
import salt.transport.zeromq
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
import zmq
from saltfactories.utils.ports import get_unused_localhost_port
from saltfactories.utils.processes import terminate_process
from tests.support.mock import MagicMock, patch
log = logging.getLogger(__name__)
@pytest.fixture
def master_config(tmp_path):
master_conf = salt.config.master_config("")
master_conf["id"] = "master"
master_conf["root_dir"] = str(tmp_path)
master_conf["sock_dir"] = str(tmp_path)
master_conf["ret_port"] = get_unused_localhost_port()
master_conf["master_uri"] = "tcp://127.0.0.1:{}".format(master_conf["ret_port"])
master_conf["pki_dir"] = str(tmp_path / "pki")
os.makedirs(master_conf["pki_dir"])
salt.crypt.gen_keys(master_conf["pki_dir"], "master", 4096)
minions_keys = os.path.join(master_conf["pki_dir"], "minions")
os.makedirs(minions_keys)
yield master_conf
@pytest.fixture
def configs(master_config):
minion_conf = salt.config.minion_config("")
minion_conf["root_dir"] = master_config["root_dir"]
minion_conf["id"] = "minion"
minion_conf["sock_dir"] = master_config["sock_dir"]
minion_conf["pki_dir"] = os.path.join(master_config["root_dir"], "pki_minion")
os.makedirs(minion_conf["pki_dir"])
minion_conf["master_port"] = master_config["ret_port"]
minion_conf["master_ip"] = "127.0.0.1"
minion_conf["master_uri"] = "tcp://127.0.0.1:{}".format(master_config["ret_port"])
salt.crypt.gen_keys(minion_conf["pki_dir"], "minion", 4096)
minion_pub = os.path.join(minion_conf["pki_dir"], "minion.pub")
pub_on_master = os.path.join(master_config["pki_dir"], "minions", "minion")
with io.open(minion_pub, "r") as rfp:
with io.open(pub_on_master, "w") as wfp:
wfp.write(rfp.read())
return (minion_conf, master_config)
def test_pub_server_channel_with_zmq_transport(io_loop, configs):
minion_conf, master_conf = configs
process_manager = salt.utils.process.ProcessManager()
server_channel = salt.transport.server.PubServerChannel.factory(
master_conf,
)
server_channel.pre_fork(process_manager)
req_server_channel = salt.transport.server.ReqServerChannel.factory(master_conf)
req_server_channel.pre_fork(process_manager)
def handle_payload(payload):
log.info("TEST - Req Server handle payload {}".format(repr(payload)))
req_server_channel.post_fork(handle_payload, io_loop=io_loop)
pub_channel = salt.transport.client.AsyncPubChannel.factory(minion_conf)
@salt.ext.tornado.gen.coroutine
def doit(channel, server):
log.info("TEST - BEFORE CHANNEL CONNECT")
yield channel.connect()
log.info("TEST - AFTER CHANNEL CONNECT")
def cb(payload):
log.info("TEST - PUB SERVER MSG {}".format(repr(payload)))
io_loop.stop()
channel.on_recv(cb)
server.publish({"tgt_type": "glob", "tgt": ["carbon"], "WTF": "SON"})
io_loop.add_callback(doit, pub_channel, server_channel)
io_loop.start()
# server_channel.transport.stop()
process_manager.terminate()
def test_pub_server_channel_with_tcp_transport(io_loop, configs):
minion_conf, master_conf = configs
minion_conf["transport"] = "tcp"
master_conf["transport"] = "tcp"
process_manager = salt.utils.process.ProcessManager()
server_channel = salt.transport.server.PubServerChannel.factory(
master_conf,
)
server_channel.pre_fork(process_manager)
req_server_channel = salt.transport.server.ReqServerChannel.factory(master_conf)
req_server_channel.pre_fork(process_manager)
def handle_payload(payload):
log.info("TEST - Req Server handle payload {}".format(repr(payload)))
req_server_channel.post_fork(handle_payload, io_loop=io_loop)
pub_channel = salt.transport.client.AsyncPubChannel.factory(minion_conf)
@salt.ext.tornado.gen.coroutine
def doit(channel, server):
log.info("TEST - BEFORE CHANNEL CONNECT")
yield channel.connect()
log.info("TEST - AFTER CHANNEL CONNECT")
def cb(payload):
log.info("TEST - PUB SERVER MSG {}".format(repr(payload)))
io_loop.stop()
channel.on_recv(cb)
server.publish({"tgt_type": "glob", "tgt": ["carbon"], "WTF": "SON"})
io_loop.add_callback(doit, pub_channel, server_channel)
io_loop.start()
# server_channel.transport.stop()
process_manager.terminate()