Preserve futures with messages

This also make the message client behave more like it does after having
been refactored to use native asyncio coroutines.
This commit is contained in:
Daniel A. Wozniak 2023-09-27 18:00:41 -07:00 committed by Pedro Algarvio
parent f6ad1a5f94
commit 8a872eff08

View file

@ -12,6 +12,7 @@ from random import randint
import zmq.error
import zmq.eventloop.zmqstream
import zmq.eventloop.future
import salt.ext.tornado
import salt.ext.tornado.concurrent
@ -513,17 +514,17 @@ class AsyncReqMessageClient:
else:
self.io_loop = io_loop
self.context = zmq.Context()
self.context = zmq.eventloop.future.Context()
self.send_queue = []
self._closing = False
self._future = None
self._send_future_map = {}
self.lock = salt.ext.tornado.locks.Lock()
self.ident = threading.get_ident()
def connect(self):
if hasattr(self, "stream"):
if hasattr(self, "socket") and self.socket:
return
# wire up sockets
self._init_socket()
@ -539,24 +540,10 @@ class AsyncReqMessageClient:
return
else:
self._closing = True
if hasattr(self, "stream") and self.stream is not None:
if ZMQ_VERSION_INFO < (14, 3, 0):
# stream.close() doesn't work properly on pyzmq < 14.3.0
if self.stream.socket:
self.stream.socket.close()
self.stream.io_loop.remove_handler(self.stream.socket)
# set this to None, more hacks for messed up pyzmq
self.stream.socket = None
self.socket.close()
else:
self.stream.close(1)
self.socket = None
self.stream = None
if self._future:
self._future.set_exception(SaltException("Closing connection"))
self._future = None
if hasattr(self, "socket") and self.socket is not None:
self.socket.close(0)
self.socket = None
if self.context.closed is False:
# This hangs if closing the stream causes an import error
self.context.term()
def _init_socket(self):
@ -573,11 +560,8 @@ class AsyncReqMessageClient:
self.socket.setsockopt(zmq.IPV6, 1)
elif hasattr(zmq, "IPV4ONLY"):
self.socket.setsockopt(zmq.IPV4ONLY, 0)
self.socket.linger = self.linger
self.socket.setsockopt(zmq.LINGER, self.linger)
self.socket.connect(self.addr)
self.stream = zmq.eventloop.zmqstream.ZMQStream(
self.socket, io_loop=self.io_loop
)
@salt.ext.tornado.gen.coroutine
def send(self, message, timeout=None, callback=None):
@ -599,27 +583,30 @@ class AsyncReqMessageClient:
if self.opts.get("detect_mode") is True:
timeout = 1
def timeout_message(future):
if not future.done():
future.set_exception(SaltReqTimeoutError("Message timed out"))
if timeout is not None:
send_timeout = self.io_loop.call_later(
timeout, timeout_message, future
timeout, self._timeout_message, future
)
def mark_future(msg):
if not future.done():
data = salt.payload.loads(msg[0])
future.set_result(data)
self.io_loop.spawn_callback(self._send_recv, message, future)
with (yield self.lock.acquire()):
self.stream.on_recv(mark_future)
yield self.stream.send(message)
recv = yield future
recv = yield future
raise salt.ext.tornado.gen.Return(recv)
def _timeout_message(self, future):
if not future.done():
future.set_exception(SaltReqTimeoutError("Message timed out"))
@salt.ext.tornado.gen.coroutine
def _send_recv(self, message, future):
with (yield self.lock.acquire()):
yield self.socket.send(message)
recv = yield self.socket.recv()
if not future.done():
data = salt.payload.loads(recv)
future.set_result(data)
class ZeroMQSocketMonitor:
__EVENT_MAP = None