Reconnect without killing on_recv handler

This commit is contained in:
Daniel A. Wozniak 2023-07-03 15:32:37 -07:00 committed by Gareth J. Greenaway
parent d52df08f22
commit 8035d2418d
2 changed files with 68 additions and 54 deletions

View file

@ -56,6 +56,10 @@ class ClosingError(Exception):
""" """
def _null_callback(*args, **kwargs):
pass
def _get_socket(opts):
family = socket.AF_INET
if opts.get("ipv6", False):
@ -248,8 +252,8 @@ class TCPPubClient(salt.transport.base.PublishClient):
self.path = kwargs.get("path", None)
self.source_ip = self.opts.get("source_ip")
self.source_port = self.opts.get("source_publish_port")
self.connect_callback = None
self.disconnect_callback = None
self.connect_callback = _null_callback
self.disconnect_callback = _null_callback
self.on_recv_task = None
if self.host is None and self.port is None:
if self.path is None:
@ -289,25 +293,34 @@ class TCPPubClient(salt.transport.base.PublishClient):
while stream is None and (not self._closed and not self._closing):
try:
if self.host and self.port:
log.trace(
"PubClient connecting to %r %r:%r", self, self.host, self.port
)
self._tcp_client = TCPClientKeepAlive(
self.opts, resolver=self.resolver
)
stream = await self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs,
stream = await asyncio.wait_for(
self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs,
),
1,
)
log.error(
self.unpacker = salt.utils.msgpack.Unpacker()
log.debug(
"PubClient conencted to %r %r:%r", self, self.host, self.port
)
else:
log.trace("PubClient connecting to %r %r", self, self.path)
sock_type = socket.AF_UNIX
stream = tornado.iostream.IOStream(
socket.socket(sock_type, socket.SOCK_STREAM)
)
await stream.connect(self.path)
log.error("PubClient conencted to %r %r", self, self.path)
await asyncio.wait_for(stream.connect(self.path), 1)
self.unpacker = salt.utils.msgpack.Unpacker()
log.debug("PubClient conencted to %r %r", self, self.path)
self.poller = select.poll()
self.poller.register(stream.socket, select.POLLIN)
except Exception as exc: # pylint: disable=broad-except
@ -363,19 +376,15 @@ class TCPPubClient(salt.transport.base.PublishClient):
await self._stream.send(msg)
async def recv(self, timeout=None):
log.error("PubClient recv called")
while self._stream is None:
await self.connect()
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
poller = select.poll()
poller.register(self._stream.socket, select.POLLIN)
try:
events = poller.poll(0)
events = self.poller.poll(0)
except TimeoutError:
events = []
if events:
@ -383,19 +392,21 @@ class TCPPubClient(salt.transport.base.PublishClient):
await self._read_in_progress.acquire()
try:
byts = await self._stream.read_bytes(4096, partial=True)
log.error("PUBCLIENT GOT BYTES %r", byts)
except tornado.iostream.StreamClosedError:
self.close()
log.trace("Stream closed, reconnecting.")
stream = self._stream
self._stream = None
stream.close()
await self.connect()
return
except Exception:
raise
# except Exception:
# log.error("Unhandled Exception")
# raise
finally:
self._read_in_progress.release()
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
elif timeout:
try:
@ -411,16 +422,18 @@ class TCPPubClient(salt.transport.base.PublishClient):
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
while not self._closing:
await self._read_in_progress.acquire()
try:
byts = await self._stream.read_bytes(4096, partial=True)
log.error("PUBCLIENT GOT BYTES %r", byts)
except tornado.iostream.StreamClosedError:
self.close()
log.trace("Stream closed, reconnecting.")
stream = self._stream
self._stream = None
stream.close()
await self.connect()
log.error("Re-connected - continue")
continue
# except AttributeError:
# return
@ -431,29 +444,27 @@ class TCPPubClient(salt.transport.base.PublishClient):
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
async def on_recv_handler(self, callback):
while not self._stream:
await asyncio.sleep(0.003)
while True:
try:
log.error("On recv handler %r", self)
msg = await self.recv()
if msg:
callback(msg)
except tornado.iostream.StreamClosedError:
log.trace("Stream closed, reconnecting.")
self._stream.close()
self._stream = None
await self._connect()
if self.disconnect_callback:
self.disconnect_callback()
self.unpacker = salt.utils.msgpack.Unpacker()
continue
except Exception: # py-lint: disable=broad-except
log.error("Unhandled exception in on_recv handler.", exc_info=True)
# try:
msg = await self.recv()
if msg:
callback(msg)
# except tornado.iostream.StreamClosedError:
# log.trace("Stream closed, reconnecting.")
# self._stream.close()
# self._stream = None
# await self._connect()
# if self.disconnect_callback:
# self.disconnect_callback()
# self.unpacker = salt.utils.msgpack.Unpacker()
# continue
# except Exception: # py-lint: disable=broad-except
# log.error("Unhandled exception in on_recv handler.", exc_info=True)
def on_recv(self, callback):
"""
@ -1150,7 +1161,7 @@ class PubServer(tornado.tcpserver.TCPServer):
log.trace(
"TCP PubServer sending payload: topic_list=%r %r", topic_list, package
)
log.error("PUBLISH PAYLOAD %r", package)
# log.error("PUBLISH PAYLOAD %r", package)
payload = salt.transport.frame.frame_msg(package)
to_remove = []
if topic_list:
@ -1170,7 +1181,7 @@ class PubServer(tornado.tcpserver.TCPServer):
else:
for client in self.clients:
try:
log.error("PUBLISH CLIENT %r", package)
# log.error("PUBLISH CLIENT %r", package)
# Write the packed str
await client.stream.write(payload)
except tornado.iostream.StreamClosedError:
@ -1484,7 +1495,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
"""
Publish "load" to minions
"""
log.error("PUBLISH %r", payload)
# log.error("PUBLISH %r", payload)
if not self.pub_sock:
self.connect()
self.pub_sock.send(payload)
@ -1632,8 +1643,6 @@ class _TCPPubServerPublisher:
self._connecting_future.set_exception(e)
break
await asyncio.sleep(1)
def close(self):
"""
Routines to handle any cleanup before the instance shuts down.
@ -1723,7 +1732,8 @@ class TCPReqClient(salt.transport.base.RequestClient):
self._connecting_future = tornado.concurrent.Future()
self._stream_return_running = False
self._stream = None
self.disconnect_callback = None
self.disconnect_callback = _null_callback
self.connect_callback = _null_callback
async def getstream(self, **kwargs):
if self.source_ip or self.source_port:
@ -1767,8 +1777,8 @@ class TCPReqClient(salt.transport.base.RequestClient):
if not self._stream_return_running:
self.task = asyncio.create_task(self._stream_return())
# self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
if self.connect_callback is not None:
self.connect_callback()
async def _stream_return(self):
self._stream_return_running = True
@ -1807,7 +1817,7 @@ class TCPReqClient(salt.transport.base.RequestClient):
self.send_future_map = {}
if self._closing or self._closed:
return
if self.disconnect_callback:
if self.disconnect_callback is not None:
self.disconnect_callback()
stream = self._stream
self._stream = None
@ -1831,7 +1841,7 @@ class TCPReqClient(salt.transport.base.RequestClient):
self.send_future_map = {}
if self._closing or self._closed:
return
if self.disconnect_callback:
if self.disconnect_callback is not None:
self.disconnect_callback()
stream = self._stream
self._stream = None

View file

@ -28,7 +28,7 @@ def server(config):
async def handle_stream(self, stream, address):
try:
log.error("Got stream")
log.error("Got stream %r", self.disconnect)
while self.disconnect is False:
for msg in self.send[:]:
msg = self.send.pop(0)
@ -42,11 +42,9 @@ def server(config):
log.error("SLEEP")
await asyncio.sleep(1)
log.error("Close stream")
log.error("After close stream")
except:
log.error("WTFSON", exc_info=True)
finally:
stream.close()
log.error("After close stream")
server = TestServer()
try:
@ -123,8 +121,14 @@ async def test_message_client_reconnect(config, client, server):
# rest of this test would fail.
log.error("Send pmsg %r", pmsg)
server.send.append(pmsg)
log.error("After - Send pmsg %r", pmsg)
while not received:
await tornado.gen.sleep(1)
log.error("received %r", received)
assert received == [msg, msg]
server.disconnect = True
# Close the client
client.close()
# Provide time for the on_recv task to complete
await tornado.gen.sleep(1)