From 7f73274352d17673e0b214f7bc9e44e253aef4c9 Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Wed, 25 Sep 2024 19:40:31 -0700 Subject: [PATCH] The zmq socket poll method needs to be awaited When using zmq.asyncio.Context, the socket's poll method is a coroutine. --- salt/transport/zeromq.py | 2 +- .../scenarios/transport/test_zeromq.py | 82 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 tests/pytests/scenarios/transport/test_zeromq.py diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index fe61cb8808f..19b41f7e273 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -776,7 +776,7 @@ class ZeroMQSocketMonitor: async def consume(self): while self._running.is_set(): try: - if self._monitor_socket.poll(): + if await self._monitor_socket.poll(): msg = await self._monitor_socket.recv_multipart() self.monitor_callback(msg) else: diff --git a/tests/pytests/scenarios/transport/test_zeromq.py b/tests/pytests/scenarios/transport/test_zeromq.py new file mode 100644 index 00000000000..35157c3e26e --- /dev/null +++ b/tests/pytests/scenarios/transport/test_zeromq.py @@ -0,0 +1,82 @@ +import asyncio +import logging +import multiprocessing +import time + +import pytest + +try: + import zmq + + import salt.transport.zeromq +except ImportError: + zmq = None + + +log = logging.getLogger(__name__) + + +def clients(recieved): + """ + Fire up 1000 publish socket clients and wait for a message. + """ + log.debug("Clients start") + context = zmq.asyncio.Context() + sockets = {} + for i in range(1000): + socket = context.socket(zmq.SUB) + socket.connect("tcp://127.0.0.1:5406") + socket.setsockopt(zmq.SUBSCRIBE, b"") + sockets[i] = socket + log.debug("Clients connected") + + async def check(): + start = time.time() + while time.time() - start < 60: + n = 0 + for i in list(sockets): + if await sockets[i].poll(): + msg = await sockets[i].recv() + n += 1 + log.debug( + "Client %d got message %s total %d", i, msg, recieved.value + ) + sockets[i].close(0) + sockets.pop(i) + with recieved.get_lock(): + recieved.value += n + await asyncio.sleep(0.3) + + asyncio.run(check()) + + +@pytest.mark.skipif(not zmq, reason="Zeromq not installed") +def test_issue_regression_65265(): + """ + Regression test for 65265. This test will not fail 100% of the time prior + to the fix for 65265. However, it does pass reliably with the issue fixed. + """ + recieved = multiprocessing.Value("i", 0) + process_manager = salt.utils.process.ProcessManager(wait_for_kill=5) + opts = {"ipv6": False, "zmq_filtering": False, "zmq_backlog": 1000, "pub_hwm": 1000} + process_manager.add_process(clients, args=(recieved,)) + process_manager.add_process(clients, args=(recieved,)) + process_manager.add_process(clients, args=(recieved,)) + # Give some time for all clients to start up before starting server. + time.sleep(10) + server = salt.transport.zeromq.PublishServer( + opts, pub_host="127.0.0.1", pub_port=5406, pull_path="/tmp/pull.ipc" + ) + process_manager.add_process(server.publish_daemon, args=(server.publish_payload,)) + # Wait some more for the server to start up completely. + time.sleep(10) + asyncio.run(server.publish(b"asdf")) + log.debug("After publish") + # Give time for clients to receive thier messages. + time.sleep(10) + try: + with recieved.get_lock(): + total = recieved.value + assert total == 3000 + finally: + process_manager.terminate()