Revert changes to MessageClient other than deprecation

This commit is contained in:
Daniel A. Wozniak 2023-07-31 14:29:02 -07:00 committed by Gareth J. Greenaway
parent 13190ff89a
commit 1d1ca7b6cb
2 changed files with 43 additions and 99 deletions

View file

@ -733,7 +733,6 @@ class MessageClient:
opts,
host,
port,
url=None,
io_loop=None,
resolver=None,
connect_callback=None,
@ -748,7 +747,6 @@ class MessageClient:
self.opts = opts
self.host = host
self.port = port
self.url = url
self.source_ip = source_ip
self.source_port = source_port
self.connect_callback = connect_callback
@ -771,28 +769,20 @@ class MessageClient:
self._stream = None
self.backoff = opts.get("tcp_reconnect_backoff", 1)
self.callbacks = {}
self.unpacker = salt.utils.msgpack.Unpacker()
self._read_in_progress = Lock()
self.task = None
self.tcp_client = None
def _stop_io_loop(self):
if self.io_loop is not None:
self.io_loop.stop()
# TODO: timeout inflight sessions
def close(self):
if self._closing:
return
self._closing = True
if self._stream is not None:
self._stream.close()
if self.tcp_client is not None:
self.tcp_client.close()
if self.task is not None:
self.task.cancel()
self.io_loop.add_timeout(1, self.check_close)
async def check_close(self):
@tornado.gen.coroutine
def check_close(self):
if not self.send_future_map:
self._tcp_client.close()
self._stream = None
@ -803,14 +793,12 @@ class MessageClient:
# pylint: disable=W1701
def __del__(self):
if not self._closing:
warnings.warn(
"unclosed message client {self!r}", ResourceWarning, source=self
)
self.close()
# pylint: enable=W1701
async def getstream(self, **kwargs):
@tornado.gen.coroutine
def getstream(self, **kwargs):
if self.source_ip or self.source_port:
kwargs = {
"source_ip": self.source_ip,
@ -819,20 +807,12 @@ class MessageClient:
stream = None
while stream is None and (not self._closed and not self._closing):
try:
if self.host and self.port:
stream = await self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs,
)
else:
sock_type = socket.AF_UNIX
path = self.url.replace("ipc://", "")
stream = tornado.iostream.IOStream(
socket.socket(sock_type, socket.SOCK_STREAM)
)
await stream.connect(path)
stream = yield self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs
)
except Exception as exc: # pylint: disable=broad-except
log.warning(
"TCP Message Client encountered an exception while connecting to"
@ -842,24 +822,26 @@ class MessageClient:
exc,
self.backoff,
)
await asyncio.sleep(self.backoff)
return stream
yield tornado.gen.sleep(self.backoff)
raise tornado.gen.Return(stream)
async def connect(self):
@tornado.gen.coroutine
def connect(self):
if self._stream is None:
self._stream = await self.getstream()
self._stream = yield self.getstream()
if self._stream:
if not self._stream_return_running:
self.task = asyncio.create_task(self._stream_return())
self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
async def _stream_return(self):
@tornado.gen.coroutine
def _stream_return(self):
self._stream_return_running = True
unpacker = salt.utils.msgpack.Unpacker()
while not self._closing:
try:
wire_bytes = await self._stream.read_bytes(4096, partial=True)
wire_bytes = yield self._stream.read_bytes(4096, partial=True)
unpacker.feed(wire_bytes)
for framed_msg in unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
@ -873,7 +855,6 @@ class MessageClient:
else:
if self._on_recv is not None:
self.io_loop.spawn_callback(self._on_recv, header, body)
# await self._on_recv(header, body)
else:
log.error(
"Got response for message_id %s that we are not"
@ -881,7 +862,7 @@ class MessageClient:
message_id,
)
except tornado.iostream.StreamClosedError as e:
log.error(
log.debug(
"tcp stream to %s:%s closed, unable to recv",
self.host,
self.port,
@ -898,7 +879,7 @@ class MessageClient:
if stream:
stream.close()
unpacker = salt.utils.msgpack.Unpacker()
await self.connect()
yield self.connect()
except TypeError:
# This is an invalid transport
if "detect_mode" in self.opts:
@ -922,7 +903,7 @@ class MessageClient:
if stream:
stream.close()
unpacker = salt.utils.msgpack.Unpacker()
await self.connect()
yield self.connect()
self._stream_return_running = False
def _message_id(self):
@ -944,7 +925,14 @@ class MessageClient:
"""
Register a callback for received messages (that we didn't initiate)
"""
self._on_recv = callback
if callback is None:
self._on_recv = callback
else:
def wrap_recv(header, body):
callback(body)
self._on_recv = wrap_recv
def remove_message_timeout(self, message_id):
if message_id not in self.send_timeout_map:
@ -959,13 +947,13 @@ class MessageClient:
if future is not None:
future.set_exception(SaltReqTimeoutError("Message timed out"))
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
@tornado.gen.coroutine
def send(self, msg, timeout=None, callback=None, raw=False):
if self._closing:
raise ClosingError()
while not self._stream:
await asyncio.sleep(0.03)
message_id = self._message_id()
header = {"mid": message_id}
future = tornado.concurrent.Future()
if callback is not None:
@ -986,58 +974,18 @@ class MessageClient:
item = salt.transport.frame.frame_msg(msg, header=header)
async def _do_send():
await self.connect()
@tornado.gen.coroutine
def _do_send():
yield self.connect()
# If the _stream is None, we failed to connect.
if self._stream:
await self._stream.write(item)
yield self._stream.write(item)
# Run send in a callback so we can wait on the future, in case we time
# out before we are able to connect.
self.io_loop.add_callback(_do_send)
recv = await future
return recv
async def recv(self, timeout=None):
try:
await self._read_in_progress.acquire(timeout=0.00000001)
except tornado.gen.TimeoutError:
log.error("Timeout Error")
return
try:
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
poller = select.poll()
poller.register(self._stream.socket, select.POLLIN)
try:
events = poller.poll(0)
except TimeoutError:
events = []
if events:
while True:
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
else:
return
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
finally:
self._read_in_progress.release()
recv = yield future
raise tornado.gen.Return(recv)
class Subscriber:

View file

@ -22,11 +22,7 @@ def maintenence(maintenence_opts):
"""
The master's Maintenence class
"""
maintenence = salt.master.Maintenance(maintenence_opts)
try:
yield maintenence
finally:
pass
return salt.master.Maintenance(maintenence_opts)
@pytest.fixture