From 8035d2418d51fe053196b868b229d938fcf6f6e4 Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Mon, 3 Jul 2023 15:32:37 -0700 Subject: [PATCH] Reconnect without killing on_recv handler --- salt/transport/tcp.py | 110 ++++++++++-------- .../transport/tcp/test_message_client.py | 12 +- 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index 6c57362fb63..e3db4aa85d7 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -56,6 +56,10 @@ class ClosingError(Exception): """ """ +def _null_callback(*args, **kwargs): + pass + + def _get_socket(opts): family = socket.AF_INET if opts.get("ipv6", False): @@ -248,8 +252,8 @@ class TCPPubClient(salt.transport.base.PublishClient): self.path = kwargs.get("path", None) self.source_ip = self.opts.get("source_ip") self.source_port = self.opts.get("source_publish_port") - self.connect_callback = None - self.disconnect_callback = None + self.connect_callback = _null_callback + self.disconnect_callback = _null_callback self.on_recv_task = None if self.host is None and self.port is None: if self.path is None: @@ -289,25 +293,34 @@ class TCPPubClient(salt.transport.base.PublishClient): while stream is None and (not self._closed and not self._closing): try: if self.host and self.port: + log.trace( + "PubClient connecting to %r %r:%r", self, self.host, self.port + ) self._tcp_client = TCPClientKeepAlive( self.opts, resolver=self.resolver ) - stream = await self._tcp_client.connect( - ip_bracket(self.host, strip=True), - self.port, - ssl_options=self.opts.get("ssl"), - **kwargs, + stream = await asyncio.wait_for( + self._tcp_client.connect( + ip_bracket(self.host, strip=True), + self.port, + ssl_options=self.opts.get("ssl"), + **kwargs, + ), + 1, ) - log.error( + self.unpacker = salt.utils.msgpack.Unpacker() + log.debug( "PubClient conencted to %r %r:%r", self, self.host, self.port ) else: + log.trace("PubClient connecting to %r %r", self, self.path) sock_type = socket.AF_UNIX stream = tornado.iostream.IOStream( socket.socket(sock_type, socket.SOCK_STREAM) ) - await stream.connect(self.path) - log.error("PubClient conencted to %r %r", self, self.path) + await asyncio.wait_for(stream.connect(self.path), 1) + self.unpacker = salt.utils.msgpack.Unpacker() + log.debug("PubClient conencted to %r %r", self, self.path) self.poller = select.poll() self.poller.register(stream.socket, select.POLLIN) except Exception as exc: # pylint: disable=broad-except @@ -363,19 +376,15 @@ class TCPPubClient(salt.transport.base.PublishClient): await self._stream.send(msg) async def recv(self, timeout=None): - log.error("PubClient recv called") while self._stream is None: await self.connect() await asyncio.sleep(0.001) if timeout == 0: for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) - log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] - poller = select.poll() - poller.register(self._stream.socket, select.POLLIN) try: - events = poller.poll(0) + events = self.poller.poll(0) except TimeoutError: events = [] if events: @@ -383,19 +392,21 @@ class TCPPubClient(salt.transport.base.PublishClient): await self._read_in_progress.acquire() try: byts = await self._stream.read_bytes(4096, partial=True) - log.error("PUBCLIENT GOT BYTES %r", byts) except tornado.iostream.StreamClosedError: - self.close() + log.trace("Stream closed, reconnecting.") + stream = self._stream + self._stream = None + stream.close() await self.connect() return - except Exception: - raise + # except Exception: + # log.error("Unhandled Exception") + # raise finally: self._read_in_progress.release() self.unpacker.feed(byts) for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) - log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] elif timeout: try: @@ -411,16 +422,18 @@ class TCPPubClient(salt.transport.base.PublishClient): else: for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) - log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] while not self._closing: await self._read_in_progress.acquire() try: byts = await self._stream.read_bytes(4096, partial=True) - log.error("PUBCLIENT GOT BYTES %r", byts) except tornado.iostream.StreamClosedError: - self.close() + log.trace("Stream closed, reconnecting.") + stream = self._stream + self._stream = None + stream.close() await self.connect() + log.error("Re-connected - continue") continue # except AttributeError: # return @@ -431,29 +444,27 @@ class TCPPubClient(salt.transport.base.PublishClient): self.unpacker.feed(byts) for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) - log.error("PUBCLIENT GOT %r", framed_msg["body"]) return framed_msg["body"] async def on_recv_handler(self, callback): while not self._stream: await asyncio.sleep(0.003) while True: - try: - log.error("On recv handler %r", self) - msg = await self.recv() - if msg: - callback(msg) - except tornado.iostream.StreamClosedError: - log.trace("Stream closed, reconnecting.") - self._stream.close() - self._stream = None - await self._connect() - if self.disconnect_callback: - self.disconnect_callback() - self.unpacker = salt.utils.msgpack.Unpacker() - continue - except Exception: # py-lint: disable=broad-except - log.error("Unhandled exception in on_recv handler.", exc_info=True) + # try: + msg = await self.recv() + if msg: + callback(msg) + # except tornado.iostream.StreamClosedError: + # log.trace("Stream closed, reconnecting.") + # self._stream.close() + # self._stream = None + # await self._connect() + # if self.disconnect_callback: + # self.disconnect_callback() + # self.unpacker = salt.utils.msgpack.Unpacker() + # continue + # except Exception: # py-lint: disable=broad-except + # log.error("Unhandled exception in on_recv handler.", exc_info=True) def on_recv(self, callback): """ @@ -1150,7 +1161,7 @@ class PubServer(tornado.tcpserver.TCPServer): log.trace( "TCP PubServer sending payload: topic_list=%r %r", topic_list, package ) - log.error("PUBLISH PAYLOAD %r", package) + # log.error("PUBLISH PAYLOAD %r", package) payload = salt.transport.frame.frame_msg(package) to_remove = [] if topic_list: @@ -1170,7 +1181,7 @@ class PubServer(tornado.tcpserver.TCPServer): else: for client in self.clients: try: - log.error("PUBLISH CLIENT %r", package) + # log.error("PUBLISH CLIENT %r", package) # Write the packed str await client.stream.write(payload) except tornado.iostream.StreamClosedError: @@ -1484,7 +1495,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): """ Publish "load" to minions """ - log.error("PUBLISH %r", payload) + # log.error("PUBLISH %r", payload) if not self.pub_sock: self.connect() self.pub_sock.send(payload) @@ -1632,8 +1643,6 @@ class _TCPPubServerPublisher: self._connecting_future.set_exception(e) break - await asyncio.sleep(1) - def close(self): """ Routines to handle any cleanup before the instance shuts down. @@ -1723,7 +1732,8 @@ class TCPReqClient(salt.transport.base.RequestClient): self._connecting_future = tornado.concurrent.Future() self._stream_return_running = False self._stream = None - self.disconnect_callback = None + self.disconnect_callback = _null_callback + self.connect_callback = _null_callback async def getstream(self, **kwargs): if self.source_ip or self.source_port: @@ -1767,8 +1777,8 @@ class TCPReqClient(salt.transport.base.RequestClient): if not self._stream_return_running: self.task = asyncio.create_task(self._stream_return()) # self.io_loop.spawn_callback(self._stream_return) - if self.connect_callback: - self.connect_callback(True) + if self.connect_callback is not None: + self.connect_callback() async def _stream_return(self): self._stream_return_running = True @@ -1807,7 +1817,7 @@ class TCPReqClient(salt.transport.base.RequestClient): self.send_future_map = {} if self._closing or self._closed: return - if self.disconnect_callback: + if self.disconnect_callback is not None: self.disconnect_callback() stream = self._stream self._stream = None @@ -1831,7 +1841,7 @@ class TCPReqClient(salt.transport.base.RequestClient): self.send_future_map = {} if self._closing or self._closed: return - if self.disconnect_callback: + if self.disconnect_callback is not None: self.disconnect_callback() stream = self._stream self._stream = None diff --git a/tests/pytests/functional/transport/tcp/test_message_client.py b/tests/pytests/functional/transport/tcp/test_message_client.py index e15006f51dc..c9b1b302a1b 100644 --- a/tests/pytests/functional/transport/tcp/test_message_client.py +++ b/tests/pytests/functional/transport/tcp/test_message_client.py @@ -28,7 +28,7 @@ def server(config): async def handle_stream(self, stream, address): try: - log.error("Got stream") + log.error("Got stream %r", self.disconnect) while self.disconnect is False: for msg in self.send[:]: msg = self.send.pop(0) @@ -42,11 +42,9 @@ def server(config): log.error("SLEEP") await asyncio.sleep(1) log.error("Close stream") - log.error("After close stream") - except: - log.error("WTFSON", exc_info=True) finally: stream.close() + log.error("After close stream") server = TestServer() try: @@ -123,8 +121,14 @@ async def test_message_client_reconnect(config, client, server): # rest of this test would fail. log.error("Send pmsg %r", pmsg) server.send.append(pmsg) + log.error("After - Send pmsg %r", pmsg) while not received: await tornado.gen.sleep(1) + log.error("received %r", received) assert received == [msg, msg] server.disconnect = True + + # Close the client + client.close() + # Provide time for the on_recv task to complete await tornado.gen.sleep(1)