Initial pass at consolidating ipc and tcp/zeromq

This commit is contained in:
Daniel A. Wozniak 2023-06-15 17:58:47 -07:00 committed by Gareth J. Greenaway
parent 7ec90b0fcd
commit e102f8f11e
7 changed files with 149 additions and 117 deletions

View file

@ -469,7 +469,7 @@ class AsyncPubChannel:
@tornado.gen.coroutine
def wrap_callback(messages):
payload = yield self.transport._decode_messages(messages)
payload = self.transport._decode_messages(messages)
decoded = yield self._decode_payload(payload)
log.debug("PubChannel received: %r", decoded)
if decoded is not None:

View file

@ -253,7 +253,6 @@ class TCPPubClient(salt.transport.base.PublishClient):
yield self.message_client.connect() # wait for the client to be connected
self.connected = True
@tornado.gen.coroutine
def _decode_messages(self, messages):
if not isinstance(messages, dict):
# TODO: For some reason we need to decode here for things
@ -262,7 +261,7 @@ class TCPPubClient(salt.transport.base.PublishClient):
body = salt.transport.frame.decode_embedded_strs(body)
else:
body = messages
raise tornado.gen.Return(body)
return body
@tornado.gen.coroutine
def send(self, msg):

View file

@ -1,6 +1,7 @@
"""
Zeromq transport classes
"""
import asyncio
import errno
import hashlib
import logging
@ -15,6 +16,7 @@ import tornado.concurrent
import tornado.gen
import tornado.ioloop
import zmq.error
import zmq.asyncio
import zmq.eventloop.zmqstream
import salt.payload
@ -108,13 +110,15 @@ class PublishClient(salt.transport.base.PublishClient):
def __init__(self, opts, io_loop, **kwargs):
super().__init__(opts, io_loop, **kwargs)
self.callbacks = {}
self.opts = opts
self.io_loop = io_loop
self.hexid = hashlib.sha1(
salt.utils.stringutils.to_bytes(self.opts["id"])
).hexdigest()
self._closing = False
self.context = zmq.Context()
import zmq.asyncio
self.context = zmq.asyncio.Context()
self._socket = self.context.socket(zmq.SUB)
if self.opts["zmq_filtering"]:
# TODO: constants file for "broadcast"
@ -177,9 +181,11 @@ class PublishClient(salt.transport.base.PublishClient):
# IPv6 sockets work for both IPv6 and IPv4 addresses
self._socket.setsockopt(zmq.IPV4ONLY, 0)
if HAS_ZMQ_MONITOR and self.opts["zmq_monitor"]:
self._monitor = ZeroMQSocketMonitor(self._socket)
self._monitor.start_io_loop(self.io_loop)
# if HAS_ZMQ_MONITOR and self.opts["zmq_monitor"]:
# self._monitor = ZeroMQSocketMonitor(self._socket)
# self._monitor.start_io_loop(self.io_loop)
self._monitor = None
self.task = None
def close(self):
if self._closing is True:
@ -203,16 +209,16 @@ class PublishClient(salt.transport.base.PublishClient):
self.close()
# TODO: this is the time to see if we are connected, maybe use the req channel to guess?
@tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
async def connect(
self, publish_port, connect_callback=None, disconnect_callback=None
):
self.publish_port = publish_port
log.debug(
"Connecting the Minion to the Master publish port, using the URI: %s",
self.master_pub,
)
log.debug("%r connecting to %s", self, self.master_pub)
self._socket.connect(self.master_pub)
connect_callback(True)
# await connect_callback(True)
@property
def master_pub(self):
@ -226,7 +232,6 @@ class PublishClient(salt.transport.base.PublishClient):
source_port=self.opts.get("source_publish_port"),
)
@tornado.gen.coroutine
def _decode_messages(self, messages):
"""
Take the zmq messages, decrypt/decode them into a payload
@ -248,7 +253,7 @@ class PublishClient(salt.transport.base.PublishClient):
and message_target not in ("broadcast", "syndic")
):
log.debug("Publish received for not this minion: %s", message_target)
raise tornado.gen.Return(None)
return None
payload = salt.payload.loads(messages[1])
else:
raise Exception(
@ -258,18 +263,7 @@ class PublishClient(salt.transport.base.PublishClient):
)
# Yield control back to the caller. When the payload has been decoded, assign
# the decoded payload to 'ret' and resume operation
raise tornado.gen.Return(payload)
@property
def stream(self):
"""
Return the current zmqstream, creating one if necessary
"""
if not hasattr(self, "_stream"):
self._stream = zmq.eventloop.zmqstream.ZMQStream(
self._socket, io_loop=self.io_loop
)
return self._stream
return payload
def on_recv(self, callback):
"""
@ -277,11 +271,36 @@ class PublishClient(salt.transport.base.PublishClient):
:param func callback: A function which should be called when data is received
"""
return self.stream.on_recv(callback)
running = asyncio.Event()
running.set()
async def recv(self, timeout=None):
return await self._socket.recv()
@tornado.gen.coroutine
def send(self, msg):
self.stream.send(msg, noblock=True)
async def consume(running):
while running.is_set():
try:
msg = await self._socket.recv_multipart()
except zmq.error.ZMQError:
# We've disconnected just die
break
except Exception: # pylint: disable=broad-except
log.error("Exception while reading", exc_info=True)
break
try:
await callback(msg)
except Exception: # pylint: disable=broad-except
log.error("Exception while running callback", exc_info=True)
log.debug("Callback done %r", callback)
task = self.io_loop.create_task(consume(running))
self.callbacks[callback] = running, task
async def send(self, msg):
await self._socket.send(msg)
class RequestServer(salt.transport.base.DaemonizedRequestServer):
@ -290,6 +309,8 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
self._closing = False
self._monitor = None
self._w_monitor = None
self.task = None
self._event = asyncio.Event()
def zmq_device(self):
"""
@ -354,6 +375,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
return
log.info("MWorkerQueue under PID %s is closing", os.getpid())
self._closing = True
self._event.set()
if getattr(self, "_monitor", None) is not None:
self._monitor.stop()
self._monitor = None
@ -402,7 +424,8 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
they are picked up off the wire
:param IOLoop io_loop: An instance of a Tornado IOLoop, to handle event scheduling
"""
context = zmq.Context(1)
#context = zmq.Context(1)
context = zmq.asyncio.Context()
self._socket = context.socket(zmq.REP)
# Linger -1 means we'll never discard messages.
self._socket.setsockopt(zmq.LINGER, -1)
@ -422,16 +445,23 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
os.path.join(self.opts["sock_dir"], "workers.ipc")
):
os.chmod(os.path.join(self.opts["sock_dir"], "workers.ipc"), 0o600)
self.stream = zmq.eventloop.zmqstream.ZMQStream(self._socket, io_loop=io_loop)
self.message_handler = message_handler
self.stream.on_recv_stream(self.handle_message)
async def callback():
self.task = asyncio.create_task(self.request_handler())
await self.task
io_loop.add_callback(callback)
@tornado.gen.coroutine
def handle_message(self, stream, payload):
async def request_handler(self):
while not self._event.is_set():
request = await self._socket.recv()
reply = await self.handle_message(None, request)
await self._socket.send(self.encode_payload(reply))
async def handle_message(self, stream, payload):
payload = self.decode_payload(payload)
# XXX: Is header really needed?
reply = yield self.message_handler(payload)
self.stream.send(self.encode_payload(reply))
return await self.message_handler(payload)
def encode_payload(self, payload):
return salt.payload.dumps(payload)
@ -452,7 +482,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
sys.exit(salt.defaults.exitcodes.EX_OK)
def decode_payload(self, payload):
payload = salt.payload.loads(payload[0])
payload = salt.payload.loads(payload)
return payload
@ -508,7 +538,7 @@ class AsyncReqMessageClient:
else:
self.io_loop = io_loop
self.context = zmq.Context()
self.context = zmq.asyncio.Context()
self.send_queue = []
# mapping of message -> future
@ -531,28 +561,14 @@ class AsyncReqMessageClient:
return
else:
self._closing = True
if hasattr(self, "stream") and self.stream is not None:
if ZMQ_VERSION_INFO < (14, 3, 0):
# stream.close() doesn't work properly on pyzmq < 14.3.0
if self.stream.socket:
self.stream.socket.close()
self.stream.io_loop.remove_handler(self.stream.socket)
# set this to None, more hacks for messed up pyzmq
self.stream.socket = None
self.socket.close()
else:
self.stream.close(1)
self.socket = None
self.stream = None
self.socket.close()
if self.context.closed is False:
# This hangs if closing the stream causes an import error
self.context.term()
def _init_socket(self):
if hasattr(self, "stream"):
self.stream.close() # pylint: disable=E0203
if hasattr(self, "socket"):
self.socket.close() # pylint: disable=E0203
del self.stream
del self.socket
self.socket = self.context.socket(zmq.REQ)
@ -570,9 +586,6 @@ class AsyncReqMessageClient:
self.socket.setsockopt(zmq.IPV4ONLY, 0)
self.socket.linger = self.linger
self.socket.connect(self.addr)
self.stream = zmq.eventloop.zmqstream.ZMQStream(
self.socket, io_loop=self.io_loop
)
def timeout_message(self, message):
"""
@ -587,44 +600,22 @@ class AsyncReqMessageClient:
if future is not None:
future.set_exception(SaltReqTimeoutError("Message timed out"))
@tornado.gen.coroutine
def send(self, message, timeout=None, callback=None):
async def _send_recv(self, message):
message = salt.payload.dumps(message)
await self.socket.send(message)
ret = await self.socket.recv()
data = salt.payload.loads(ret)
return data
async def send(self, message, timeout=None, callback=None):
"""
Return a future which will be completed when the message has a response
"""
future = tornado.concurrent.Future()
message = salt.payload.dumps(message)
if callback is not None:
def handle_future(future):
response = future.result()
self.io_loop.add_callback(callback, response)
future.add_done_callback(handle_future)
# Add this future to the mapping
self.send_future_map[message] = future
if self.opts.get("detect_mode") is True:
timeout = 1
if timeout is not None:
send_timeout = self.io_loop.call_later(
timeout, self.timeout_message, message
)
def mark_future(msg):
if not future.done():
data = salt.payload.loads(msg[0])
future.set_result(data)
self.send_future_map.pop(message)
self.stream.on_recv(mark_future)
yield self.stream.send(message)
recv = yield future
raise tornado.gen.Return(recv)
response = await asyncio.wait_for(self._send_recv(message), timeout=timeout)
if callback:
callback(response)
return response
class ZeroMQSocketMonitor:
@ -643,6 +634,7 @@ class ZeroMQSocketMonitor:
def start_io_loop(self, io_loop):
log.trace("Event monitor start!")
return
self._monitor_stream = zmq.eventloop.zmqstream.ZMQStream(
self._monitor_socket, io_loop=io_loop
)
@ -712,14 +704,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
This method represents the Publish Daemon process. It is intended to be
run in a thread or process as it creates and runs an it's own ioloop.
"""
ioloop = tornado.ioloop.IOLoop()
ioloop = tornado.ioloop.IOLoop.current()
self.io_loop = ioloop
context = zmq.Context(1)
context = zmq.asyncio.Context()
pub_sock = context.socket(zmq.PUB)
monitor = ZeroMQSocketMonitor(pub_sock)
monitor.start_io_loop(ioloop)
_set_tcp_keepalive(pub_sock, self.opts)
self.dpub_sock = pub_sock = zmq.eventloop.zmqstream.ZMQStream(pub_sock)
self.dpub_sock = pub_sock #= zmq.eventloop.zmqstream.ZMQStream(pub_sock)
# if 2.1 >= zmq < 3.0, we only have one HWM setting
try:
pub_sock.setsockopt(zmq.HWM, self.opts.get("pub_hwm", 1000))
@ -736,7 +728,8 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
pub_sock.setsockopt(zmq.LINGER, -1)
# Prepare minion pull socket
pull_sock = context.socket(zmq.PULL)
pull_sock = zmq.eventloop.zmqstream.ZMQStream(pull_sock)
pull_sock.setsockopt(zmq.LINGER, -1)
#pull_sock = zmq.eventloop.zmqstream.ZMQStream(pull_sock)
pull_sock.setsockopt(zmq.LINGER, -1)
salt.utils.zeromq.check_ipc_path_max_len(self.pull_uri)
# Start the minion command publisher
@ -747,18 +740,22 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
with salt.utils.files.set_umask(0o177):
pull_sock.bind(self.pull_uri)
@tornado.gen.coroutine
def on_recv(packages):
async def on_recv(packages):
for package in packages:
payload = salt.payload.loads(package)
yield publish_payload(payload)
await publish_payload(payload)
pull_sock.on_recv(on_recv)
self.task = None
async def callback():
self.task = asyncio.create_task(self.publisher(pull_sock, publish_payload))
ioloop.add_callback(callback)
try:
ioloop.start()
finally:
pub_sock.close()
pull_sock.close()
if self.task:
self.task.cancel()
@property
def pull_uri(self):
@ -779,6 +776,34 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
@tornado.gen.coroutine
def publish_payload(self, payload, topic_list=None):
payload = salt.payload.dumps(payload)
async def publisher(self, pull_sock, publish_payload):
while True:
try:
package = await pull_sock.recv()
payload = salt.payload.loads(package)
await publish_payload(payload)
except Exception as exc:
log.error("Exception in publisher %s %s", self.pull_uri, exc, exc_info=True)
@property
def pull_uri(self):
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = "tcp://127.0.0.1:{}".format(
self.opts.get("tcp_master_publish_pull", 4514)
)
else:
pull_uri = "ipc://{}".format(
os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
)
return pull_uri
@property
def pub_uri(self):
return "tcp://{interface}:{publish_port}".format(**self.opts)
async def publish_payload(self, payload, topic_list=None):
payload = salt.payload.dumps(payload)
if self.opts["zmq_filtering"]:
if topic_list:
for topic in topic_list:
@ -788,25 +813,25 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
htopic = salt.utils.stringutils.to_bytes(
hashlib.sha1(salt.utils.stringutils.to_bytes(topic)).hexdigest()
)
yield self.dpub_sock.send(htopic, flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
await self.dpub_sock.send(htopic, flags=zmq.SNDMORE)
await self.dpub_sock.send(payload)
log.trace("Filtered data has been sent")
# Syndic broadcast
if self.opts.get("order_masters"):
log.trace("Sending filtered data to syndic")
yield self.dpub_sock.send(b"syndic", flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
await self.dpub_sock.send(b"syndic", flags=zmq.SNDMORE)
await self.dpub_sock.send(payload)
log.trace("Filtered data has been sent to syndic")
# otherwise its a broadcast
else:
# TODO: constants file for "broadcast"
log.trace("Sending broadcasted data over publisher %s", self.pub_uri)
yield self.dpub_sock.send(b"broadcast", flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
await self.dpub_sock.send(b"broadcast", flags=zmq.SNDMORE)
await self.dpub_sock.send(payload)
log.trace("Broadcasted data has been sent")
else:
log.trace("Sending ZMQ-unfiltered data over publisher %s", self.pub_uri)
yield self.dpub_sock.send(payload)
await self.dpub_sock.send(payload)
log.trace("Unfiltered data has been sent")
def pre_fork(self, process_manager):
@ -909,11 +934,9 @@ class RequestClient(salt.transport.base.RequestClient):
def connect(self):
self.message_client.connect()
@tornado.gen.coroutine
def send(self, load, timeout=60):
async def send(self, load, timeout=60):
self.connect()
ret = yield self.message_client.send(load, timeout=timeout)
raise tornado.gen.Return(ret)
return await self.message_client.send(load, timeout=timeout)
def close(self):
self.message_client.close()

View file

@ -416,6 +416,12 @@ class SaltEvent:
kwargs={"io_loop": self.io_loop},
loop_kwarg="io_loop",
)
#self.pusher = salt.utils.asynchronous.SyncWrapper(
# salt.transport.ipc.IPCMessageClient,
# args=(self.pulluri,),
# kwargs={"io_loop": self.io_loop},
# loop_kwarg="io_loop",
#)
try:
self.pusher.connect(timeout=timeout)
self.cpush = True
@ -770,9 +776,10 @@ class SaltEvent:
]
)
msg = salt.utils.stringutils.to_bytes(event, "utf-8")
ret = yield self.pusher.send(msg)
if cb is not None:
cb(ret)
self.pusher.publish(msg, noserial=True)
#ret = yield self.pusher.send(msg)
#if cb is not None:
# cb(ret)
def fire_event(self, data, tag, timeout=1000):
"""
@ -826,6 +833,7 @@ class SaltEvent:
with salt.utils.asynchronous.current_ioloop(self.io_loop):
try:
self.pusher.send(msg)
#self.pusher.send(msg)
except Exception as exc: # pylint: disable=broad-except
log.debug(
"Publisher send failed with exception: %s",

View file

@ -19,7 +19,7 @@ pytestmark = [
]
@pytest.fixture(scope="module", params=["tcp", "zeromq"])
@pytest.fixture(scope="module", params=["zeromq", "tcp"])
def transport(request):
yield request.param

View file

@ -381,7 +381,6 @@ class MockSaltMinionMaster:
master_opts.update({"transport": "zeromq"})
self.server_channel = salt.channel.server.ReqServerChannel.factory(master_opts)
self.server_channel.pre_fork(self.process_manager)
self.io_loop = tornado.ioloop.IOLoop()
self.evt = threading.Event()
self.server_channel.post_fork(self._handle_payload, io_loop=self.io_loop)
@ -466,7 +465,7 @@ def test_serverside_exception(temp_salt_minion, temp_salt_master):
assert ret == "Server-side exception handling payload"
def test_zeromq_async_pub_channel_publish_port(temp_salt_master):
async def test_zeromq_async_pub_channel_publish_port(temp_salt_master):
"""
test when connecting that we use the publish_port set in opts when its not 4506
"""
@ -490,7 +489,7 @@ def test_zeromq_async_pub_channel_publish_port(temp_salt_master):
patch_socket = MagicMock(return_value=True)
patch_auth = MagicMock(return_value=True)
with patch.object(transport, "_socket", patch_socket):
transport.connect(455505)
await transport.connect(455505)
assert str(opts["publish_port"]) in patch_socket.mock_calls[0][1][0]
@ -534,7 +533,7 @@ def test_zeromq_async_pub_channel_filtering_decode_message_no_match(
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
):
res = channel._decode_messages(message)
assert res.result() is None
assert res is None
def test_zeromq_async_pub_channel_filtering_decode_message(
@ -582,7 +581,7 @@ def test_zeromq_async_pub_channel_filtering_decode_message(
) as mock_test:
res = channel._decode_messages(message)
assert res.result()["enc"] == "aes"
assert res["enc"] == "aes"
def test_req_server_chan_encrypt_v2(pki_dir):

View file

@ -125,6 +125,7 @@ class Collector(salt.utils.process.SignalHandlingProcess):
return
self.started.set()
last_msg = time.time()
self.start = last_msg
serial = salt.payload.Serial(self.minion_config)
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
while True:
@ -158,6 +159,8 @@ class Collector(salt.utils.process.SignalHandlingProcess):
if not self.zmq_filtering:
log.exception("Failed to deserialize...")
break
self.end = time.time()
print(f"Total time {self.end - self.start}")
loop.stop()
def run(self):