mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Revert changes to MessageClient other than deprecation
This commit is contained in:
parent
13190ff89a
commit
1d1ca7b6cb
2 changed files with 43 additions and 99 deletions
|
@ -733,7 +733,6 @@ class MessageClient:
|
|||
opts,
|
||||
host,
|
||||
port,
|
||||
url=None,
|
||||
io_loop=None,
|
||||
resolver=None,
|
||||
connect_callback=None,
|
||||
|
@ -748,7 +747,6 @@ class MessageClient:
|
|||
self.opts = opts
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = url
|
||||
self.source_ip = source_ip
|
||||
self.source_port = source_port
|
||||
self.connect_callback = connect_callback
|
||||
|
@ -771,28 +769,20 @@ class MessageClient:
|
|||
self._stream = None
|
||||
|
||||
self.backoff = opts.get("tcp_reconnect_backoff", 1)
|
||||
self.callbacks = {}
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
self._read_in_progress = Lock()
|
||||
self.task = None
|
||||
self.tcp_client = None
|
||||
|
||||
def _stop_io_loop(self):
|
||||
if self.io_loop is not None:
|
||||
self.io_loop.stop()
|
||||
|
||||
# TODO: timeout inflight sessions
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
if self._stream is not None:
|
||||
self._stream.close()
|
||||
if self.tcp_client is not None:
|
||||
self.tcp_client.close()
|
||||
if self.task is not None:
|
||||
self.task.cancel()
|
||||
self.io_loop.add_timeout(1, self.check_close)
|
||||
|
||||
async def check_close(self):
|
||||
@tornado.gen.coroutine
|
||||
def check_close(self):
|
||||
if not self.send_future_map:
|
||||
self._tcp_client.close()
|
||||
self._stream = None
|
||||
|
@ -803,14 +793,12 @@ class MessageClient:
|
|||
|
||||
# pylint: disable=W1701
|
||||
def __del__(self):
|
||||
if not self._closing:
|
||||
warnings.warn(
|
||||
"unclosed message client {self!r}", ResourceWarning, source=self
|
||||
)
|
||||
self.close()
|
||||
|
||||
# pylint: enable=W1701
|
||||
|
||||
async def getstream(self, **kwargs):
|
||||
@tornado.gen.coroutine
|
||||
def getstream(self, **kwargs):
|
||||
if self.source_ip or self.source_port:
|
||||
kwargs = {
|
||||
"source_ip": self.source_ip,
|
||||
|
@ -819,20 +807,12 @@ class MessageClient:
|
|||
stream = None
|
||||
while stream is None and (not self._closed and not self._closing):
|
||||
try:
|
||||
if self.host and self.port:
|
||||
stream = await self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
sock_type = socket.AF_UNIX
|
||||
path = self.url.replace("ipc://", "")
|
||||
stream = tornado.iostream.IOStream(
|
||||
socket.socket(sock_type, socket.SOCK_STREAM)
|
||||
)
|
||||
await stream.connect(path)
|
||||
stream = yield self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs
|
||||
)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
log.warning(
|
||||
"TCP Message Client encountered an exception while connecting to"
|
||||
|
@ -842,24 +822,26 @@ class MessageClient:
|
|||
exc,
|
||||
self.backoff,
|
||||
)
|
||||
await asyncio.sleep(self.backoff)
|
||||
return stream
|
||||
yield tornado.gen.sleep(self.backoff)
|
||||
raise tornado.gen.Return(stream)
|
||||
|
||||
async def connect(self):
|
||||
@tornado.gen.coroutine
|
||||
def connect(self):
|
||||
if self._stream is None:
|
||||
self._stream = await self.getstream()
|
||||
self._stream = yield self.getstream()
|
||||
if self._stream:
|
||||
if not self._stream_return_running:
|
||||
self.task = asyncio.create_task(self._stream_return())
|
||||
self.io_loop.spawn_callback(self._stream_return)
|
||||
if self.connect_callback:
|
||||
self.connect_callback(True)
|
||||
|
||||
async def _stream_return(self):
|
||||
@tornado.gen.coroutine
|
||||
def _stream_return(self):
|
||||
self._stream_return_running = True
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
while not self._closing:
|
||||
try:
|
||||
wire_bytes = await self._stream.read_bytes(4096, partial=True)
|
||||
wire_bytes = yield self._stream.read_bytes(4096, partial=True)
|
||||
unpacker.feed(wire_bytes)
|
||||
for framed_msg in unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
|
||||
|
@ -873,7 +855,6 @@ class MessageClient:
|
|||
else:
|
||||
if self._on_recv is not None:
|
||||
self.io_loop.spawn_callback(self._on_recv, header, body)
|
||||
# await self._on_recv(header, body)
|
||||
else:
|
||||
log.error(
|
||||
"Got response for message_id %s that we are not"
|
||||
|
@ -881,7 +862,7 @@ class MessageClient:
|
|||
message_id,
|
||||
)
|
||||
except tornado.iostream.StreamClosedError as e:
|
||||
log.error(
|
||||
log.debug(
|
||||
"tcp stream to %s:%s closed, unable to recv",
|
||||
self.host,
|
||||
self.port,
|
||||
|
@ -898,7 +879,7 @@ class MessageClient:
|
|||
if stream:
|
||||
stream.close()
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
await self.connect()
|
||||
yield self.connect()
|
||||
except TypeError:
|
||||
# This is an invalid transport
|
||||
if "detect_mode" in self.opts:
|
||||
|
@ -922,7 +903,7 @@ class MessageClient:
|
|||
if stream:
|
||||
stream.close()
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
await self.connect()
|
||||
yield self.connect()
|
||||
self._stream_return_running = False
|
||||
|
||||
def _message_id(self):
|
||||
|
@ -944,7 +925,14 @@ class MessageClient:
|
|||
"""
|
||||
Register a callback for received messages (that we didn't initiate)
|
||||
"""
|
||||
self._on_recv = callback
|
||||
if callback is None:
|
||||
self._on_recv = callback
|
||||
else:
|
||||
|
||||
def wrap_recv(header, body):
|
||||
callback(body)
|
||||
|
||||
self._on_recv = wrap_recv
|
||||
|
||||
def remove_message_timeout(self, message_id):
|
||||
if message_id not in self.send_timeout_map:
|
||||
|
@ -959,13 +947,13 @@ class MessageClient:
|
|||
if future is not None:
|
||||
future.set_exception(SaltReqTimeoutError("Message timed out"))
|
||||
|
||||
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
|
||||
@tornado.gen.coroutine
|
||||
def send(self, msg, timeout=None, callback=None, raw=False):
|
||||
if self._closing:
|
||||
raise ClosingError()
|
||||
while not self._stream:
|
||||
await asyncio.sleep(0.03)
|
||||
message_id = self._message_id()
|
||||
header = {"mid": message_id}
|
||||
|
||||
future = tornado.concurrent.Future()
|
||||
|
||||
if callback is not None:
|
||||
|
@ -986,58 +974,18 @@ class MessageClient:
|
|||
|
||||
item = salt.transport.frame.frame_msg(msg, header=header)
|
||||
|
||||
async def _do_send():
|
||||
await self.connect()
|
||||
@tornado.gen.coroutine
|
||||
def _do_send():
|
||||
yield self.connect()
|
||||
# If the _stream is None, we failed to connect.
|
||||
if self._stream:
|
||||
await self._stream.write(item)
|
||||
yield self._stream.write(item)
|
||||
|
||||
# Run send in a callback so we can wait on the future, in case we time
|
||||
# out before we are able to connect.
|
||||
self.io_loop.add_callback(_do_send)
|
||||
recv = await future
|
||||
return recv
|
||||
|
||||
async def recv(self, timeout=None):
|
||||
try:
|
||||
await self._read_in_progress.acquire(timeout=0.00000001)
|
||||
except tornado.gen.TimeoutError:
|
||||
log.error("Timeout Error")
|
||||
return
|
||||
try:
|
||||
if timeout == 0:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
poller = select.poll()
|
||||
poller.register(self._stream.socket, select.POLLIN)
|
||||
try:
|
||||
events = poller.poll(0)
|
||||
except TimeoutError:
|
||||
events = []
|
||||
if events:
|
||||
while True:
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
else:
|
||||
return
|
||||
elif timeout:
|
||||
return await asyncio.wait_for(self.recv(), timeout=timeout)
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
while True:
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
finally:
|
||||
self._read_in_progress.release()
|
||||
recv = yield future
|
||||
raise tornado.gen.Return(recv)
|
||||
|
||||
|
||||
class Subscriber:
|
||||
|
|
|
@ -22,11 +22,7 @@ def maintenence(maintenence_opts):
|
|||
"""
|
||||
The master's Maintenence class
|
||||
"""
|
||||
maintenence = salt.master.Maintenance(maintenence_opts)
|
||||
try:
|
||||
yield maintenence
|
||||
finally:
|
||||
pass
|
||||
return salt.master.Maintenance(maintenence_opts)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
Loading…
Add table
Reference in a new issue