From 5540fd8111c7f0ebd171e4d0b5124692cce372b4 Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Sun, 2 Jul 2023 16:07:26 -0700 Subject: [PATCH] Fix up on_recv logic --- salt/transport/tcp.py | 61 +++++++++++++++---- .../rest_tornado/test_event_listener.py | 10 +++ 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index 1b18090f189..09ae4099d14 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -250,6 +250,7 @@ class TCPPubClient(salt.transport.base.PublishClient): self.source_port = self.opts.get("source_publish_port") self.connect_callback = None self.disconnect_callback = None + self.on_recv_task = None if self.host is None and self.port is None: if self.path is None: raise Exception("A host and port or a path must be provided") @@ -261,6 +262,9 @@ class TCPPubClient(salt.transport.base.PublishClient): if self._closing: return self._closing = True + if self.on_recv_task: + self.on_recv_task.cancel() + self.on_recv_task = None if self._stream is not None: self._stream.close() self._stream = None @@ -294,12 +298,14 @@ class TCPPubClient(salt.transport.base.PublishClient): ssl_options=self.opts.get("ssl"), **kwargs, ) + log.error("PubClient conencted to %r %r:%r", self, self.host, self.port) else: sock_type = socket.AF_UNIX stream = tornado.iostream.IOStream( socket.socket(sock_type, socket.SOCK_STREAM) ) await stream.connect(self.path) + log.error("PubClient conencted to %r %r", self, self.path) self.poller = select.poll() self.poller.register(stream.socket, select.POLLIN) except Exception as exc: # pylint: disable=broad-except @@ -317,6 +323,8 @@ class TCPPubClient(salt.transport.base.PublishClient): async def _connect(self): if self._stream is None: + self._closing = False + self._closed = False self._stream = await self.getstream() if self._stream: # if not self._stream_return_running: @@ -353,13 +361,14 @@ class TCPPubClient(salt.transport.base.PublishClient): await self._stream.send(msg) async def recv(self, timeout=None): - if not self._stream: + log.error("PubClient recv called") + while self._stream is None: await self.connect() await asyncio.sleep(0.001) - return if timeout == 0: for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) + log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] poller = select.poll() poller.register(self._stream.socket, select.POLLIN) @@ -368,12 +377,14 @@ class TCPPubClient(salt.transport.base.PublishClient): except TimeoutError: events = [] if events: - while True: + while not self._closing: await self._read_in_progress.acquire() try: byts = await self._stream.read_bytes(4096, partial=True) + log.error("PUBCLIENT GOT BYTES %r", byts) except tornado.iostream.StreamClosedError: self.close() + await self.connect() return except Exception: raise @@ -382,33 +393,39 @@ class TCPPubClient(salt.transport.base.PublishClient): self.unpacker.feed(byts) for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) + log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] elif timeout: try: return await asyncio.wait_for(self.recv(), timeout=timeout) - except asyncio.exceptions.TimeoutError: + except (TimeoutError, asyncio.exceptions.TimeoutError, asyncio.exceptions.CancelledError): self.close() await self.connect() return else: for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) + log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] - while True: + while not self._closing: await self._read_in_progress.acquire() try: byts = await self._stream.read_bytes(4096, partial=True) + log.error("PUBCLIENT GOT BYTES %r", byts) except tornado.iostream.StreamClosedError: self.close() await self.connect() continue - except Exception: - raise + #except AttributeError: + # return + #except Exception: + # raise finally: self._read_in_progress.release() self.unpacker.feed(byts) for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) + log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] async def on_recv_handler(self, callback): @@ -416,7 +433,10 @@ class TCPPubClient(salt.transport.base.PublishClient): await asyncio.sleep(0.003) while True: try: + log.error("On recv handler %r", self) msg = await self.recv() + if msg: + callback(msg) except tornado.iostream.StreamClosedError: log.trace("Stream closed, reconnecting.") self._stream.close() @@ -428,13 +448,17 @@ class TCPPubClient(salt.transport.base.PublishClient): continue except Exception: # py-lint: disable=broad-except log.error("Unhandled exception in on_recv handler.", exc_info=True) - callback(msg) def on_recv(self, callback): """ Register a callback for received messages (that we didn't initiate) """ - self.io_loop.spawn_callback(self.on_recv_handler, callback) + if self.on_recv_task: + self.on_recv_task.cancel() + if callback is None: + self.on_recv_task = None + else: + self.on_recv_task = asyncio.create_task(self.on_recv_handler(callback)) def __enter__(self): return self @@ -775,7 +799,9 @@ class MessageClient: # pylint: disable=W1701 def __del__(self): if not self._closing: - warnings.warn("%r not closed", self) + warnings.warn( + "unclosed message client {self!r}", ResourceWarning, source=self + ) # pylint: enable=W1701 @@ -1040,7 +1066,9 @@ class Subscriber: # pylint: disable=W1701 def __del__(self): if not self._closing: - warnings.warn("%r not closed", self) + warnings.warn( + "unclosed publish subscriber {self!r}", ResourceWarning, source=self + ) # pylint: enable=W1701 @@ -1116,6 +1144,7 @@ class PubServer(tornado.tcpserver.TCPServer): log.trace( "TCP PubServer sending payload: topic_list=%r %r", topic_list, package ) + log.error("PUBLISH PAYLOAD %r", package) payload = salt.transport.frame.frame_msg(package) to_remove = [] if topic_list: @@ -1135,6 +1164,7 @@ class PubServer(tornado.tcpserver.TCPServer): else: for client in self.clients: try: + log.error("PUBLISH CLIENT %r", package) # Write the packed str await client.stream.write(payload) except tornado.iostream.StreamClosedError: @@ -1295,7 +1325,9 @@ class TCPPuller: # pylint: disable=W1701 def __del__(self): if not self._closing: - warnings.warn("%r not closed", self) + warnings.warn( + "unclosed tcp puller {self!r}", ResourceWarning, source=self + ) # pylint: enable=W1701 @@ -1436,6 +1468,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): """ Publish "load" to minions """ + log.error("PUBLISH %r", payload) if not self.pub_sock: self.connect() self.pub_sock.send(payload) @@ -1605,7 +1638,9 @@ class _TCPPubServerPublisher: # pylint: disable=W1701 def __del__(self): if not self._closing: - warnings.warn("%r not closed", self) + warnings.warn( + "unclosed publisher client {self!r}", ResourceWarning, source=self + ) # pylint: enable=W1701 diff --git a/tests/pytests/functional/netapi/rest_tornado/test_event_listener.py b/tests/pytests/functional/netapi/rest_tornado/test_event_listener.py index 5ed798ad45a..2b66981b91d 100644 --- a/tests/pytests/functional/netapi/rest_tornado/test_event_listener.py +++ b/tests/pytests/functional/netapi/rest_tornado/test_event_listener.py @@ -1,9 +1,13 @@ +import asyncio +import time +import logging import pytest import salt.utils.event from salt.netapi.rest_tornado import saltnado from tests.support.events import eventpublisher_process +log = logging.getLogger(__name__) def _check_skip(grains): if grains["os"] == "MacOS": @@ -40,6 +44,7 @@ async def test_simple(sock_dir): {}, # we don't use mod_opts, don't save? {"sock_dir": sock_dir, "transport": "zeromq"}, ) + await asyncio.sleep(1) event_future = event_listener.get_event( request, "evt1" ) # get an event future @@ -65,6 +70,7 @@ async def test_set_event_handler(sock_dir): {}, # we don't use mod_opts, don't save? {"sock_dir": sock_dir, "transport": "zeromq"}, ) + await asyncio.sleep(1) event_future = event_listener.get_event( request, tag="evt", @@ -88,6 +94,7 @@ async def test_timeout(sock_dir): {}, # we don't use mod_opts, don't save? {"sock_dir": sock_dir, "transport": "zeromq"}, ) + await asyncio.sleep(1) event_future = event_listener.get_event( request, tag="evt1", @@ -110,13 +117,16 @@ async def test_clean_by_request(sock_dir, io_loop): """ with eventpublisher_process(sock_dir): + log.error("After event pubserver start") with salt.utils.event.MasterEvent(sock_dir) as me: + log.error("After master event start %r", me) request1 = Request() request2 = Request() event_listener = saltnado.EventListener( {}, # we don't use mod_opts, don't save? {"sock_dir": sock_dir, "transport": "zeromq"}, ) + await asyncio.sleep(1) assert 0 == len(event_listener.tag_map) assert 0 == len(event_listener.request_map)