mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Add pub server tests
This commit is contained in:
parent
5873780ee3
commit
9536db83ce
6 changed files with 186 additions and 32 deletions
|
@ -106,28 +106,24 @@ class ReqServerChannel:
|
|||
and call payload_handler. You will also be passed io_loop, for all of your
|
||||
asynchronous needs
|
||||
"""
|
||||
import salt.master
|
||||
|
||||
if self.opts["pub_server_niceness"] and not salt.utils.platform.is_windows():
|
||||
log.info(
|
||||
"setting Publish daemon niceness to %i",
|
||||
self.opts["pub_server_niceness"],
|
||||
)
|
||||
os.nice(self.opts["pub_server_niceness"])
|
||||
self.payload_handler = payload_handler
|
||||
self.io_loop = io_loop
|
||||
self.transport.post_fork(self.handle_message, io_loop)
|
||||
import salt.master
|
||||
|
||||
self.crypticle = salt.crypt.Crypticle(
|
||||
self.opts, salt.master.SMaster.secrets["aes"]["secret"].value
|
||||
)
|
||||
|
||||
# other things needed for _auth
|
||||
# Create the event manager
|
||||
self.event = salt.utils.event.get_master_event(
|
||||
self.opts, self.opts["sock_dir"], listen=False
|
||||
)
|
||||
self.auto_key = salt.daemons.masterapi.AutoKey(self.opts)
|
||||
|
||||
# only create a con_cache-client if the con_cache is active
|
||||
if self.opts["con_cache"]:
|
||||
self.cache_cli = CacheCli(self.opts)
|
||||
|
@ -135,14 +131,13 @@ class ReqServerChannel:
|
|||
self.cache_cli = False
|
||||
# Make an minion checker object
|
||||
self.ckminions = salt.utils.minions.CkMinions(self.opts)
|
||||
|
||||
self.master_key = salt.crypt.MasterKeys(self.opts)
|
||||
self.payload_handler = payload_handler
|
||||
self.transport.post_fork(self.handle_message, io_loop)
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def handle_message(self, stream, payload, header=None):
|
||||
stream = self.transport.wrap_stream(stream)
|
||||
def handle_message(self, payload, send_reply=None, header=None):
|
||||
try:
|
||||
payload = self.transport.decode_payload(payload)
|
||||
payload = self._decode_payload(payload)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
exc_type = type(exc).__name__
|
||||
|
@ -155,7 +150,7 @@ class ReqServerChannel:
|
|||
)
|
||||
else:
|
||||
log.error("Bad load from minion: %s: %s", exc_type, exc)
|
||||
yield stream.send("bad load", header)
|
||||
yield send_reply("bad load", header)
|
||||
raise salt.ext.tornado.gen.Return()
|
||||
|
||||
# TODO helper functions to normalize payload?
|
||||
|
@ -165,24 +160,24 @@ class ReqServerChannel:
|
|||
payload,
|
||||
payload.get("load"),
|
||||
)
|
||||
yield stream.send("payload and load must be a dict", header)
|
||||
yield send_reply("payload and load must be a dict", header)
|
||||
raise salt.ext.tornado.gen.Return()
|
||||
|
||||
try:
|
||||
id_ = payload["load"].get("id", "")
|
||||
if "\0" in id_:
|
||||
log.error("Payload contains an id with a null byte: %s", payload)
|
||||
stream.send("bad load: id contains a null byte", header)
|
||||
yield send_reply("bad load: id contains a null byte", header)
|
||||
raise salt.ext.tornado.gen.Return()
|
||||
except TypeError:
|
||||
log.error("Payload contains non-string id: %s", payload)
|
||||
stream.send("bad load: id {} is not a string".format(id_), header)
|
||||
yield send_reply("bad load: id {} is not a string".format(id_), header)
|
||||
raise salt.ext.tornado.gen.Return()
|
||||
|
||||
# intercept the "_auth" commands, since the main daemon shouldn't know
|
||||
# anything about our key auth
|
||||
if payload["enc"] == "clear" and payload.get("load", {}).get("cmd") == "_auth":
|
||||
stream.send(self._auth(payload["load"]), header)
|
||||
yield send_reply(self._auth(payload["load"]), header)
|
||||
raise salt.ext.tornado.gen.Return()
|
||||
|
||||
# TODO: test
|
||||
|
@ -192,17 +187,17 @@ class ReqServerChannel:
|
|||
ret, req_opts = yield self.payload_handler(payload)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# always attempt to return an error to the minion
|
||||
stream.send("Some exception handling minion payload", header)
|
||||
yield send_reply("Some exception handling minion payload", header)
|
||||
log.error("Some exception handling a payload from minion", exc_info=True)
|
||||
raise salt.ext.tornado.gen.Return()
|
||||
|
||||
req_fun = req_opts.get("fun", "send")
|
||||
if req_fun == "send_clear":
|
||||
stream.send(ret, header)
|
||||
yield send_reply(ret, header)
|
||||
elif req_fun == "send":
|
||||
stream.send(self.crypticle.dumps(ret), header)
|
||||
yield send_reply(self.crypticle.dumps(ret), header)
|
||||
elif req_fun == "send_private":
|
||||
stream.send(
|
||||
yield send_reply(
|
||||
self._encrypt_private(
|
||||
ret,
|
||||
req_opts["key"],
|
||||
|
@ -213,7 +208,7 @@ class ReqServerChannel:
|
|||
else:
|
||||
log.error("Unknown req_fun %s", req_fun)
|
||||
# always attempt to return an error to the minion
|
||||
stream.send("Server-side exception handling payload", header)
|
||||
yield send_reply("Server-side exception handling payload", header)
|
||||
raise salt.ext.tornado.gen.Return()
|
||||
|
||||
def _encrypt_private(self, ret, dictkey, target):
|
||||
|
|
|
@ -18,6 +18,7 @@ import time
|
|||
|
||||
import salt.acl
|
||||
import salt.auth
|
||||
import salt.channel.server
|
||||
import salt.client
|
||||
import salt.client.ssh.client
|
||||
import salt.crypt
|
||||
|
@ -34,7 +35,6 @@ import salt.pillar
|
|||
import salt.runner
|
||||
import salt.serializers.msgpack
|
||||
import salt.state
|
||||
import salt.transport.server
|
||||
import salt.utils.args
|
||||
import salt.utils.atomicfile
|
||||
import salt.utils.crypt
|
||||
|
@ -675,7 +675,7 @@ class Master(SMaster):
|
|||
log.info("Creating master publisher process")
|
||||
log_queue = salt.log.setup.get_multiprocessing_logging_queue()
|
||||
for _, opts in iter_transport_opts(self.opts):
|
||||
chan = salt.transport.server.PubServerChannel.factory(opts)
|
||||
chan = salt.channel.server.PubServerChannel.factory(opts)
|
||||
chan.pre_fork(self.process_manager, kwargs={"log_queue": log_queue})
|
||||
pub_channels.append(chan)
|
||||
|
||||
|
@ -856,7 +856,7 @@ class ReqServer(salt.utils.process.SignalHandlingProcess):
|
|||
req_channels = []
|
||||
tcp_only = True
|
||||
for transport, opts in iter_transport_opts(self.opts):
|
||||
chan = salt.transport.server.ReqServerChannel.factory(opts)
|
||||
chan = salt.channel.server.ReqServerChannel.factory(opts)
|
||||
chan.pre_fork(self.process_manager)
|
||||
req_channels.append(chan)
|
||||
if transport != "tcp":
|
||||
|
@ -2295,7 +2295,7 @@ class ClearFuncs(TransportMethods):
|
|||
Take a load and send it across the network to connected minions
|
||||
"""
|
||||
for transport, opts in iter_transport_opts(self.opts):
|
||||
chan = salt.transport.server.PubServerChannel.factory(opts)
|
||||
chan = salt.channel.server.PubServerChannel.factory(opts)
|
||||
chan.publish(load)
|
||||
|
||||
@property
|
||||
|
|
|
@ -357,6 +357,7 @@ class TCPReqServerChannel:
|
|||
|
||||
payload_handler: function to call with your payloads
|
||||
"""
|
||||
self.message_handler = message_handler
|
||||
|
||||
with salt.utils.asynchronous.current_ioloop(io_loop):
|
||||
if USE_LOAD_BALANCER:
|
||||
|
@ -375,13 +376,19 @@ class TCPReqServerChannel:
|
|||
(self.opts["interface"], int(self.opts["ret_port"]))
|
||||
)
|
||||
self.req_server = SaltMessageServer(
|
||||
message_handler,
|
||||
self.handle_message,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
io_loop=io_loop,
|
||||
)
|
||||
self.req_server.add_socket(self._socket)
|
||||
self._socket.listen(self.backlog)
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def handle_message(self, stream, payload, header=None):
|
||||
stream = self.wrap_stream(stream)
|
||||
payload = self.decode_payload(payload)
|
||||
yield self.message_handler(payload, send_reply=stream.send, header=header)
|
||||
|
||||
def wrap_stream(self, stream):
|
||||
class Stream:
|
||||
def __init__(self, stream):
|
||||
|
|
|
@ -414,7 +414,14 @@ class ZeroMQReqServerChannel:
|
|||
log.info("Worker binding to socket %s", self.w_uri)
|
||||
self._socket.connect(self.w_uri)
|
||||
self.stream = zmq.eventloop.zmqstream.ZMQStream(self._socket, io_loop=io_loop)
|
||||
self.stream.on_recv_stream(message_handler)
|
||||
self.message_handler = message_handler
|
||||
self.stream.on_recv_stream(self.handle_message)
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def handle_message(self, stream, payload, header=None):
|
||||
stream = self.wrap_stream(stream)
|
||||
payload = self.decode_payload(payload)
|
||||
self.message_handler(payload, send_reply=stream.send, header=header)
|
||||
|
||||
def __setup_signals(self):
|
||||
signal.signal(signal.SIGINT, self._handle_signals)
|
||||
|
@ -508,7 +515,7 @@ class AsyncReqMessageClient:
|
|||
# mapping of message -> future
|
||||
self.send_future_map = {}
|
||||
|
||||
self.send_timeout_map = {} # message -> timeout
|
||||
# self.send_timeout_map = {} # message -> timeout
|
||||
self._closing = False
|
||||
|
||||
# TODO: timeout all in-flight sessions, or error
|
||||
|
@ -565,7 +572,7 @@ class AsyncReqMessageClient:
|
|||
elif hasattr(zmq, "IPV4ONLY"):
|
||||
self.socket.setsockopt(zmq.IPV4ONLY, 0)
|
||||
self.socket.linger = self.linger
|
||||
log.debug("Trying to connect to: %s", self.addr)
|
||||
log.debug("**** Trying to connect to: %s", self.addr)
|
||||
self.socket.connect(self.addr)
|
||||
self.stream = zmq.eventloop.zmqstream.ZMQStream(
|
||||
self.socket, io_loop=self.io_loop
|
||||
|
@ -582,7 +589,6 @@ class AsyncReqMessageClient:
|
|||
# In a race condition the message might have been sent by the time
|
||||
# we're timing it out. Make sure the future is not None
|
||||
if future is not None:
|
||||
del self.send_timeout_map[message]
|
||||
if future.attempts < future.tries:
|
||||
future.attempts += 1
|
||||
log.debug(
|
||||
|
@ -726,8 +732,8 @@ class ZeroMQPubServerChannel:
|
|||
"""
|
||||
ioloop = salt.ext.tornado.ioloop.IOLoop()
|
||||
ioloop.make_current()
|
||||
|
||||
context = zmq.Context(1)
|
||||
self.io_loop = ioloop
|
||||
context = self.context = zmq.Context(1)
|
||||
pub_sock = context.socket(zmq.PUB)
|
||||
monitor = ZeroMQSocketMonitor(pub_sock)
|
||||
monitor.start_io_loop(ioloop)
|
||||
|
@ -770,7 +776,8 @@ class ZeroMQPubServerChannel:
|
|||
try:
|
||||
ioloop.start()
|
||||
finally:
|
||||
context.term()
|
||||
pub_sock.close()
|
||||
pull_sock.close()
|
||||
|
||||
@property
|
||||
def pull_uri(self):
|
||||
|
@ -894,6 +901,9 @@ class ZeroMQPubServerChannel:
|
|||
def topic_support(self):
|
||||
return self.opts.get("zmq_filtering", False)
|
||||
|
||||
def close(self):
|
||||
self.pub_close()
|
||||
|
||||
|
||||
class ZeroMQReqChannel:
|
||||
ttype = "zeromq"
|
||||
|
|
0
tests/pytests/functional/channel/__init__.py
Normal file
0
tests/pytests/functional/channel/__init__.py
Normal file
142
tests/pytests/functional/channel/test_server.py
Normal file
142
tests/pytests/functional/channel/test_server.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
import ctypes
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
import salt.channel.client
|
||||
import salt.channel.server
|
||||
import salt.config
|
||||
import salt.exceptions
|
||||
import salt.ext.tornado.gen
|
||||
import salt.ext.tornado.ioloop
|
||||
import salt.log.setup
|
||||
import salt.master
|
||||
import salt.transport.client
|
||||
import salt.transport.server
|
||||
import salt.transport.zeromq
|
||||
import salt.utils.platform
|
||||
import salt.utils.process
|
||||
import salt.utils.stringutils
|
||||
import zmq
|
||||
from saltfactories.utils.ports import get_unused_localhost_port
|
||||
from saltfactories.utils.processes import terminate_process
|
||||
from tests.support.mock import MagicMock, patch
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def master_config(tmp_path):
|
||||
master_conf = salt.config.master_config("")
|
||||
master_conf["id"] = "master"
|
||||
master_conf["root_dir"] = str(tmp_path)
|
||||
master_conf["sock_dir"] = str(tmp_path)
|
||||
master_conf["ret_port"] = get_unused_localhost_port()
|
||||
master_conf["master_uri"] = "tcp://127.0.0.1:{}".format(master_conf["ret_port"])
|
||||
master_conf["pki_dir"] = str(tmp_path / "pki")
|
||||
os.makedirs(master_conf["pki_dir"])
|
||||
salt.crypt.gen_keys(master_conf["pki_dir"], "master", 4096)
|
||||
minions_keys = os.path.join(master_conf["pki_dir"], "minions")
|
||||
os.makedirs(minions_keys)
|
||||
yield master_conf
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configs(master_config):
|
||||
minion_conf = salt.config.minion_config("")
|
||||
minion_conf["root_dir"] = master_config["root_dir"]
|
||||
minion_conf["id"] = "minion"
|
||||
minion_conf["sock_dir"] = master_config["sock_dir"]
|
||||
minion_conf["pki_dir"] = os.path.join(master_config["root_dir"], "pki_minion")
|
||||
os.makedirs(minion_conf["pki_dir"])
|
||||
minion_conf["master_port"] = master_config["ret_port"]
|
||||
minion_conf["master_ip"] = "127.0.0.1"
|
||||
minion_conf["master_uri"] = "tcp://127.0.0.1:{}".format(master_config["ret_port"])
|
||||
salt.crypt.gen_keys(minion_conf["pki_dir"], "minion", 4096)
|
||||
minion_pub = os.path.join(minion_conf["pki_dir"], "minion.pub")
|
||||
pub_on_master = os.path.join(master_config["pki_dir"], "minions", "minion")
|
||||
with io.open(minion_pub, "r") as rfp:
|
||||
with io.open(pub_on_master, "w") as wfp:
|
||||
wfp.write(rfp.read())
|
||||
return (minion_conf, master_config)
|
||||
|
||||
|
||||
def test_pub_server_channel_with_zmq_transport(io_loop, configs):
|
||||
minion_conf, master_conf = configs
|
||||
|
||||
process_manager = salt.utils.process.ProcessManager()
|
||||
server_channel = salt.transport.server.PubServerChannel.factory(
|
||||
master_conf,
|
||||
)
|
||||
server_channel.pre_fork(process_manager)
|
||||
req_server_channel = salt.transport.server.ReqServerChannel.factory(master_conf)
|
||||
req_server_channel.pre_fork(process_manager)
|
||||
|
||||
def handle_payload(payload):
|
||||
log.info("TEST - Req Server handle payload {}".format(repr(payload)))
|
||||
|
||||
req_server_channel.post_fork(handle_payload, io_loop=io_loop)
|
||||
|
||||
pub_channel = salt.transport.client.AsyncPubChannel.factory(minion_conf)
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def doit(channel, server):
|
||||
log.info("TEST - BEFORE CHANNEL CONNECT")
|
||||
yield channel.connect()
|
||||
log.info("TEST - AFTER CHANNEL CONNECT")
|
||||
|
||||
def cb(payload):
|
||||
log.info("TEST - PUB SERVER MSG {}".format(repr(payload)))
|
||||
io_loop.stop()
|
||||
|
||||
channel.on_recv(cb)
|
||||
server.publish({"tgt_type": "glob", "tgt": ["carbon"], "WTF": "SON"})
|
||||
|
||||
io_loop.add_callback(doit, pub_channel, server_channel)
|
||||
io_loop.start()
|
||||
# server_channel.transport.stop()
|
||||
process_manager.terminate()
|
||||
|
||||
|
||||
def test_pub_server_channel_with_tcp_transport(io_loop, configs):
|
||||
minion_conf, master_conf = configs
|
||||
minion_conf["transport"] = "tcp"
|
||||
master_conf["transport"] = "tcp"
|
||||
|
||||
process_manager = salt.utils.process.ProcessManager()
|
||||
server_channel = salt.transport.server.PubServerChannel.factory(
|
||||
master_conf,
|
||||
)
|
||||
server_channel.pre_fork(process_manager)
|
||||
req_server_channel = salt.transport.server.ReqServerChannel.factory(master_conf)
|
||||
req_server_channel.pre_fork(process_manager)
|
||||
|
||||
def handle_payload(payload):
|
||||
log.info("TEST - Req Server handle payload {}".format(repr(payload)))
|
||||
|
||||
req_server_channel.post_fork(handle_payload, io_loop=io_loop)
|
||||
|
||||
pub_channel = salt.transport.client.AsyncPubChannel.factory(minion_conf)
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def doit(channel, server):
|
||||
log.info("TEST - BEFORE CHANNEL CONNECT")
|
||||
yield channel.connect()
|
||||
log.info("TEST - AFTER CHANNEL CONNECT")
|
||||
|
||||
def cb(payload):
|
||||
log.info("TEST - PUB SERVER MSG {}".format(repr(payload)))
|
||||
io_loop.stop()
|
||||
|
||||
channel.on_recv(cb)
|
||||
server.publish({"tgt_type": "glob", "tgt": ["carbon"], "WTF": "SON"})
|
||||
|
||||
io_loop.add_callback(doit, pub_channel, server_channel)
|
||||
io_loop.start()
|
||||
# server_channel.transport.stop()
|
||||
process_manager.terminate()
|
Loading…
Add table
Reference in a new issue