From e102f8f11e1849df6f5533571230ba67327d7c03 Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Thu, 15 Jun 2023 17:58:47 -0700 Subject: [PATCH] Initial pass at consolidating ipc and tcp/zeromq --- salt/channel/client.py | 2 +- salt/transport/tcp.py | 3 +- salt/transport/zeromq.py | 233 ++++++++++-------- salt/utils/event.py | 14 +- .../transport/ipc/test_pub_server_channel.py | 2 +- tests/pytests/unit/transport/test_zeromq.py | 9 +- tests/support/pytest/transport.py | 3 + 7 files changed, 149 insertions(+), 117 deletions(-) diff --git a/salt/channel/client.py b/salt/channel/client.py index 079ed296dd9..18999c78e03 100644 --- a/salt/channel/client.py +++ b/salt/channel/client.py @@ -469,7 +469,7 @@ class AsyncPubChannel: @tornado.gen.coroutine def wrap_callback(messages): - payload = yield self.transport._decode_messages(messages) + payload = self.transport._decode_messages(messages) decoded = yield self._decode_payload(payload) log.debug("PubChannel received: %r", decoded) if decoded is not None: diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index 835f9db2d5b..09303fbdeae 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -253,7 +253,6 @@ class TCPPubClient(salt.transport.base.PublishClient): yield self.message_client.connect() # wait for the client to be connected self.connected = True - @tornado.gen.coroutine def _decode_messages(self, messages): if not isinstance(messages, dict): # TODO: For some reason we need to decode here for things @@ -262,7 +261,7 @@ class TCPPubClient(salt.transport.base.PublishClient): body = salt.transport.frame.decode_embedded_strs(body) else: body = messages - raise tornado.gen.Return(body) + return body @tornado.gen.coroutine def send(self, msg): diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index 50c2a66deaa..a5e1a6876ca 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -1,6 +1,7 @@ """ Zeromq transport classes """ +import asyncio import errno import hashlib import logging @@ -15,6 +16,7 @@ import tornado.concurrent import tornado.gen import tornado.ioloop import zmq.error +import zmq.asyncio import zmq.eventloop.zmqstream import salt.payload @@ -108,13 +110,15 @@ class PublishClient(salt.transport.base.PublishClient): def __init__(self, opts, io_loop, **kwargs): super().__init__(opts, io_loop, **kwargs) + self.callbacks = {} self.opts = opts self.io_loop = io_loop self.hexid = hashlib.sha1( salt.utils.stringutils.to_bytes(self.opts["id"]) ).hexdigest() self._closing = False - self.context = zmq.Context() + import zmq.asyncio + self.context = zmq.asyncio.Context() self._socket = self.context.socket(zmq.SUB) if self.opts["zmq_filtering"]: # TODO: constants file for "broadcast" @@ -177,9 +181,11 @@ class PublishClient(salt.transport.base.PublishClient): # IPv6 sockets work for both IPv6 and IPv4 addresses self._socket.setsockopt(zmq.IPV4ONLY, 0) - if HAS_ZMQ_MONITOR and self.opts["zmq_monitor"]: - self._monitor = ZeroMQSocketMonitor(self._socket) - self._monitor.start_io_loop(self.io_loop) + # if HAS_ZMQ_MONITOR and self.opts["zmq_monitor"]: + # self._monitor = ZeroMQSocketMonitor(self._socket) + # self._monitor.start_io_loop(self.io_loop) + self._monitor = None + self.task = None def close(self): if self._closing is True: @@ -203,16 +209,16 @@ class PublishClient(salt.transport.base.PublishClient): self.close() # TODO: this is the time to see if we are connected, maybe use the req channel to guess? - @tornado.gen.coroutine - def connect(self, publish_port, connect_callback=None, disconnect_callback=None): + async def connect( + self, publish_port, connect_callback=None, disconnect_callback=None + ): self.publish_port = publish_port log.debug( "Connecting the Minion to the Master publish port, using the URI: %s", self.master_pub, ) - log.debug("%r connecting to %s", self, self.master_pub) self._socket.connect(self.master_pub) - connect_callback(True) + # await connect_callback(True) @property def master_pub(self): @@ -226,7 +232,6 @@ class PublishClient(salt.transport.base.PublishClient): source_port=self.opts.get("source_publish_port"), ) - @tornado.gen.coroutine def _decode_messages(self, messages): """ Take the zmq messages, decrypt/decode them into a payload @@ -248,7 +253,7 @@ class PublishClient(salt.transport.base.PublishClient): and message_target not in ("broadcast", "syndic") ): log.debug("Publish received for not this minion: %s", message_target) - raise tornado.gen.Return(None) + return None payload = salt.payload.loads(messages[1]) else: raise Exception( @@ -258,18 +263,7 @@ class PublishClient(salt.transport.base.PublishClient): ) # Yield control back to the caller. When the payload has been decoded, assign # the decoded payload to 'ret' and resume operation - raise tornado.gen.Return(payload) - - @property - def stream(self): - """ - Return the current zmqstream, creating one if necessary - """ - if not hasattr(self, "_stream"): - self._stream = zmq.eventloop.zmqstream.ZMQStream( - self._socket, io_loop=self.io_loop - ) - return self._stream + return payload def on_recv(self, callback): """ @@ -277,11 +271,36 @@ class PublishClient(salt.transport.base.PublishClient): :param func callback: A function which should be called when data is received """ - return self.stream.on_recv(callback) + running = asyncio.Event() + running.set() + + async def recv(self, timeout=None): + return await self._socket.recv() @tornado.gen.coroutine def send(self, msg): self.stream.send(msg, noblock=True) + async def consume(running): + while running.is_set(): + try: + msg = await self._socket.recv_multipart() + except zmq.error.ZMQError: + # We've disconnected just die + break + except Exception: # pylint: disable=broad-except + log.error("Exception while reading", exc_info=True) + break + try: + await callback(msg) + except Exception: # pylint: disable=broad-except + log.error("Exception while running callback", exc_info=True) + log.debug("Callback done %r", callback) + + task = self.io_loop.create_task(consume(running)) + self.callbacks[callback] = running, task + + async def send(self, msg): + await self._socket.send(msg) class RequestServer(salt.transport.base.DaemonizedRequestServer): @@ -290,6 +309,8 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): self._closing = False self._monitor = None self._w_monitor = None + self.task = None + self._event = asyncio.Event() def zmq_device(self): """ @@ -354,6 +375,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): return log.info("MWorkerQueue under PID %s is closing", os.getpid()) self._closing = True + self._event.set() if getattr(self, "_monitor", None) is not None: self._monitor.stop() self._monitor = None @@ -402,7 +424,8 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): they are picked up off the wire :param IOLoop io_loop: An instance of a Tornado IOLoop, to handle event scheduling """ - context = zmq.Context(1) + #context = zmq.Context(1) + context = zmq.asyncio.Context() self._socket = context.socket(zmq.REP) # Linger -1 means we'll never discard messages. self._socket.setsockopt(zmq.LINGER, -1) @@ -422,16 +445,23 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): os.path.join(self.opts["sock_dir"], "workers.ipc") ): os.chmod(os.path.join(self.opts["sock_dir"], "workers.ipc"), 0o600) - self.stream = zmq.eventloop.zmqstream.ZMQStream(self._socket, io_loop=io_loop) self.message_handler = message_handler - self.stream.on_recv_stream(self.handle_message) + async def callback(): + self.task = asyncio.create_task(self.request_handler()) + await self.task + io_loop.add_callback(callback) - @tornado.gen.coroutine - def handle_message(self, stream, payload): + async def request_handler(self): + while not self._event.is_set(): + request = await self._socket.recv() + reply = await self.handle_message(None, request) + await self._socket.send(self.encode_payload(reply)) + + + async def handle_message(self, stream, payload): payload = self.decode_payload(payload) - # XXX: Is header really needed? - reply = yield self.message_handler(payload) - self.stream.send(self.encode_payload(reply)) + return await self.message_handler(payload) + def encode_payload(self, payload): return salt.payload.dumps(payload) @@ -452,7 +482,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): sys.exit(salt.defaults.exitcodes.EX_OK) def decode_payload(self, payload): - payload = salt.payload.loads(payload[0]) + payload = salt.payload.loads(payload) return payload @@ -508,7 +538,7 @@ class AsyncReqMessageClient: else: self.io_loop = io_loop - self.context = zmq.Context() + self.context = zmq.asyncio.Context() self.send_queue = [] # mapping of message -> future @@ -531,28 +561,14 @@ class AsyncReqMessageClient: return else: self._closing = True - if hasattr(self, "stream") and self.stream is not None: - if ZMQ_VERSION_INFO < (14, 3, 0): - # stream.close() doesn't work properly on pyzmq < 14.3.0 - if self.stream.socket: - self.stream.socket.close() - self.stream.io_loop.remove_handler(self.stream.socket) - # set this to None, more hacks for messed up pyzmq - self.stream.socket = None - self.socket.close() - else: - self.stream.close(1) - self.socket = None - self.stream = None + self.socket.close() if self.context.closed is False: # This hangs if closing the stream causes an import error self.context.term() def _init_socket(self): - if hasattr(self, "stream"): - self.stream.close() # pylint: disable=E0203 + if hasattr(self, "socket"): self.socket.close() # pylint: disable=E0203 - del self.stream del self.socket self.socket = self.context.socket(zmq.REQ) @@ -570,9 +586,6 @@ class AsyncReqMessageClient: self.socket.setsockopt(zmq.IPV4ONLY, 0) self.socket.linger = self.linger self.socket.connect(self.addr) - self.stream = zmq.eventloop.zmqstream.ZMQStream( - self.socket, io_loop=self.io_loop - ) def timeout_message(self, message): """ @@ -587,44 +600,22 @@ class AsyncReqMessageClient: if future is not None: future.set_exception(SaltReqTimeoutError("Message timed out")) - @tornado.gen.coroutine - def send(self, message, timeout=None, callback=None): + async def _send_recv(self, message): + message = salt.payload.dumps(message) + await self.socket.send(message) + ret = await self.socket.recv() + data = salt.payload.loads(ret) + return data + + async def send(self, message, timeout=None, callback=None): """ Return a future which will be completed when the message has a response """ - future = tornado.concurrent.Future() - message = salt.payload.dumps(message) - - if callback is not None: - - def handle_future(future): - response = future.result() - self.io_loop.add_callback(callback, response) - - future.add_done_callback(handle_future) - - # Add this future to the mapping - self.send_future_map[message] = future - - if self.opts.get("detect_mode") is True: - timeout = 1 - - if timeout is not None: - send_timeout = self.io_loop.call_later( - timeout, self.timeout_message, message - ) - - def mark_future(msg): - if not future.done(): - data = salt.payload.loads(msg[0]) - future.set_result(data) - self.send_future_map.pop(message) - - self.stream.on_recv(mark_future) - yield self.stream.send(message) - recv = yield future - raise tornado.gen.Return(recv) + response = await asyncio.wait_for(self._send_recv(message), timeout=timeout) + if callback: + callback(response) + return response class ZeroMQSocketMonitor: @@ -643,6 +634,7 @@ class ZeroMQSocketMonitor: def start_io_loop(self, io_loop): log.trace("Event monitor start!") + return self._monitor_stream = zmq.eventloop.zmqstream.ZMQStream( self._monitor_socket, io_loop=io_loop ) @@ -712,14 +704,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): This method represents the Publish Daemon process. It is intended to be run in a thread or process as it creates and runs an it's own ioloop. """ - ioloop = tornado.ioloop.IOLoop() + ioloop = tornado.ioloop.IOLoop.current() self.io_loop = ioloop - context = zmq.Context(1) + context = zmq.asyncio.Context() pub_sock = context.socket(zmq.PUB) monitor = ZeroMQSocketMonitor(pub_sock) monitor.start_io_loop(ioloop) _set_tcp_keepalive(pub_sock, self.opts) - self.dpub_sock = pub_sock = zmq.eventloop.zmqstream.ZMQStream(pub_sock) + self.dpub_sock = pub_sock #= zmq.eventloop.zmqstream.ZMQStream(pub_sock) # if 2.1 >= zmq < 3.0, we only have one HWM setting try: pub_sock.setsockopt(zmq.HWM, self.opts.get("pub_hwm", 1000)) @@ -736,7 +728,8 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): pub_sock.setsockopt(zmq.LINGER, -1) # Prepare minion pull socket pull_sock = context.socket(zmq.PULL) - pull_sock = zmq.eventloop.zmqstream.ZMQStream(pull_sock) + pull_sock.setsockopt(zmq.LINGER, -1) + #pull_sock = zmq.eventloop.zmqstream.ZMQStream(pull_sock) pull_sock.setsockopt(zmq.LINGER, -1) salt.utils.zeromq.check_ipc_path_max_len(self.pull_uri) # Start the minion command publisher @@ -747,18 +740,22 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): with salt.utils.files.set_umask(0o177): pull_sock.bind(self.pull_uri) - @tornado.gen.coroutine - def on_recv(packages): + async def on_recv(packages): for package in packages: payload = salt.payload.loads(package) - yield publish_payload(payload) + await publish_payload(payload) - pull_sock.on_recv(on_recv) + self.task = None + async def callback(): + self.task = asyncio.create_task(self.publisher(pull_sock, publish_payload)) + ioloop.add_callback(callback) try: ioloop.start() finally: pub_sock.close() pull_sock.close() + if self.task: + self.task.cancel() @property def pull_uri(self): @@ -779,6 +776,34 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): @tornado.gen.coroutine def publish_payload(self, payload, topic_list=None): payload = salt.payload.dumps(payload) + + async def publisher(self, pull_sock, publish_payload): + while True: + try: + package = await pull_sock.recv() + payload = salt.payload.loads(package) + await publish_payload(payload) + except Exception as exc: + log.error("Exception in publisher %s %s", self.pull_uri, exc, exc_info=True) + + @property + def pull_uri(self): + if self.opts.get("ipc_mode", "") == "tcp": + pull_uri = "tcp://127.0.0.1:{}".format( + self.opts.get("tcp_master_publish_pull", 4514) + ) + else: + pull_uri = "ipc://{}".format( + os.path.join(self.opts["sock_dir"], "publish_pull.ipc") + ) + return pull_uri + + @property + def pub_uri(self): + return "tcp://{interface}:{publish_port}".format(**self.opts) + + async def publish_payload(self, payload, topic_list=None): + payload = salt.payload.dumps(payload) if self.opts["zmq_filtering"]: if topic_list: for topic in topic_list: @@ -788,25 +813,25 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): htopic = salt.utils.stringutils.to_bytes( hashlib.sha1(salt.utils.stringutils.to_bytes(topic)).hexdigest() ) - yield self.dpub_sock.send(htopic, flags=zmq.SNDMORE) - yield self.dpub_sock.send(payload) + await self.dpub_sock.send(htopic, flags=zmq.SNDMORE) + await self.dpub_sock.send(payload) log.trace("Filtered data has been sent") # Syndic broadcast if self.opts.get("order_masters"): log.trace("Sending filtered data to syndic") - yield self.dpub_sock.send(b"syndic", flags=zmq.SNDMORE) - yield self.dpub_sock.send(payload) + await self.dpub_sock.send(b"syndic", flags=zmq.SNDMORE) + await self.dpub_sock.send(payload) log.trace("Filtered data has been sent to syndic") # otherwise its a broadcast else: # TODO: constants file for "broadcast" log.trace("Sending broadcasted data over publisher %s", self.pub_uri) - yield self.dpub_sock.send(b"broadcast", flags=zmq.SNDMORE) - yield self.dpub_sock.send(payload) + await self.dpub_sock.send(b"broadcast", flags=zmq.SNDMORE) + await self.dpub_sock.send(payload) log.trace("Broadcasted data has been sent") else: log.trace("Sending ZMQ-unfiltered data over publisher %s", self.pub_uri) - yield self.dpub_sock.send(payload) + await self.dpub_sock.send(payload) log.trace("Unfiltered data has been sent") def pre_fork(self, process_manager): @@ -909,11 +934,9 @@ class RequestClient(salt.transport.base.RequestClient): def connect(self): self.message_client.connect() - @tornado.gen.coroutine - def send(self, load, timeout=60): + async def send(self, load, timeout=60): self.connect() - ret = yield self.message_client.send(load, timeout=timeout) - raise tornado.gen.Return(ret) + return await self.message_client.send(load, timeout=timeout) def close(self): self.message_client.close() diff --git a/salt/utils/event.py b/salt/utils/event.py index b48a1b0b1cf..bfbfeaa7f4b 100644 --- a/salt/utils/event.py +++ b/salt/utils/event.py @@ -416,6 +416,12 @@ class SaltEvent: kwargs={"io_loop": self.io_loop}, loop_kwarg="io_loop", ) + #self.pusher = salt.utils.asynchronous.SyncWrapper( + # salt.transport.ipc.IPCMessageClient, + # args=(self.pulluri,), + # kwargs={"io_loop": self.io_loop}, + # loop_kwarg="io_loop", + #) try: self.pusher.connect(timeout=timeout) self.cpush = True @@ -770,9 +776,10 @@ class SaltEvent: ] ) msg = salt.utils.stringutils.to_bytes(event, "utf-8") - ret = yield self.pusher.send(msg) - if cb is not None: - cb(ret) + self.pusher.publish(msg, noserial=True) + #ret = yield self.pusher.send(msg) + #if cb is not None: + # cb(ret) def fire_event(self, data, tag, timeout=1000): """ @@ -826,6 +833,7 @@ class SaltEvent: with salt.utils.asynchronous.current_ioloop(self.io_loop): try: self.pusher.send(msg) + #self.pusher.send(msg) except Exception as exc: # pylint: disable=broad-except log.debug( "Publisher send failed with exception: %s", diff --git a/tests/pytests/functional/transport/ipc/test_pub_server_channel.py b/tests/pytests/functional/transport/ipc/test_pub_server_channel.py index f9360297aa4..1d42aac7485 100644 --- a/tests/pytests/functional/transport/ipc/test_pub_server_channel.py +++ b/tests/pytests/functional/transport/ipc/test_pub_server_channel.py @@ -19,7 +19,7 @@ pytestmark = [ ] -@pytest.fixture(scope="module", params=["tcp", "zeromq"]) +@pytest.fixture(scope="module", params=["zeromq", "tcp"]) def transport(request): yield request.param diff --git a/tests/pytests/unit/transport/test_zeromq.py b/tests/pytests/unit/transport/test_zeromq.py index 8561b05b3f9..b8c010af6f2 100644 --- a/tests/pytests/unit/transport/test_zeromq.py +++ b/tests/pytests/unit/transport/test_zeromq.py @@ -381,7 +381,6 @@ class MockSaltMinionMaster: master_opts.update({"transport": "zeromq"}) self.server_channel = salt.channel.server.ReqServerChannel.factory(master_opts) self.server_channel.pre_fork(self.process_manager) - self.io_loop = tornado.ioloop.IOLoop() self.evt = threading.Event() self.server_channel.post_fork(self._handle_payload, io_loop=self.io_loop) @@ -466,7 +465,7 @@ def test_serverside_exception(temp_salt_minion, temp_salt_master): assert ret == "Server-side exception handling payload" -def test_zeromq_async_pub_channel_publish_port(temp_salt_master): +async def test_zeromq_async_pub_channel_publish_port(temp_salt_master): """ test when connecting that we use the publish_port set in opts when its not 4506 """ @@ -490,7 +489,7 @@ def test_zeromq_async_pub_channel_publish_port(temp_salt_master): patch_socket = MagicMock(return_value=True) patch_auth = MagicMock(return_value=True) with patch.object(transport, "_socket", patch_socket): - transport.connect(455505) + await transport.connect(455505) assert str(opts["publish_port"]) in patch_socket.mock_calls[0][1][0] @@ -534,7 +533,7 @@ def test_zeromq_async_pub_channel_filtering_decode_message_no_match( MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}), ): res = channel._decode_messages(message) - assert res.result() is None + assert res is None def test_zeromq_async_pub_channel_filtering_decode_message( @@ -582,7 +581,7 @@ def test_zeromq_async_pub_channel_filtering_decode_message( ) as mock_test: res = channel._decode_messages(message) - assert res.result()["enc"] == "aes" + assert res["enc"] == "aes" def test_req_server_chan_encrypt_v2(pki_dir): diff --git a/tests/support/pytest/transport.py b/tests/support/pytest/transport.py index 038c8a6cde6..4972f316144 100644 --- a/tests/support/pytest/transport.py +++ b/tests/support/pytest/transport.py @@ -125,6 +125,7 @@ class Collector(salt.utils.process.SignalHandlingProcess): return self.started.set() last_msg = time.time() + self.start = last_msg serial = salt.payload.Serial(self.minion_config) crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key) while True: @@ -158,6 +159,8 @@ class Collector(salt.utils.process.SignalHandlingProcess): if not self.zmq_filtering: log.exception("Failed to deserialize...") break + self.end = time.time() + print(f"Total time {self.end - self.start}") loop.stop() def run(self):