Fix up on_recv logic

This commit is contained in:
Daniel A. Wozniak 2023-07-02 16:07:26 -07:00 committed by Gareth J. Greenaway
parent d92df14ecb
commit 5540fd8111
2 changed files with 58 additions and 13 deletions

View file

@ -250,6 +250,7 @@ class TCPPubClient(salt.transport.base.PublishClient):
self.source_port = self.opts.get("source_publish_port")
self.connect_callback = None
self.disconnect_callback = None
self.on_recv_task = None
if self.host is None and self.port is None:
if self.path is None:
raise Exception("A host and port or a path must be provided")
@ -261,6 +262,9 @@ class TCPPubClient(salt.transport.base.PublishClient):
if self._closing:
return
self._closing = True
if self.on_recv_task:
self.on_recv_task.cancel()
self.on_recv_task = None
if self._stream is not None:
self._stream.close()
self._stream = None
@ -294,12 +298,14 @@ class TCPPubClient(salt.transport.base.PublishClient):
ssl_options=self.opts.get("ssl"),
**kwargs,
)
log.error("PubClient conencted to %r %r:%r", self, self.host, self.port)
else:
sock_type = socket.AF_UNIX
stream = tornado.iostream.IOStream(
socket.socket(sock_type, socket.SOCK_STREAM)
)
await stream.connect(self.path)
log.error("PubClient conencted to %r %r", self, self.path)
self.poller = select.poll()
self.poller.register(stream.socket, select.POLLIN)
except Exception as exc: # pylint: disable=broad-except
@ -317,6 +323,8 @@ class TCPPubClient(salt.transport.base.PublishClient):
async def _connect(self):
if self._stream is None:
self._closing = False
self._closed = False
self._stream = await self.getstream()
if self._stream:
# if not self._stream_return_running:
@ -353,13 +361,14 @@ class TCPPubClient(salt.transport.base.PublishClient):
await self._stream.send(msg)
async def recv(self, timeout=None):
if not self._stream:
log.error("PubClient recv called")
while self._stream is None:
await self.connect()
await asyncio.sleep(0.001)
return
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
poller = select.poll()
poller.register(self._stream.socket, select.POLLIN)
@ -368,12 +377,14 @@ class TCPPubClient(salt.transport.base.PublishClient):
except TimeoutError:
events = []
if events:
while True:
while not self._closing:
await self._read_in_progress.acquire()
try:
byts = await self._stream.read_bytes(4096, partial=True)
log.error("PUBCLIENT GOT BYTES %r", byts)
except tornado.iostream.StreamClosedError:
self.close()
await self.connect()
return
except Exception:
raise
@ -382,33 +393,39 @@ class TCPPubClient(salt.transport.base.PublishClient):
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
elif timeout:
try:
return await asyncio.wait_for(self.recv(), timeout=timeout)
except asyncio.exceptions.TimeoutError:
except (TimeoutError, asyncio.exceptions.TimeoutError, asyncio.exceptions.CancelledError):
self.close()
await self.connect()
return
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
while True:
while not self._closing:
await self._read_in_progress.acquire()
try:
byts = await self._stream.read_bytes(4096, partial=True)
log.error("PUBCLIENT GOT BYTES %r", byts)
except tornado.iostream.StreamClosedError:
self.close()
await self.connect()
continue
except Exception:
raise
#except AttributeError:
# return
#except Exception:
# raise
finally:
self._read_in_progress.release()
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("PUBCLIENT GOT %r", framed_msg["body"])
return framed_msg["body"]
async def on_recv_handler(self, callback):
@ -416,7 +433,10 @@ class TCPPubClient(salt.transport.base.PublishClient):
await asyncio.sleep(0.003)
while True:
try:
log.error("On recv handler %r", self)
msg = await self.recv()
if msg:
callback(msg)
except tornado.iostream.StreamClosedError:
log.trace("Stream closed, reconnecting.")
self._stream.close()
@ -428,13 +448,17 @@ class TCPPubClient(salt.transport.base.PublishClient):
continue
except Exception: # py-lint: disable=broad-except
log.error("Unhandled exception in on_recv handler.", exc_info=True)
callback(msg)
def on_recv(self, callback):
"""
Register a callback for received messages (that we didn't initiate)
"""
self.io_loop.spawn_callback(self.on_recv_handler, callback)
if self.on_recv_task:
self.on_recv_task.cancel()
if callback is None:
self.on_recv_task = None
else:
self.on_recv_task = asyncio.create_task(self.on_recv_handler(callback))
def __enter__(self):
return self
@ -775,7 +799,9 @@ class MessageClient:
# pylint: disable=W1701
def __del__(self):
if not self._closing:
warnings.warn("%r not closed", self)
warnings.warn(
"unclosed message client {self!r}", ResourceWarning, source=self
)
# pylint: enable=W1701
@ -1040,7 +1066,9 @@ class Subscriber:
# pylint: disable=W1701
def __del__(self):
if not self._closing:
warnings.warn("%r not closed", self)
warnings.warn(
"unclosed publish subscriber {self!r}", ResourceWarning, source=self
)
# pylint: enable=W1701
@ -1116,6 +1144,7 @@ class PubServer(tornado.tcpserver.TCPServer):
log.trace(
"TCP PubServer sending payload: topic_list=%r %r", topic_list, package
)
log.error("PUBLISH PAYLOAD %r", package)
payload = salt.transport.frame.frame_msg(package)
to_remove = []
if topic_list:
@ -1135,6 +1164,7 @@ class PubServer(tornado.tcpserver.TCPServer):
else:
for client in self.clients:
try:
log.error("PUBLISH CLIENT %r", package)
# Write the packed str
await client.stream.write(payload)
except tornado.iostream.StreamClosedError:
@ -1295,7 +1325,9 @@ class TCPPuller:
# pylint: disable=W1701
def __del__(self):
if not self._closing:
warnings.warn("%r not closed", self)
warnings.warn(
"unclosed tcp puller {self!r}", ResourceWarning, source=self
)
# pylint: enable=W1701
@ -1436,6 +1468,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
"""
Publish "load" to minions
"""
log.error("PUBLISH %r", payload)
if not self.pub_sock:
self.connect()
self.pub_sock.send(payload)
@ -1605,7 +1638,9 @@ class _TCPPubServerPublisher:
# pylint: disable=W1701
def __del__(self):
if not self._closing:
warnings.warn("%r not closed", self)
warnings.warn(
"unclosed publisher client {self!r}", ResourceWarning, source=self
)
# pylint: enable=W1701

View file

@ -1,9 +1,13 @@
import asyncio
import time
import logging
import pytest
import salt.utils.event
from salt.netapi.rest_tornado import saltnado
from tests.support.events import eventpublisher_process
log = logging.getLogger(__name__)
def _check_skip(grains):
if grains["os"] == "MacOS":
@ -40,6 +44,7 @@ async def test_simple(sock_dir):
{}, # we don't use mod_opts, don't save?
{"sock_dir": sock_dir, "transport": "zeromq"},
)
await asyncio.sleep(1)
event_future = event_listener.get_event(
request, "evt1"
) # get an event future
@ -65,6 +70,7 @@ async def test_set_event_handler(sock_dir):
{}, # we don't use mod_opts, don't save?
{"sock_dir": sock_dir, "transport": "zeromq"},
)
await asyncio.sleep(1)
event_future = event_listener.get_event(
request,
tag="evt",
@ -88,6 +94,7 @@ async def test_timeout(sock_dir):
{}, # we don't use mod_opts, don't save?
{"sock_dir": sock_dir, "transport": "zeromq"},
)
await asyncio.sleep(1)
event_future = event_listener.get_event(
request,
tag="evt1",
@ -110,13 +117,16 @@ async def test_clean_by_request(sock_dir, io_loop):
"""
with eventpublisher_process(sock_dir):
log.error("After event pubserver start")
with salt.utils.event.MasterEvent(sock_dir) as me:
log.error("After master event start %r", me)
request1 = Request()
request2 = Request()
event_listener = saltnado.EventListener(
{}, # we don't use mod_opts, don't save?
{"sock_dir": sock_dir, "transport": "zeromq"},
)
await asyncio.sleep(1)
assert 0 == len(event_listener.tag_map)
assert 0 == len(event_listener.request_map)