mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Reconnect without killing on_recv handler
This commit is contained in:
parent
d52df08f22
commit
8035d2418d
2 changed files with 68 additions and 54 deletions
|
@ -56,6 +56,10 @@ class ClosingError(Exception):
|
|||
""" """
|
||||
|
||||
|
||||
def _null_callback(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def _get_socket(opts):
|
||||
family = socket.AF_INET
|
||||
if opts.get("ipv6", False):
|
||||
|
@ -248,8 +252,8 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
self.path = kwargs.get("path", None)
|
||||
self.source_ip = self.opts.get("source_ip")
|
||||
self.source_port = self.opts.get("source_publish_port")
|
||||
self.connect_callback = None
|
||||
self.disconnect_callback = None
|
||||
self.connect_callback = _null_callback
|
||||
self.disconnect_callback = _null_callback
|
||||
self.on_recv_task = None
|
||||
if self.host is None and self.port is None:
|
||||
if self.path is None:
|
||||
|
@ -289,25 +293,34 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
while stream is None and (not self._closed and not self._closing):
|
||||
try:
|
||||
if self.host and self.port:
|
||||
log.trace(
|
||||
"PubClient connecting to %r %r:%r", self, self.host, self.port
|
||||
)
|
||||
self._tcp_client = TCPClientKeepAlive(
|
||||
self.opts, resolver=self.resolver
|
||||
)
|
||||
stream = await self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs,
|
||||
stream = await asyncio.wait_for(
|
||||
self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs,
|
||||
),
|
||||
1,
|
||||
)
|
||||
log.error(
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
log.debug(
|
||||
"PubClient conencted to %r %r:%r", self, self.host, self.port
|
||||
)
|
||||
else:
|
||||
log.trace("PubClient connecting to %r %r", self, self.path)
|
||||
sock_type = socket.AF_UNIX
|
||||
stream = tornado.iostream.IOStream(
|
||||
socket.socket(sock_type, socket.SOCK_STREAM)
|
||||
)
|
||||
await stream.connect(self.path)
|
||||
log.error("PubClient conencted to %r %r", self, self.path)
|
||||
await asyncio.wait_for(stream.connect(self.path), 1)
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
log.debug("PubClient conencted to %r %r", self, self.path)
|
||||
self.poller = select.poll()
|
||||
self.poller.register(stream.socket, select.POLLIN)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
|
@ -363,19 +376,15 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
await self._stream.send(msg)
|
||||
|
||||
async def recv(self, timeout=None):
|
||||
log.error("PubClient recv called")
|
||||
while self._stream is None:
|
||||
await self.connect()
|
||||
await asyncio.sleep(0.001)
|
||||
if timeout == 0:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
log.error("PUBCLIENT GOT %r", framed_msg["body"])
|
||||
return framed_msg["body"]
|
||||
poller = select.poll()
|
||||
poller.register(self._stream.socket, select.POLLIN)
|
||||
try:
|
||||
events = poller.poll(0)
|
||||
events = self.poller.poll(0)
|
||||
except TimeoutError:
|
||||
events = []
|
||||
if events:
|
||||
|
@ -383,19 +392,21 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
await self._read_in_progress.acquire()
|
||||
try:
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
log.error("PUBCLIENT GOT BYTES %r", byts)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
self.close()
|
||||
log.trace("Stream closed, reconnecting.")
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
stream.close()
|
||||
await self.connect()
|
||||
return
|
||||
except Exception:
|
||||
raise
|
||||
# except Exception:
|
||||
# log.error("Unhandled Exception")
|
||||
# raise
|
||||
finally:
|
||||
self._read_in_progress.release()
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
log.error("PUBCLIENT GOT %r", framed_msg["body"])
|
||||
return framed_msg["body"]
|
||||
elif timeout:
|
||||
try:
|
||||
|
@ -411,16 +422,18 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
else:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
log.error("PUBCLIENT GOT %r", framed_msg["body"])
|
||||
return framed_msg["body"]
|
||||
while not self._closing:
|
||||
await self._read_in_progress.acquire()
|
||||
try:
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
log.error("PUBCLIENT GOT BYTES %r", byts)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
self.close()
|
||||
log.trace("Stream closed, reconnecting.")
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
stream.close()
|
||||
await self.connect()
|
||||
log.error("Re-connected - continue")
|
||||
continue
|
||||
# except AttributeError:
|
||||
# return
|
||||
|
@ -431,29 +444,27 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
log.error("PUBCLIENT GOT %r", framed_msg["body"])
|
||||
return framed_msg["body"]
|
||||
|
||||
async def on_recv_handler(self, callback):
|
||||
while not self._stream:
|
||||
await asyncio.sleep(0.003)
|
||||
while True:
|
||||
try:
|
||||
log.error("On recv handler %r", self)
|
||||
msg = await self.recv()
|
||||
if msg:
|
||||
callback(msg)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
log.trace("Stream closed, reconnecting.")
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
await self._connect()
|
||||
if self.disconnect_callback:
|
||||
self.disconnect_callback()
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
continue
|
||||
except Exception: # py-lint: disable=broad-except
|
||||
log.error("Unhandled exception in on_recv handler.", exc_info=True)
|
||||
# try:
|
||||
msg = await self.recv()
|
||||
if msg:
|
||||
callback(msg)
|
||||
# except tornado.iostream.StreamClosedError:
|
||||
# log.trace("Stream closed, reconnecting.")
|
||||
# self._stream.close()
|
||||
# self._stream = None
|
||||
# await self._connect()
|
||||
# if self.disconnect_callback:
|
||||
# self.disconnect_callback()
|
||||
# self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
# continue
|
||||
# except Exception: # py-lint: disable=broad-except
|
||||
# log.error("Unhandled exception in on_recv handler.", exc_info=True)
|
||||
|
||||
def on_recv(self, callback):
|
||||
"""
|
||||
|
@ -1150,7 +1161,7 @@ class PubServer(tornado.tcpserver.TCPServer):
|
|||
log.trace(
|
||||
"TCP PubServer sending payload: topic_list=%r %r", topic_list, package
|
||||
)
|
||||
log.error("PUBLISH PAYLOAD %r", package)
|
||||
# log.error("PUBLISH PAYLOAD %r", package)
|
||||
payload = salt.transport.frame.frame_msg(package)
|
||||
to_remove = []
|
||||
if topic_list:
|
||||
|
@ -1170,7 +1181,7 @@ class PubServer(tornado.tcpserver.TCPServer):
|
|||
else:
|
||||
for client in self.clients:
|
||||
try:
|
||||
log.error("PUBLISH CLIENT %r", package)
|
||||
# log.error("PUBLISH CLIENT %r", package)
|
||||
# Write the packed str
|
||||
await client.stream.write(payload)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
|
@ -1484,7 +1495,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
"""
|
||||
Publish "load" to minions
|
||||
"""
|
||||
log.error("PUBLISH %r", payload)
|
||||
# log.error("PUBLISH %r", payload)
|
||||
if not self.pub_sock:
|
||||
self.connect()
|
||||
self.pub_sock.send(payload)
|
||||
|
@ -1632,8 +1643,6 @@ class _TCPPubServerPublisher:
|
|||
self._connecting_future.set_exception(e)
|
||||
break
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Routines to handle any cleanup before the instance shuts down.
|
||||
|
@ -1723,7 +1732,8 @@ class TCPReqClient(salt.transport.base.RequestClient):
|
|||
self._connecting_future = tornado.concurrent.Future()
|
||||
self._stream_return_running = False
|
||||
self._stream = None
|
||||
self.disconnect_callback = None
|
||||
self.disconnect_callback = _null_callback
|
||||
self.connect_callback = _null_callback
|
||||
|
||||
async def getstream(self, **kwargs):
|
||||
if self.source_ip or self.source_port:
|
||||
|
@ -1767,8 +1777,8 @@ class TCPReqClient(salt.transport.base.RequestClient):
|
|||
if not self._stream_return_running:
|
||||
self.task = asyncio.create_task(self._stream_return())
|
||||
# self.io_loop.spawn_callback(self._stream_return)
|
||||
if self.connect_callback:
|
||||
self.connect_callback(True)
|
||||
if self.connect_callback is not None:
|
||||
self.connect_callback()
|
||||
|
||||
async def _stream_return(self):
|
||||
self._stream_return_running = True
|
||||
|
@ -1807,7 +1817,7 @@ class TCPReqClient(salt.transport.base.RequestClient):
|
|||
self.send_future_map = {}
|
||||
if self._closing or self._closed:
|
||||
return
|
||||
if self.disconnect_callback:
|
||||
if self.disconnect_callback is not None:
|
||||
self.disconnect_callback()
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
|
@ -1831,7 +1841,7 @@ class TCPReqClient(salt.transport.base.RequestClient):
|
|||
self.send_future_map = {}
|
||||
if self._closing or self._closed:
|
||||
return
|
||||
if self.disconnect_callback:
|
||||
if self.disconnect_callback is not None:
|
||||
self.disconnect_callback()
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
|
|
|
@ -28,7 +28,7 @@ def server(config):
|
|||
|
||||
async def handle_stream(self, stream, address):
|
||||
try:
|
||||
log.error("Got stream")
|
||||
log.error("Got stream %r", self.disconnect)
|
||||
while self.disconnect is False:
|
||||
for msg in self.send[:]:
|
||||
msg = self.send.pop(0)
|
||||
|
@ -42,11 +42,9 @@ def server(config):
|
|||
log.error("SLEEP")
|
||||
await asyncio.sleep(1)
|
||||
log.error("Close stream")
|
||||
log.error("After close stream")
|
||||
except:
|
||||
log.error("WTFSON", exc_info=True)
|
||||
finally:
|
||||
stream.close()
|
||||
log.error("After close stream")
|
||||
|
||||
server = TestServer()
|
||||
try:
|
||||
|
@ -123,8 +121,14 @@ async def test_message_client_reconnect(config, client, server):
|
|||
# rest of this test would fail.
|
||||
log.error("Send pmsg %r", pmsg)
|
||||
server.send.append(pmsg)
|
||||
log.error("After - Send pmsg %r", pmsg)
|
||||
while not received:
|
||||
await tornado.gen.sleep(1)
|
||||
log.error("received %r", received)
|
||||
assert received == [msg, msg]
|
||||
server.disconnect = True
|
||||
|
||||
# Close the client
|
||||
client.close()
|
||||
# Provide time for the on_recv task to complete
|
||||
await tornado.gen.sleep(1)
|
||||
|
|
Loading…
Add table
Reference in a new issue