mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Cleanup and refactor TCP transport
This commit is contained in:
parent
b4e407a8a9
commit
31c59ce450
2 changed files with 229 additions and 52 deletions
|
@ -392,33 +392,31 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
finally:
|
||||
self._read_in_progress.release()
|
||||
|
||||
async def on_recv_handler(self, callback):
|
||||
while not self._stream:
|
||||
await asyncio.sleep(0.003)
|
||||
while True:
|
||||
try:
|
||||
msg = await self.recv()
|
||||
logit = True
|
||||
except tornado.iostream.StreamClosedError:
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
await self._connect()
|
||||
if self.disconnect_callback:
|
||||
self.disconnect_callback()
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
continue
|
||||
except Exception:
|
||||
log.error("Other exception", exc_info=True)
|
||||
log.error("on recv got msg %r", msg)
|
||||
callback(msg)
|
||||
|
||||
def on_recv(self, callback):
|
||||
"""
|
||||
Register a callback for received messages (that we didn't initiate)
|
||||
"""
|
||||
|
||||
async def setup_callback():
|
||||
while not self._stream:
|
||||
await asyncio.sleep(0.003)
|
||||
while True:
|
||||
try:
|
||||
msg = await self.recv()
|
||||
logit = True
|
||||
except tornado.iostream.StreamClosedError:
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
await self._connect()
|
||||
await asyncio.sleep(0.03)
|
||||
# if self.disconnect_callback:
|
||||
# self.disconnect_callback()
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
continue
|
||||
except Exception:
|
||||
log.error("Other exception", exc_info=True)
|
||||
log.error("on recv got msg %r", msg)
|
||||
callback(msg)
|
||||
|
||||
self.io_loop.spawn_callback(setup_callback)
|
||||
self.io_loop.spawn_callback(self.on_recv_handler, callback)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -678,9 +676,6 @@ class TCPClientKeepAlive(tornado.tcpclient.TCPClient):
|
|||
return stream, stream.connect(addr)
|
||||
|
||||
|
||||
# TODO consolidate with IPCClient
|
||||
# TODO: limit in-flight messages.
|
||||
# TODO: singleton? Something to not re-create the tcp connection so much
|
||||
class MessageClient:
|
||||
"""
|
||||
Low-level message sending client
|
||||
|
@ -699,6 +694,10 @@ class MessageClient:
|
|||
source_ip=None,
|
||||
source_port=None,
|
||||
):
|
||||
warn_until(
|
||||
3008,
|
||||
"MessageClient has been deprecated and will be removed.",
|
||||
)
|
||||
self.opts = opts
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
@ -1135,19 +1134,12 @@ class PubServer(tornado.tcpserver.TCPServer):
|
|||
log.trace("TCP PubServer finished publishing payload")
|
||||
|
||||
|
||||
class TCPServer:
|
||||
class TCPPuller:
|
||||
"""
|
||||
A Tornado IPC server very similar to Tornado's TCPServer class
|
||||
but using either UNIX domain sockets or TCP sockets
|
||||
"""
|
||||
|
||||
async_methods = [
|
||||
"handle_stream",
|
||||
]
|
||||
close_methods = [
|
||||
"close",
|
||||
]
|
||||
|
||||
def __init__(self, socket_path, io_loop=None, payload_handler=None):
|
||||
"""
|
||||
Create a new Tornado IPC server
|
||||
|
@ -1426,7 +1418,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
else:
|
||||
pull_uri = self.pull_port
|
||||
|
||||
self.pull_sock = TCPServer(
|
||||
self.pull_sock = TCPPuller(
|
||||
pull_uri,
|
||||
io_loop=io_loop,
|
||||
payload_handler=publish_payload,
|
||||
|
@ -1451,7 +1443,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
def connect(self):
|
||||
log.debug("Connect pusher %s", self.pull_path)
|
||||
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
|
||||
TCPMessageClient,
|
||||
_TCPPubServerPublisher,
|
||||
(self.pull_path,),
|
||||
loop_kwarg="io_loop",
|
||||
)
|
||||
|
@ -1483,7 +1475,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
self.pub_sock = None
|
||||
|
||||
|
||||
class TCPMessageClient:
|
||||
class _TCPPubServerPublisher:
|
||||
"""
|
||||
Salt IPC message client
|
||||
|
||||
|
@ -1658,7 +1650,6 @@ class TCPMessageClient:
|
|||
|
||||
# FIXME timeout unimplemented
|
||||
# FIXME tries unimplemented
|
||||
# @tornado.gen.coroutine
|
||||
async def send(self, msg, timeout=None, tries=None):
|
||||
"""
|
||||
Send a message to an IPC socket
|
||||
|
@ -1684,28 +1675,214 @@ class TCPReqClient(salt.transport.base.RequestClient):
|
|||
def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
|
||||
self.opts = opts
|
||||
self.io_loop = io_loop
|
||||
|
||||
parse = urllib.parse.urlparse(self.opts["master_uri"])
|
||||
master_host, master_port = parse.netloc.rsplit(":", 1)
|
||||
master_addr = (master_host, int(master_port))
|
||||
# self.resolver = Resolver()
|
||||
resolver = kwargs.get("resolver")
|
||||
self.message_client = salt.transport.tcp.MessageClient(
|
||||
opts,
|
||||
master_host,
|
||||
int(master_port),
|
||||
io_loop=io_loop,
|
||||
resolver=resolver,
|
||||
source_ip=opts.get("source_ip"),
|
||||
source_port=opts.get("source_ret_port"),
|
||||
)
|
||||
|
||||
self.host = master_host
|
||||
self.port = int(master_port)
|
||||
self._tcp_client = TCPClientKeepAlive(opts, resolver=resolver)
|
||||
self.source_ip = opts.get("source_ip")
|
||||
self.source_port = opts.get("source_ret_port")
|
||||
self._mid = 1
|
||||
self._max_messages = int((1 << 31) - 2) # number of IDs before we wrap
|
||||
# TODO: max queue size
|
||||
self.send_queue = [] # queue of messages to be sent
|
||||
self.send_future_map = {} # mapping of request_id -> Future
|
||||
|
||||
self._read_until_future = None
|
||||
self._on_recv = None
|
||||
self._closing = False
|
||||
self._closed = False
|
||||
self._connecting_future = tornado.concurrent.Future()
|
||||
self._stream_return_running = False
|
||||
self._stream = None
|
||||
|
||||
async def getstream(self, **kwargs):
|
||||
if self.source_ip or self.source_port:
|
||||
kwargs = {
|
||||
"source_ip": self.source_ip,
|
||||
"source_port": self.source_port,
|
||||
}
|
||||
stream = None
|
||||
while stream is None and (not self._closed and not self._closing):
|
||||
try:
|
||||
if self.host and self.port:
|
||||
stream = await self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
sock_type = socket.AF_UNIX
|
||||
path = self.url.replace("ipc://", "")
|
||||
stream = tornado.iostream.IOStream(
|
||||
socket.socket(sock_type, socket.SOCK_STREAM)
|
||||
)
|
||||
await stream.connect(path)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
log.warning(
|
||||
"TCP Message Client encountered an exception while connecting to"
|
||||
" %s:%s: %r, will reconnect in %d seconds",
|
||||
self.host,
|
||||
self.port,
|
||||
exc,
|
||||
self.backoff,
|
||||
)
|
||||
await asyncio.sleep(self.backoff)
|
||||
return stream
|
||||
|
||||
async def connect(self):
|
||||
await self.message_client.connect()
|
||||
if self._stream is None:
|
||||
self._stream = await self.getstream()
|
||||
if self._stream:
|
||||
if not self._stream_return_running:
|
||||
self.task = asyncio.create_task(self._stream_return())
|
||||
# self.io_loop.spawn_callback(self._stream_return)
|
||||
if self.connect_callback:
|
||||
self.connect_callback(True)
|
||||
|
||||
async def send(self, load, timeout=60):
|
||||
async def _stream_return(self):
|
||||
self._stream_return_running = True
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
while not self._closing:
|
||||
try:
|
||||
wire_bytes = await self._stream.read_bytes(4096, partial=True)
|
||||
unpacker.feed(wire_bytes)
|
||||
for framed_msg in unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
|
||||
header = framed_msg["head"]
|
||||
body = framed_msg["body"]
|
||||
message_id = header.get("mid")
|
||||
|
||||
if message_id in self.send_future_map:
|
||||
self.send_future_map.pop(message_id).set_result(body)
|
||||
# self.remove_message_timeout(message_id)
|
||||
else:
|
||||
if self._on_recv is not None:
|
||||
self.io_loop.spawn_callback(self._on_recv, header, body)
|
||||
# await self._on_recv(header, body)
|
||||
else:
|
||||
log.error(
|
||||
"Got response for message_id %s that we are not"
|
||||
" tracking",
|
||||
message_id,
|
||||
)
|
||||
except tornado.iostream.StreamClosedError as e:
|
||||
log.error(
|
||||
"tcp stream to %s:%s closed, unable to recv",
|
||||
self.host,
|
||||
self.port,
|
||||
)
|
||||
for future in self.send_future_map.values():
|
||||
future.set_exception(e)
|
||||
self.send_future_map = {}
|
||||
if self._closing or self._closed:
|
||||
return
|
||||
if self.disconnect_callback:
|
||||
self.disconnect_callback()
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
if stream:
|
||||
stream.close()
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
await self.connect()
|
||||
except TypeError:
|
||||
# This is an invalid transport
|
||||
if "detect_mode" in self.opts:
|
||||
log.info(
|
||||
"There was an error trying to use TCP transport; "
|
||||
"attempting to fallback to another transport"
|
||||
)
|
||||
else:
|
||||
raise SaltClientError
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
log.error("Exception parsing response", exc_info=True)
|
||||
for future in self.send_future_map.values():
|
||||
future.set_exception(e)
|
||||
self.send_future_map = {}
|
||||
if self._closing or self._closed:
|
||||
return
|
||||
if self.disconnect_callback:
|
||||
self.disconnect_callback()
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
if stream:
|
||||
stream.close()
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
await self.connect()
|
||||
self._stream_return_running = False
|
||||
|
||||
def _message_id(self):
|
||||
wrap = False
|
||||
while self._mid in self.send_future_map:
|
||||
if self._mid >= self._max_messages:
|
||||
if wrap:
|
||||
# this shouldn't ever happen, but just in case
|
||||
raise Exception("Unable to find available messageid")
|
||||
self._mid = 1
|
||||
wrap = True
|
||||
else:
|
||||
self._mid += 1
|
||||
|
||||
return self._mid
|
||||
|
||||
def remove_message_timeout(self, message_id):
|
||||
if message_id not in self.send_timeout_map:
|
||||
return
|
||||
timeout = self.send_timeout_map.pop(message_id)
|
||||
self.io_loop.remove_timeout(timeout)
|
||||
|
||||
def timeout_message(self, message_id, msg):
|
||||
if message_id not in self.send_future_map:
|
||||
return
|
||||
future = self.send_future_map.pop(message_id)
|
||||
if future is not None:
|
||||
future.set_exception(SaltReqTimeoutError("Message timed out"))
|
||||
|
||||
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
|
||||
await self.connect()
|
||||
msg = await self.message_client.send(load, timeout=timeout)
|
||||
return msg
|
||||
if self._closing:
|
||||
raise ClosingError()
|
||||
while not self._stream:
|
||||
await asyncio.sleep(0.03)
|
||||
message_id = self._message_id()
|
||||
header = {"mid": message_id}
|
||||
future = tornado.concurrent.Future()
|
||||
|
||||
if callback is not None:
|
||||
|
||||
def handle_future(future):
|
||||
response = future.result()
|
||||
self.io_loop.add_callback(callback, response)
|
||||
|
||||
future.add_done_callback(handle_future)
|
||||
# Add this future to the mapping
|
||||
self.send_future_map[message_id] = future
|
||||
|
||||
if self.opts.get("detect_mode") is True:
|
||||
timeout = 1
|
||||
|
||||
if timeout is not None:
|
||||
self.io_loop.call_later(timeout, self.timeout_message, message_id, msg)
|
||||
|
||||
item = salt.transport.frame.frame_msg(msg, header=header)
|
||||
|
||||
async def _do_send():
|
||||
await self.connect()
|
||||
# If the _stream is None, we failed to connect.
|
||||
if self._stream:
|
||||
await self._stream.write(item)
|
||||
|
||||
# Run send in a callback so we can wait on the future, in case we time
|
||||
# out before we are able to connect.
|
||||
self.io_loop.add_callback(_do_send)
|
||||
recv = await future
|
||||
return recv
|
||||
|
||||
def close(self):
|
||||
self.message_client.close()
|
||||
|
|
|
@ -998,7 +998,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
already exists "pub_close" is called before creating and connecting a
|
||||
new socket.
|
||||
"""
|
||||
log.error("Connecting to pub server: %s", self.pull_uri)
|
||||
log.debug("Connecting to pub server: %s", self.pull_uri)
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.sock = self.ctx.socket(zmq.PUSH)
|
||||
self.sock.setsockopt(zmq.LINGER, 300)
|
||||
|
|
Loading…
Add table
Reference in a new issue