Fix warts in TCPPubClient recv

This commit is contained in:
Daniel A. Wozniak 2023-07-01 00:47:57 -07:00 committed by Gareth J. Greenaway
parent e0bea13bf2
commit 4e22161bee
3 changed files with 46 additions and 44 deletions

View file

@ -239,7 +239,7 @@ class TCPPubClient(salt.transport.base.PublishClient):
self._closed = False
self.backoff = opts.get("tcp_reconnect_backoff", 1)
self.resolver = kwargs.get("resolver")
self._read_in_progress = Lock()
self._read_in_progress = asyncio.Lock()
self.poller = None
self.host = kwargs.get("host", None)
@ -303,7 +303,7 @@ class TCPPubClient(salt.transport.base.PublishClient):
self.poller.register(stream.socket, select.POLLIN)
except Exception as exc: # pylint: disable=broad-except
log.warning(
"TCP Message Client encountered an exception while connecting to"
"TCP Publish Client encountered an exception while connecting to"
" %s:%s %s: %r, will reconnect in %d seconds",
self.host,
self.port,
@ -352,46 +352,46 @@ class TCPPubClient(salt.transport.base.PublishClient):
await self._stream.send(msg)
async def recv(self, timeout=None):
try:
await self._read_in_progress.acquire(timeout=0.00000001)
except tornado.gen.TimeoutError:
log.error("Timeout Error")
if not self._stream:
await asyncio.sleep(0.001)
return
try:
if timeout == 0:
if not self._stream:
await asyncio.sleep(0.001)
return
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
poller = select.poll()
poller.register(self._stream.socket, select.POLLIN)
try:
events = poller.poll(0)
except TimeoutError:
events = []
if events:
while True:
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
poller = select.poll()
poller.register(self._stream.socket, select.POLLIN)
try:
events = poller.poll(0)
except TimeoutError:
events = []
if events:
while True:
byts = await self._stream.read_bytes(4096, partial=True)
await self._read_in_progress.acquire()
try:
byts = await self._stream.read_bytes(4096, partial=True)
finally:
self._read_in_progress.release()
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
finally:
self._read_in_progress.release()
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
await self._read_in_progress.acquire()
try:
byts = await self._stream.read_bytes(4096, partial=True)
finally:
self._read_in_progress.release()
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
async def on_recv_handler(self, callback):
while not self._stream:
@ -399,7 +399,6 @@ class TCPPubClient(salt.transport.base.PublishClient):
while True:
try:
msg = await self.recv()
logit = True
except tornado.iostream.StreamClosedError:
log.trace("Stream closed, reconnecting.")
self._stream.close()

View file

@ -431,6 +431,7 @@ class SaltEvent:
kwargs={
"pub_path": self.puburi,
"pull_path": self.pulluri,
"transport": "tcp",
},
)
try:
@ -448,7 +449,10 @@ class SaltEvent:
else:
if self.pusher is None:
self.pusher = salt.transport.publish_server(
self.opts, pub_path=self.puburi, pull_path=self.pulluri
self.opts,
pub_path=self.puburi,
pull_path=self.pulluri,
transport="tcp",
)
# For the asynchronous case, the connect will be deferred to when
# fire_event() is invoked.

View file

@ -21,12 +21,11 @@ def eventpublisher_process(sock_dir):
"ipv6": None,
"zmq_filtering": None,
}
ipc_publisher = salt.transport.publish_server(opts)
ipc_publisher.pub_uri = "ipc://{}".format(
os.path.join(opts["sock_dir"], "master_event_pub.ipc")
)
ipc_publisher.pull_uri = "ipc://{}".format(
os.path.join(opts["sock_dir"], "master_event_pull.ipc")
ipc_publisher = salt.transport.publish_server(
opts,
pub_path=os.path.join(opts["sock_dir"], "master_event_pub.ipc"),
pull_path=os.path.join(opts["sock_dir"], "master_event_pull.ipc"),
transport="tcp",
)
proc = Process(
target=ipc_publisher.publish_daemon,