From 24257072bbd1a1268d1e4f52a3c26853eedb840f Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Fri, 18 Aug 2023 14:34:48 -0700 Subject: [PATCH] wip --- salt/channel/client.py | 11 ++--- salt/netapi/rest_tornado/saltnado.py | 4 +- salt/transport/tcp.py | 18 ++++++- salt/transport/ws.py | 12 +++-- salt/transport/zeromq.py | 49 +++++++------------ salt/utils/event.py | 6 +++ .../pytests/functional/channel/test_server.py | 15 +++--- .../transport/tcp/test_message_client.py | 5 +- .../rest_tornado/test_minions_api_handler.py | 1 + .../unit/transport/test_publish_client.py | 2 +- 10 files changed, 67 insertions(+), 56 deletions(-) diff --git a/salt/channel/client.py b/salt/channel/client.py index 49562ce6fae..4368beb73df 100644 --- a/salt/channel/client.py +++ b/salt/channel/client.py @@ -478,13 +478,12 @@ class AsyncPubChannel: if callback is None: return self.transport.on_recv(None) - @tornado.gen.coroutine - def wrap_callback(messages): + async def wrap_callback(messages): payload = self.transport._decode_messages(messages) - decoded = yield self._decode_payload(payload) - log.debug("PubChannel received: %r", decoded) - if decoded is not None: - callback(decoded) + decoded = await self._decode_payload(payload) + log.debug("PubChannel received: %r %r", decoded, callback) + if decoded is not None and callback is not None: + await callback(decoded) return self.transport.on_recv(wrap_callback) diff --git a/salt/netapi/rest_tornado/saltnado.py b/salt/netapi/rest_tornado/saltnado.py index a57986d48b1..3ef0a412509 100644 --- a/salt/netapi/rest_tornado/saltnado.py +++ b/salt/netapi/rest_tornado/saltnado.py @@ -1100,7 +1100,9 @@ class SaltAPIHandler(BaseSaltAPIHandler): # pylint: disable=W0223 minions, is_finished, ) - + print("$" * 80) + print(f"Get minion returns {events!r}") + print("$" * 80) result = yield self.get_minion_returns( events=events, is_finished=is_finished, diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index 8998ab4591f..eed00e25def 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -353,8 +353,9 @@ class PublishClient(salt.transport.base.PublishClient): if not isinstance(messages, dict): # TODO: For some reason we need to decode here for things # to work. Fix this. - body = salt.utils.msgpack.loads(messages) - body = salt.transport.frame.decode_embedded_strs(body) + body = salt.payload.loads(messages) + #body = salt.utils.msgpack.loads(messages) + #body = salt.transport.frame.decode_embedded_strs(body) else: body = messages return body @@ -368,6 +369,9 @@ class PublishClient(salt.transport.base.PublishClient): await asyncio.sleep(0.001) if timeout == 0: for msg in self.unpacker: + print("^" * 80) + print(f"RECV {msg!r}") + print("^" * 80) return msg[b"body"] try: events, _, _ = select.select([self._stream.socket], [], [], 0) @@ -389,6 +393,9 @@ class PublishClient(salt.transport.base.PublishClient): return self.unpacker.feed(byts) for msg in self.unpacker: + print("^" * 80) + print(f"RECV {msg!r}") + print("^" * 80) return msg[b"body"] elif timeout: try: @@ -403,6 +410,9 @@ class PublishClient(salt.transport.base.PublishClient): return else: for msg in self.unpacker: + print("^" * 80) + print(f"RECV {msg!r}") + print("^" * 80) return msg[b"body"] while not self._closing: async with self._read_in_progress: @@ -420,6 +430,9 @@ class PublishClient(salt.transport.base.PublishClient): continue self.unpacker.feed(byts) for msg in self.unpacker: + print("^" * 80) + print(f"RECV {msg!r}") + print("^" * 80) return msg[b"body"] async def on_recv_handler(self, callback): @@ -427,6 +440,7 @@ class PublishClient(salt.transport.base.PublishClient): # Retry quickly, we may want to increase this if it's hogging cpu. await asyncio.sleep(0.003) while True: + print("On RECV READ") msg = await self.recv() if msg: try: diff --git a/salt/transport/ws.py b/salt/transport/ws.py index 3522c5b4cb9..4a85336a5db 100644 --- a/salt/transport/ws.py +++ b/salt/transport/ws.py @@ -102,6 +102,13 @@ class PublishClient(salt.transport.base.PublishClient): ) # pylint: enable=W1701 + def _decode_messages(self, messages): + if not isinstance(messages, dict): + body =salt.payload.loads(messages) + else: + body = messages + return body + async def getstream(self, **kwargs): if self.source_ip or self.source_port: @@ -327,7 +334,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): except (KeyboardInterrupt, SystemExit): pass finally: - print("CLOSE") self.close() async def publisher( @@ -364,9 +370,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): await runner.setup() site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx) log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port) - print("start site") await site.start() - print("start puller") self._pub_payload = publish_payload if self.pull_path: @@ -378,14 +382,12 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): self.puller = await asyncio.start_server( self.pull_handler, self.pull_host, self.pull_port ) - print("puller started") while self._run.is_set(): await asyncio.sleep(0.3) await self.server.stop() await self.puller.wait_closed() async def pull_handler(self, reader, writer): - print("puller got connection") unpacker = salt.utils.msgpack.Unpacker() while True: data = await reader.read(1024) diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index 1f9704cbcb9..1b8f8ef0f4a 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -223,6 +223,7 @@ class PublishClient(salt.transport.base.PublishClient): elif self.host and self.port: if self.path: raise Exception("A host and port or a path must be provided, not both") + self.on_recv_task = None def close(self): if self._closing is True: @@ -341,44 +342,28 @@ class PublishClient(salt.transport.base.PublishClient): # raise Exception("Send not supported") # await self._socket.send(msg) - def on_recv(self, callback): + async def on_recv_handler(self, callback): + while not self._socket: + # Retry quickly, we may want to increase this if it's hogging cpu. + await asyncio.sleep(0.003) + while True: + msg = await self.recv() + if msg: + await callback(msg) + def on_recv(self, callback): """ Register a callback for received messages (that we didn't initiate) - - :param func callback: A function which should be called when data is received """ + if self.on_recv_task: + # XXX: We are not awaiting this canceled task. This still needs to + # be addressed. + self.on_recv_task.cancel() if callback is None: - callbacks = self.callbacks - self.callbacks = {} - for callback, (running, task) in callbacks.items(): - running.clear() - return + self.on_recv_task = None + else: + self.on_recv_task = asyncio.create_task(self.on_recv_handler(callback)) - running = asyncio.Event() - running.set() - - async def consume(running): - try: - while running.is_set(): - try: - msg = await self.recv(timeout=None) - except zmq.error.ZMQError as exc: - # We've disconnected just die - break - if msg: - try: - await callback(msg) - except Exception: # pylint: disable=broad-except - log.error("Exception while running callback", exc_info=True) - # log.debug("Callback done %r", callback) - except Exception as exc: # pylint: disable=broad-except - log.error( - "Exception while consuming%s %s", self.uri, exc, exc_info=True - ) - - task = self.io_loop.spawn_callback(consume, running) - self.callbacks[callback] = running, task class RequestServer(salt.transport.base.DaemonizedRequestServer): diff --git a/salt/utils/event.py b/salt/utils/event.py index f66e8dee230..f222e0dfc71 100644 --- a/salt/utils/event.py +++ b/salt/utils/event.py @@ -550,6 +550,9 @@ class SaltEvent: try: if not self.cpub and not self.connect_pub(timeout=wait): break + print("%" * 80) + print(f"get event {wait}") + print("%" * 80) raw = self.subscriber.recv(timeout=wait) if raw is None: break @@ -636,6 +639,9 @@ class SaltEvent: request, it MUST subscribe the result to ensure the response is not lost should other regions of code call get_event for other purposes. """ + print("%" * 80) + print("GET EVENT CALLED") + print("%" * 80) log.trace("Get event. tag: %s", tag) assert self._run_io_loop_sync diff --git a/tests/pytests/functional/channel/test_server.py b/tests/pytests/functional/channel/test_server.py index f254860e813..42cf045a882 100644 --- a/tests/pytests/functional/channel/test_server.py +++ b/tests/pytests/functional/channel/test_server.py @@ -1,3 +1,4 @@ +import asyncio import ctypes import logging import multiprocessing @@ -53,7 +54,8 @@ def transport_ids(value): return f"transport({value})" -@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids) +#@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids) +@pytest.fixture(params=["ws",], ids=transport_ids) def transport(request): return request.param @@ -123,13 +125,12 @@ def master_secrets(): salt.master.SMaster.secrets.pop("aes") -@tornado.gen.coroutine -def _connect_and_publish( +async def _connect_and_publish( io_loop, channel_minion_id, channel, server, received, timeout=60 ): - yield channel.connect() + await channel.connect() - def cb(payload): + async def cb(payload): received.append(payload) io_loop.stop() @@ -139,7 +140,7 @@ def _connect_and_publish( ) start = time.time() while time.time() - start < timeout: - yield tornado.gen.sleep(1) + await asyncio.sleep(1) io_loop.stop() @@ -158,7 +159,7 @@ def test_pub_server_channel( req_server_channel = salt.channel.server.ReqServerChannel.factory(master_config) req_server_channel.pre_fork(process_manager) - def handle_payload(payload): + async def handle_payload(payload): log.debug("Payload handler got %r", payload) req_server_channel.post_fork(handle_payload, io_loop=io_loop) diff --git a/tests/pytests/functional/transport/tcp/test_message_client.py b/tests/pytests/functional/transport/tcp/test_message_client.py index 7dd8dbe1961..7feaab82b24 100644 --- a/tests/pytests/functional/transport/tcp/test_message_client.py +++ b/tests/pytests/functional/transport/tcp/test_message_client.py @@ -76,7 +76,7 @@ async def test_message_client_reconnect(config, client, server): received = [] - def handler(msg): + async def handler(msg): received.append(msg) client.on_recv(handler) @@ -119,5 +119,6 @@ async def test_message_client_reconnect(config, client, server): # Close the client client.close() + # Provide time for the on_recv task to complete - await tornado.gen.sleep(1) + await asyncio.sleep(.3) diff --git a/tests/pytests/integration/netapi/rest_tornado/test_minions_api_handler.py b/tests/pytests/integration/netapi/rest_tornado/test_minions_api_handler.py index 080ba4698da..48321f0397e 100644 --- a/tests/pytests/integration/netapi/rest_tornado/test_minions_api_handler.py +++ b/tests/pytests/integration/netapi/rest_tornado/test_minions_api_handler.py @@ -19,6 +19,7 @@ async def test_get_no_mid(http_client, salt_minion, salt_sub_minion): method="GET", follow_redirects=False, ) + print(f"{response!r}") response_obj = salt.utils.json.loads(response.body) assert len(response_obj["return"]) == 1 assert isinstance(response_obj["return"][0], dict) diff --git a/tests/pytests/unit/transport/test_publish_client.py b/tests/pytests/unit/transport/test_publish_client.py index 266e60997c6..5d796295e70 100644 --- a/tests/pytests/unit/transport/test_publish_client.py +++ b/tests/pytests/unit/transport/test_publish_client.py @@ -276,7 +276,7 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop): async def handler(request): ws = aiohttp.web.WebSocketResponse() await ws.prepare(request) - data = salt.transport.frame.frame_msg(msg, header=None) + data = salt.transport.dumps(msg) await ws.send_bytes(data) server = aiohttp.web.Server(handler)