Cleanup and refactor TCP transport

This commit is contained in:
Jenkins 2023-06-28 13:49:53 -07:00 committed by Gareth J. Greenaway
parent b4e407a8a9
commit 31c59ce450
2 changed files with 229 additions and 52 deletions

View file

@ -392,33 +392,31 @@ class TCPPubClient(salt.transport.base.PublishClient):
finally:
self._read_in_progress.release()
async def on_recv_handler(self, callback):
while not self._stream:
await asyncio.sleep(0.003)
while True:
try:
msg = await self.recv()
logit = True
except tornado.iostream.StreamClosedError:
self._stream.close()
self._stream = None
await self._connect()
if self.disconnect_callback:
self.disconnect_callback()
self.unpacker = salt.utils.msgpack.Unpacker()
continue
except Exception:
log.error("Other exception", exc_info=True)
log.error("on recv got msg %r", msg)
callback(msg)
def on_recv(self, callback):
"""
Register a callback for received messages (that we didn't initiate)
"""
async def setup_callback():
while not self._stream:
await asyncio.sleep(0.003)
while True:
try:
msg = await self.recv()
logit = True
except tornado.iostream.StreamClosedError:
self._stream.close()
self._stream = None
await self._connect()
await asyncio.sleep(0.03)
# if self.disconnect_callback:
# self.disconnect_callback()
self.unpacker = salt.utils.msgpack.Unpacker()
continue
except Exception:
log.error("Other exception", exc_info=True)
log.error("on recv got msg %r", msg)
callback(msg)
self.io_loop.spawn_callback(setup_callback)
self.io_loop.spawn_callback(self.on_recv_handler, callback)
def __enter__(self):
return self
@ -678,9 +676,6 @@ class TCPClientKeepAlive(tornado.tcpclient.TCPClient):
return stream, stream.connect(addr)
# TODO consolidate with IPCClient
# TODO: limit in-flight messages.
# TODO: singleton? Something to not re-create the tcp connection so much
class MessageClient:
"""
Low-level message sending client
@ -699,6 +694,10 @@ class MessageClient:
source_ip=None,
source_port=None,
):
warn_until(
3008,
"MessageClient has been deprecated and will be removed.",
)
self.opts = opts
self.host = host
self.port = port
@ -1135,19 +1134,12 @@ class PubServer(tornado.tcpserver.TCPServer):
log.trace("TCP PubServer finished publishing payload")
class TCPServer:
class TCPPuller:
"""
A Tornado IPC server very similar to Tornado's TCPServer class
but using either UNIX domain sockets or TCP sockets
"""
async_methods = [
"handle_stream",
]
close_methods = [
"close",
]
def __init__(self, socket_path, io_loop=None, payload_handler=None):
"""
Create a new Tornado IPC server
@ -1426,7 +1418,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
else:
pull_uri = self.pull_port
self.pull_sock = TCPServer(
self.pull_sock = TCPPuller(
pull_uri,
io_loop=io_loop,
payload_handler=publish_payload,
@ -1451,7 +1443,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
def connect(self):
log.debug("Connect pusher %s", self.pull_path)
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
TCPMessageClient,
_TCPPubServerPublisher,
(self.pull_path,),
loop_kwarg="io_loop",
)
@ -1483,7 +1475,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
self.pub_sock = None
class TCPMessageClient:
class _TCPPubServerPublisher:
"""
Salt IPC message client
@ -1658,7 +1650,6 @@ class TCPMessageClient:
# FIXME timeout unimplemented
# FIXME tries unimplemented
# @tornado.gen.coroutine
async def send(self, msg, timeout=None, tries=None):
"""
Send a message to an IPC socket
@ -1684,28 +1675,214 @@ class TCPReqClient(salt.transport.base.RequestClient):
def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
self.opts = opts
self.io_loop = io_loop
parse = urllib.parse.urlparse(self.opts["master_uri"])
master_host, master_port = parse.netloc.rsplit(":", 1)
master_addr = (master_host, int(master_port))
# self.resolver = Resolver()
resolver = kwargs.get("resolver")
self.message_client = salt.transport.tcp.MessageClient(
opts,
master_host,
int(master_port),
io_loop=io_loop,
resolver=resolver,
source_ip=opts.get("source_ip"),
source_port=opts.get("source_ret_port"),
)
self.host = master_host
self.port = int(master_port)
self._tcp_client = TCPClientKeepAlive(opts, resolver=resolver)
self.source_ip = opts.get("source_ip")
self.source_port = opts.get("source_ret_port")
self._mid = 1
self._max_messages = int((1 << 31) - 2) # number of IDs before we wrap
# TODO: max queue size
self.send_queue = [] # queue of messages to be sent
self.send_future_map = {} # mapping of request_id -> Future
self._read_until_future = None
self._on_recv = None
self._closing = False
self._closed = False
self._connecting_future = tornado.concurrent.Future()
self._stream_return_running = False
self._stream = None
async def getstream(self, **kwargs):
if self.source_ip or self.source_port:
kwargs = {
"source_ip": self.source_ip,
"source_port": self.source_port,
}
stream = None
while stream is None and (not self._closed and not self._closing):
try:
if self.host and self.port:
stream = await self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs,
)
else:
sock_type = socket.AF_UNIX
path = self.url.replace("ipc://", "")
stream = tornado.iostream.IOStream(
socket.socket(sock_type, socket.SOCK_STREAM)
)
await stream.connect(path)
except Exception as exc: # pylint: disable=broad-except
log.warning(
"TCP Message Client encountered an exception while connecting to"
" %s:%s: %r, will reconnect in %d seconds",
self.host,
self.port,
exc,
self.backoff,
)
await asyncio.sleep(self.backoff)
return stream
async def connect(self):
await self.message_client.connect()
if self._stream is None:
self._stream = await self.getstream()
if self._stream:
if not self._stream_return_running:
self.task = asyncio.create_task(self._stream_return())
# self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
async def send(self, load, timeout=60):
async def _stream_return(self):
self._stream_return_running = True
unpacker = salt.utils.msgpack.Unpacker()
while not self._closing:
try:
wire_bytes = await self._stream.read_bytes(4096, partial=True)
unpacker.feed(wire_bytes)
for framed_msg in unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
header = framed_msg["head"]
body = framed_msg["body"]
message_id = header.get("mid")
if message_id in self.send_future_map:
self.send_future_map.pop(message_id).set_result(body)
# self.remove_message_timeout(message_id)
else:
if self._on_recv is not None:
self.io_loop.spawn_callback(self._on_recv, header, body)
# await self._on_recv(header, body)
else:
log.error(
"Got response for message_id %s that we are not"
" tracking",
message_id,
)
except tornado.iostream.StreamClosedError as e:
log.error(
"tcp stream to %s:%s closed, unable to recv",
self.host,
self.port,
)
for future in self.send_future_map.values():
future.set_exception(e)
self.send_future_map = {}
if self._closing or self._closed:
return
if self.disconnect_callback:
self.disconnect_callback()
stream = self._stream
self._stream = None
if stream:
stream.close()
unpacker = salt.utils.msgpack.Unpacker()
await self.connect()
except TypeError:
# This is an invalid transport
if "detect_mode" in self.opts:
log.info(
"There was an error trying to use TCP transport; "
"attempting to fallback to another transport"
)
else:
raise SaltClientError
except Exception as e: # pylint: disable=broad-except
log.error("Exception parsing response", exc_info=True)
for future in self.send_future_map.values():
future.set_exception(e)
self.send_future_map = {}
if self._closing or self._closed:
return
if self.disconnect_callback:
self.disconnect_callback()
stream = self._stream
self._stream = None
if stream:
stream.close()
unpacker = salt.utils.msgpack.Unpacker()
await self.connect()
self._stream_return_running = False
def _message_id(self):
wrap = False
while self._mid in self.send_future_map:
if self._mid >= self._max_messages:
if wrap:
# this shouldn't ever happen, but just in case
raise Exception("Unable to find available messageid")
self._mid = 1
wrap = True
else:
self._mid += 1
return self._mid
def remove_message_timeout(self, message_id):
if message_id not in self.send_timeout_map:
return
timeout = self.send_timeout_map.pop(message_id)
self.io_loop.remove_timeout(timeout)
def timeout_message(self, message_id, msg):
if message_id not in self.send_future_map:
return
future = self.send_future_map.pop(message_id)
if future is not None:
future.set_exception(SaltReqTimeoutError("Message timed out"))
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
await self.connect()
msg = await self.message_client.send(load, timeout=timeout)
return msg
if self._closing:
raise ClosingError()
while not self._stream:
await asyncio.sleep(0.03)
message_id = self._message_id()
header = {"mid": message_id}
future = tornado.concurrent.Future()
if callback is not None:
def handle_future(future):
response = future.result()
self.io_loop.add_callback(callback, response)
future.add_done_callback(handle_future)
# Add this future to the mapping
self.send_future_map[message_id] = future
if self.opts.get("detect_mode") is True:
timeout = 1
if timeout is not None:
self.io_loop.call_later(timeout, self.timeout_message, message_id, msg)
item = salt.transport.frame.frame_msg(msg, header=header)
async def _do_send():
await self.connect()
# If the _stream is None, we failed to connect.
if self._stream:
await self._stream.write(item)
# Run send in a callback so we can wait on the future, in case we time
# out before we are able to connect.
self.io_loop.add_callback(_do_send)
recv = await future
return recv
def close(self):
self.message_client.close()

View file

@ -998,7 +998,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
already exists "pub_close" is called before creating and connecting a
new socket.
"""
log.error("Connecting to pub server: %s", self.pull_uri)
log.debug("Connecting to pub server: %s", self.pull_uri)
self.ctx = zmq.asyncio.Context()
self.sock = self.ctx.socket(zmq.PUSH)
self.sock.setsockopt(zmq.LINGER, 300)