From a2f428e5b3519ae5d178b76b0645d24dd105e6a8 Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Mon, 26 Jun 2023 22:01:43 -0700 Subject: [PATCH] More test fixes --- salt/_logging/handlers.py | 4 +- salt/channel/client.py | 2 +- salt/transport/base.py | 3 +- salt/transport/tcp.py | 86 ++---- salt/transport/zeromq.py | 87 +++--- salt/utils/asynchronous.py | 23 +- tests/pytests/conftest.py | 3 +- .../transport/ipc/test_pub_server_channel.py | 3 +- .../unit/channel/test_request_channel.py | 292 +++++++++++------- .../unit/transport/test_publish_client.py | 9 +- tests/pytests/unit/transport/test_tcp.py | 2 +- tests/pytests/unit/utils/event/test_event.py | 63 ++-- tests/support/events.py | 5 +- 13 files changed, 289 insertions(+), 293 deletions(-) diff --git a/salt/_logging/handlers.py b/salt/_logging/handlers.py index fd4152e482b..fba73c7fc7a 100644 --- a/salt/_logging/handlers.py +++ b/salt/_logging/handlers.py @@ -14,6 +14,7 @@ from collections import deque from salt._logging.mixins import ExcInfoOnLogLevelFormatMixin from salt.utils.versions import warn_until_date + log = logging.getLogger(__name__) @@ -94,9 +95,6 @@ class DeferredStreamHandler(StreamHandler): super().__init__(stream) self.__messages = deque(maxlen=max_queue_size) self.__emitting = False - import traceback - - self.stack = "".join(traceback.format_stack()) def handle(self, record): self.acquire() diff --git a/salt/channel/client.py b/salt/channel/client.py index 6d2c9f14d84..9ea6d22f439 100644 --- a/salt/channel/client.py +++ b/salt/channel/client.py @@ -368,7 +368,7 @@ class AsyncReqChannel: self.close() async def __aenter__(self): - await self.connect() + await self.transport.connect() return self async def __aexit__(self, exc_type, exc, tb): diff --git a/salt/transport/base.py b/salt/transport/base.py index 9938656d37f..a5e3b57512a 100644 --- a/salt/transport/base.py +++ b/salt/transport/base.py @@ -1,4 +1,5 @@ import os + import tornado.gen TRANSPORTS = ( @@ -63,7 +64,7 @@ def publish_server(opts, **kwargs): if "pub_host" not in kwargs and "pub_path" not in kwargs: kwargs["pub_host"] = opts["interface"] if "pub_port" not in kwargs and "pub_path" not in kwargs: - kwargs["pub_port"] = opts["publish_port"] + kwargs["pub_port"] = opts.get("publish_port", 4506) if "pull_host" not in kwargs and "pull_path" not in kwargs: if opts.get("ipc_mode", "") == "tcp": diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index 71a271f8d46..defa7be7c27 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -10,7 +10,6 @@ import asyncio import errno import logging import multiprocessing -import os import queue import select import socket @@ -281,7 +280,6 @@ class TCPPubClient(salt.transport.base.PublishClient): while stream is None and (not self._closed and not self._closing): try: if self.host and self.port: - # log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack) self._tcp_client = TCPClientKeepAlive( self.opts, resolver=self.resolver ) @@ -313,14 +311,8 @@ class TCPPubClient(salt.transport.base.PublishClient): return stream async def _connect(self): - # log.error("Connect %r %r", self, self._stream) - # import traceback - # stack = "".join(traceback.format_stack()) - # log.error(f"MessageClient Connect {name} {stack}") if self._stream is None: - log.error("Get stream") self._stream = await self.getstream() - log.error("Got stream") if self._stream: # if not self._stream_return_running: # self.io_loop.spawn_callback(self._stream_return) @@ -407,9 +399,7 @@ class TCPPubClient(salt.transport.base.PublishClient): """ async def setup_callback(): - logit = True while not self._stream: - # log.error("On recv wait stream %r", self._stream) await asyncio.sleep(0.003) while True: try: @@ -418,9 +408,6 @@ class TCPPubClient(salt.transport.base.PublishClient): except tornado.iostream.StreamClosedError: self._stream.close() self._stream = None - if logit: - log.error("Stream Closed", exc_info=True) - logit = False await self._connect() await asyncio.sleep(0.03) # if self.disconnect_callback: @@ -757,15 +744,6 @@ class MessageClient: if self.io_loop is not None: self.io_loop.stop() - # TODO: timeout inflight sessions - def close(self): - if self._closing: - return - self._closing = True - self.stream.close() - self.tcp_client.close() - self.io_loop.add_timeout(1, self.check_close) - def close(self): if self._closing: return @@ -803,7 +781,6 @@ class MessageClient: while stream is None and (not self._closed and not self._closing): try: if self.host and self.port: - # log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack) stream = await self._tcp_client.connect( ip_bracket(self.host, strip=True), self.port, @@ -811,7 +788,6 @@ class MessageClient: **kwargs, ) else: - log.error("GET STREAM IPC") sock_type = socket.AF_UNIX path = self.url.replace("ipc://", "") stream = tornado.iostream.IOStream( @@ -831,15 +807,8 @@ class MessageClient: return stream async def connect(self): - log.error("Connect %r %r", self, self._stream) - import traceback - - # stack = "".join(traceback.format_stack()) - # log.error(f"MessageClient Connect {name} {stack}") if self._stream is None: - log.error("Get stream") self._stream = await self.getstream() - log.error("Got stream") if self._stream: if not self._stream_return_running: self.task = asyncio.create_task(self._stream_return()) @@ -852,9 +821,7 @@ class MessageClient: unpacker = salt.utils.msgpack.Unpacker() while not self._closing: try: - log.error("Stream read bytes %r", self._stream.socket) wire_bytes = await self._stream.read_bytes(4096, partial=True) - log.error("Stream got bytes") unpacker.feed(wire_bytes) for framed_msg in unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg) @@ -903,9 +870,6 @@ class MessageClient: ) else: raise SaltClientError - # except OSError: - # log.error("OSERROR", exc_info=True) - # raise except Exception as e: # pylint: disable=broad-except log.error("Exception parsing response", exc_info=True) for future in self.send_future_map.values(): @@ -958,11 +922,9 @@ class MessageClient: future.set_exception(SaltReqTimeoutError("Message timed out")) async def send(self, msg, timeout=None, callback=None, raw=False, reply=True): - # log.error("stream send %r %r %r %r", self, self.url, self.port, reply) if self._closing: raise ClosingError() while not self._stream: - # log.error("Wait stream %r %r", self, self._stream) await asyncio.sleep(0.03) message_id = self._message_id() header = {"mid": message_id} @@ -1011,7 +973,6 @@ class MessageClient: try: if timeout == 0: for msg in self.unpacker: - log.error("RECV a") framed_msg = salt.transport.frame.decode_embedded_strs(msg) return framed_msg["body"] poller = select.poll() @@ -1025,27 +986,21 @@ class MessageClient: byts = await self._stream.read_bytes(4096, partial=True) self.unpacker.feed(byts) for msg in self.unpacker: - log.error("RECV b") framed_msg = salt.transport.frame.decode_embedded_strs(msg) return framed_msg["body"] else: return elif timeout: - log.error("RECV c") return await asyncio.wait_for(self.recv(), timeout=timeout) else: for msg in self.unpacker: - log.error("RECV d") framed_msg = salt.transport.frame.decode_embedded_strs(msg) return framed_msg["body"] while True: - log.error("RECV e") byts = await self._stream.read_bytes(4096, partial=True) self.unpacker.feed(byts) for msg in self.unpacker: - log.error("RECV e %r", msg) framed_msg = salt.transport.frame.decode_embedded_strs(msg) - log.error("RECV e %r", framed_msg) return framed_msg["body"] finally: self._read_in_progress.release() @@ -1211,39 +1166,38 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): self.opts = opts self.pub_sock = None # Set up Salt IPC server - #if self.opts.get("ipc_mode", "") == "tcp": + # if self.opts.get("ipc_mode", "") == "tcp": # self.pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514)) - #else: + # else: # self.pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc") - #interface = self.opts.get("interface", "127.0.0.1") - #self.publish_port = self.opts.get("publish_port", 4560) - #self.pub_uri = f"tcp://{interface}:{self.publish_port}" + # interface = self.opts.get("interface", "127.0.0.1") + # self.publish_port = self.opts.get("publish_port", 4560) + # self.pub_uri = f"tcp://{interface}:{self.publish_port}" self.pub_host = kwargs.get("pub_host", None) self.pub_port = kwargs.get("pub_port", None) self.pub_path = kwargs.get("pub_path", None) - #if pub_path: + # if pub_path: # self.pub_path = pub_path # self.pub_uri = f"ipc://{pub_path}" - #else: + # else: # self.pub_uri = f"tcp://{pub_host}:{pub_port}" - #self.publish_port = self.opts.get("publish_port", 4560) - + # self.publish_port = self.opts.get("publish_port", 4560) self.pull_host = kwargs.get("pull_host", None) self.pull_port = kwargs.get("pull_port", None) self.pull_path = kwargs.get("pull_path", None) - #if pull_path: + # if pull_path: # self.pull_uri = f"ipc://{pull_path}" - #else: + # else: # self.pull_uri = f"tcp://{pub_host}:{pub_port}" - #log.error( + # log.error( # "TCPPubServer %r %s %s", # self, # self.pull_uri, # #self.publish_port, # self.pub_uri, - #) + # ) @property def topic_support(self): @@ -1265,13 +1219,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): Bind to the interface specified in the configuration file """ io_loop = tornado.ioloop.IOLoop() - #log.error( + # log.error( # "TCPPubServer daemon %r %s %s %s", # self, # self.pull_uri, # self.publish_port, # self.pub_uri, - #) + # ) # Spin up the publisher self.pub_server = pub_server = PubServer( @@ -1298,13 +1252,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): # else: # pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc") self.pub_server = pub_server - #if "ipc://" in self.pull_uri: + # if "ipc://" in self.pull_uri: # pull_uri = pull_uri = self.pull_uri.replace("ipc://", "") # log.error("WTF PULL URI %r", pull_uri) - #elif "tcp://" in self.pull_uri: + # elif "tcp://" in self.pull_uri: # log.error("Fallback to publish port %r", self.pull_uri) # pull_uri = self.publish_port - #else: + # else: # pull_uri = self.pull_uri if self.pull_path: pull_uri = self.pull_path @@ -1345,7 +1299,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): raise tornado.gen.Return(ret) def connect(self): - #path = self.pull_uri.replace("ipc://", "") + # path = self.pull_uri.replace("ipc://", "") log.error("Connect pusher %s", self.pull_path) # self.pub_sock = salt.utils.asynchronous.SyncWrapper( # salt.transport.ipc.IPCMessageClient, @@ -1417,15 +1371,11 @@ class TCPReqClient(salt.transport.base.RequestClient): ) async def connect(self): - log.error("TCPReqClient Connect") await self.message_client.connect() - log.error("TCPReqClient Connected") async def send(self, load, timeout=60): await self.connect() - log.error("TCP Request %r", load) msg = await self.message_client.send(load, timeout=timeout) - log.error("TCP Reply %r", msg) return msg def close(self): diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index c3ddf74c806..f76fe3be7af 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -222,7 +222,7 @@ class PublishClient(salt.transport.base.PublishClient): if self.path: raise Exception("A host and port or a path must be provided, not both") - async def close(self): + def close(self): if self._closing is True: return self._closing = True @@ -238,21 +238,10 @@ class PublishClient(salt.transport.base.PublishClient): if self.callbacks: for cb in self.callbacks: running, task = self.callbacks[cb] - task.cancel() - - def close(self): - if self._closing is True: - return - self._closing = True - if hasattr(self, "_monitor") and self._monitor is not None: - self._monitor.stop() - self._monitor = None - if hasattr(self, "_stream"): - self._stream.close(0) - elif hasattr(self, "_socket"): - self._socket.close(0) - if hasattr(self, "context") and self.context.closed is False: - self.context.term() + try: + task.cancel() + except RuntimeError: + log.warning("Tasks loop already closed") # pylint: enable=W1701 def __enter__(self): @@ -402,7 +391,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): self._closing = False self._monitor = None self._w_monitor = None - self.task = None + self.tasks = set() self._event = asyncio.Event() def zmq_device(self): @@ -485,6 +474,11 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): self._socket.close() if hasattr(self, "context") and self.context.closed is False: self.context.term() + for task in list(self.tasks): + try: + task.cancel() + except RuntimeError: + log.error("IOLoop closed when trying to cancel task") def pre_fork(self, process_manager): """ @@ -518,8 +512,8 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): :param IOLoop io_loop: An instance of a Tornado IOLoop, to handle event scheduling """ # context = zmq.Context(1) - context = zmq.asyncio.Context(1) - self._socket = context.socket(zmq.REP) + self.context = zmq.asyncio.Context(1) + self._socket = self.context.socket(zmq.REP) # Linger -1 means we'll never discard messages. self._socket.setsockopt(zmq.LINGER, -1) self._start_zmq_monitor() @@ -541,16 +535,23 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): self.message_handler = message_handler async def callback(): - self.task = asyncio.create_task(self.request_handler()) - await self.task + task = asyncio.create_task(self.request_handler()) + task.add_done_callback(self.tasks.discard) + self.tasks.add(task) io_loop.add_callback(callback) async def request_handler(self): while not self._event.is_set(): - request = await self._socket.recv() - reply = await self.handle_message(None, request) - await self._socket.send(self.encode_payload(reply)) + try: + request = await asyncio.wait_for(self._socket.recv(), 1) + reply = await self.handle_message(None, request) + await self._socket.send(self.encode_payload(reply)) + except TimeoutError: + continue + except Exception: + log.error("Exception in request handler", exc_info=True) + break async def handle_message(self, stream, payload): payload = self.decode_payload(payload) @@ -605,6 +606,7 @@ def _set_tcp_keepalive(zmq_socket, opts): ctx = zmq.asyncio.Context() + # TODO: unit tests! class AsyncReqMessageClient: """ @@ -643,25 +645,20 @@ class AsyncReqMessageClient: self.socket = None async def connect(self): - # wire up sockets - self._init_socket() + if self.socket is None: + # wire up sockets + self._init_socket() # TODO: timeout all in-flight sessions, or error def close(self): - try: - if self._closing: - return - except AttributeError: - # We must have been called from __del__ - # The python interpreter has nuked most attributes already + if self._closing: return - else: - self._closing = True - if self.socket: - self.socket.close() - if self.context.closed is False: - # This hangs if closing the stream causes an import error - self.context.term() + self._closing = True + if self.socket: + self.socket.close() + if self.context.closed is False: + # This hangs if closing the stream causes an import error + self.context.term() def _init_socket(self): if self.socket is not None: @@ -809,17 +806,17 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): def __init__(self, opts, **kwargs): self.opts = opts - #if self.opts.get("ipc_mode", "") == "tcp": + # if self.opts.get("ipc_mode", "") == "tcp": # self.pull_uri = "tcp://127.0.0.1:{}".format( # self.opts.get("tcp_master_publish_pull", 4514) # ) - #else: + # else: # self.pull_uri = "ipc://{}".format( # os.path.join(self.opts["sock_dir"], "publish_pull.ipc") # ) - #interface = self.opts.get("interface", "127.0.0.1") - #publish_port = self.opts.get("publish_port", 4560) - #self.pub_uri = f"tcp://{interface}:{publish_port}" + # interface = self.opts.get("interface", "127.0.0.1") + # publish_port = self.opts.get("publish_port", 4560) + # self.pub_uri = f"tcp://{interface}:{publish_port}" pub_host = kwargs.get("pub_host", None) pub_port = kwargs.get("pub_port", None) @@ -829,7 +826,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): else: self.pub_uri = f"tcp://{pub_host}:{pub_port}" - pull_host = kwargs.get("pull_host", None) pull_port = kwargs.get("pull_port", None) pull_path = kwargs.get("pull_path", None) @@ -838,7 +834,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): else: self.pull_uri = f"tcp://{pull_host}:{pull_port}" - self.ctx = None self.sock = None self.daemon_context = None diff --git a/salt/utils/asynchronous.py b/salt/utils/asynchronous.py index bef923e434f..5c456f68154 100644 --- a/salt/utils/asynchronous.py +++ b/salt/utils/asynchronous.py @@ -110,9 +110,6 @@ class SyncWrapper: log.error("No async method %s on object %r", method, self.obj) except Exception: # pylint: disable=broad-except log.exception("Exception encountered while running stop method") - # thread = threading.Thread(target=self._run_loop_final, args=(self.asyncio_loop,)) - # thread.start() - # thread.join() io_loop = self.io_loop io_loop.stop() try: @@ -143,18 +140,6 @@ class SyncWrapper: return wrap - def _run_loop_final(self, asyncio_loop): - asyncio.set_event_loop(asyncio_loop) - io_loop = tornado.ioloop.IOLoop.current() - try: - - async def noop(): - await asyncio.sleep(0.2) - - result = io_loop.run_sync(lambda: noop()) - except Exception: # pylint: disable=broad-except - log.error("Error on last loop run") - def _target(self, key, args, kwargs, results, asyncio_loop): asyncio.set_event_loop(asyncio_loop) io_loop = tornado.ioloop.IOLoop.current() @@ -173,10 +158,16 @@ class SyncWrapper: return self else: return ret + elif hasattr(self.obj, "__enter__"): + ret = self.obj.__enter__() + if ret == self.obj: + return self + else: + return ret return self def __exit__(self, exc_type, exc_val, tb): if hasattr(self.obj, "__aexit__"): - return self._wrap("__aexit__") + return self._wrap("__aexit__")(exc_type, exc_val, tb) else: self.close() diff --git a/tests/pytests/conftest.py b/tests/pytests/conftest.py index 8b81f01e9cb..9f0aa14f931 100644 --- a/tests/pytests/conftest.py +++ b/tests/pytests/conftest.py @@ -631,12 +631,11 @@ def io_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop = tornado.ioloop.IOLoop.current() - loop.make_current() try: yield loop finally: - loop.clear_current() loop.close(all_fds=True) + asyncio.set_event_loop(None) # <---- Async Test Fixtures ------------------------------------------------------------------------------------------ diff --git a/tests/pytests/functional/transport/ipc/test_pub_server_channel.py b/tests/pytests/functional/transport/ipc/test_pub_server_channel.py index 83e4cfdf5de..c66ee8ebbac 100644 --- a/tests/pytests/functional/transport/ipc/test_pub_server_channel.py +++ b/tests/pytests/functional/transport/ipc/test_pub_server_channel.py @@ -1,3 +1,4 @@ +import asyncio import logging import time import tracemalloc @@ -64,7 +65,6 @@ def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport): ZMQ's ipc transport not supported on Windows """ - import asyncio opts = dict( salt_master.config.copy(), @@ -98,7 +98,6 @@ def test_issue_36469_tcp(salt_master, salt_minion, transport): """ if transport == "tcp": pytest.skip("Test not applicable to the ZeroMQ transport.") - import asyncio def _send_small(opts, sid, num=10): loop = asyncio.new_event_loop() diff --git a/tests/pytests/unit/channel/test_request_channel.py b/tests/pytests/unit/channel/test_request_channel.py index 7161f05a33d..0fed2249c15 100644 --- a/tests/pytests/unit/channel/test_request_channel.py +++ b/tests/pytests/unit/channel/test_request_channel.py @@ -1,7 +1,7 @@ """ :codeauthor: Thomas Jackson """ - +import asyncio import ctypes import logging import multiprocessing @@ -337,7 +337,7 @@ def run_loop_in_thread(loop, evt): """ Run the provided loop until an event is set """ - loop.make_current() + asyncio.set_event_loop(loop.asyncio_loop) @tornado.gen.coroutine def stopper(): @@ -377,7 +377,7 @@ class MockSaltMinionMaster: master_opts.update({"transport": "zeromq"}) self.server_channel = salt.channel.server.ReqServerChannel.factory(master_opts) self.server_channel.pre_fork(self.process_manager) - self.io_loop = tornado.ioloop.IOLoop() + self.io_loop = tornado.ioloop.IOLoop(make_current=False) self.evt = threading.Event() self.server_channel.post_fork(self._handle_payload, io_loop=self.io_loop) self.server_thread = threading.Thread( @@ -403,6 +403,7 @@ class MockSaltMinionMaster: def __exit__(self, *args, **kwargs): self.channel.__exit__(*args, **kwargs) del self.channel + self.server_channel.close() # Attempting to kill the children hangs the test suite. # Let the test suite handle this instead. self.process_manager.stop_restarting() @@ -411,7 +412,6 @@ class MockSaltMinionMaster: self.server_thread.join() # Give the procs a chance to fully close before we stop the io_loop time.sleep(2) - self.server_channel.close() SMaster.secrets.pop("aes") del self.server_channel del self.io_loop @@ -481,28 +481,33 @@ def test_req_server_chan_encrypt_v2(pki_dir): dictkey = "pillar" nonce = "abcdefg" pillar_data = {"pillar1": "meh"} - ret = server._encrypt_private(pillar_data, dictkey, "minion", nonce) - assert "key" in ret - assert dictkey in ret + try: + ret = server._encrypt_private(pillar_data, dictkey, "minion", nonce) + assert "key" in ret + assert dictkey in ret - key = salt.crypt.get_rsa_key(str(pki_dir.joinpath("minion", "minion.pem")), None) - if HAS_M2: - aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding) - else: - cipher = PKCS1_OAEP.new(key) - aes = cipher.decrypt(ret["key"]) - pcrypt = salt.crypt.Crypticle(opts, aes) - signed_msg = pcrypt.loads(ret[dictkey]) + key = salt.crypt.get_rsa_key( + str(pki_dir.joinpath("minion", "minion.pem")), None + ) + if HAS_M2: + aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding) + else: + cipher = PKCS1_OAEP.new(key) + aes = cipher.decrypt(ret["key"]) + pcrypt = salt.crypt.Crypticle(opts, aes) + signed_msg = pcrypt.loads(ret[dictkey]) - assert "sig" in signed_msg - assert "data" in signed_msg - data = salt.payload.loads(signed_msg["data"]) - assert "key" in data - assert data["key"] == ret["key"] - assert "key" in data - assert data["nonce"] == nonce - assert "pillar" in data - assert data["pillar"] == pillar_data + assert "sig" in signed_msg + assert "data" in signed_msg + data = salt.payload.loads(signed_msg["data"]) + assert "key" in data + assert data["key"] == ret["key"] + assert "key" in data + assert data["nonce"] == nonce + assert "pillar" in data + assert data["pillar"] == pillar_data + finally: + server.close() def test_req_server_chan_encrypt_v1(pki_dir): @@ -525,20 +530,27 @@ def test_req_server_chan_encrypt_v1(pki_dir): dictkey = "pillar" nonce = "abcdefg" pillar_data = {"pillar1": "meh"} - ret = server._encrypt_private(pillar_data, dictkey, "minion", sign_messages=False) + try: + ret = server._encrypt_private( + pillar_data, dictkey, "minion", sign_messages=False + ) - assert "key" in ret - assert dictkey in ret + assert "key" in ret + assert dictkey in ret - key = salt.crypt.get_rsa_key(str(pki_dir.joinpath("minion", "minion.pem")), None) - if HAS_M2: - aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding) - else: - cipher = PKCS1_OAEP.new(key) - aes = cipher.decrypt(ret["key"]) - pcrypt = salt.crypt.Crypticle(opts, aes) - data = pcrypt.loads(ret[dictkey]) - assert data == pillar_data + key = salt.crypt.get_rsa_key( + str(pki_dir.joinpath("minion", "minion.pem")), None + ) + if HAS_M2: + aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding) + else: + cipher = PKCS1_OAEP.new(key) + aes = cipher.decrypt(ret["key"]) + pcrypt = salt.crypt.Crypticle(opts, aes) + data = pcrypt.loads(ret[dictkey]) + assert data == pillar_data + finally: + server.close() def test_req_chan_decode_data_dict_entry_v1(pki_dir): @@ -559,19 +571,23 @@ def test_req_chan_decode_data_dict_entry_v1(pki_dir): master_opts = dict(opts, pki_dir=str(pki_dir.joinpath("master"))) server = salt.channel.server.ReqServerChannel.factory(master_opts) client = salt.channel.client.ReqChannel.factory(opts, io_loop=mockloop) - dictkey = "pillar" - target = "minion" - pillar_data = {"pillar1": "meh"} - ret = server._encrypt_private(pillar_data, dictkey, target, sign_messages=False) - key = client.auth.get_keys() - if HAS_M2: - aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding) - else: - cipher = PKCS1_OAEP.new(key) - aes = cipher.decrypt(ret["key"]) - pcrypt = salt.crypt.Crypticle(client.opts, aes) - ret_pillar_data = pcrypt.loads(ret[dictkey]) - assert ret_pillar_data == pillar_data + try: + dictkey = "pillar" + target = "minion" + pillar_data = {"pillar1": "meh"} + ret = server._encrypt_private(pillar_data, dictkey, target, sign_messages=False) + key = client.auth.get_keys() + if HAS_M2: + aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding) + else: + cipher = PKCS1_OAEP.new(key) + aes = cipher.decrypt(ret["key"]) + pcrypt = salt.crypt.Crypticle(client.opts, aes) + ret_pillar_data = pcrypt.loads(ret[dictkey]) + assert ret_pillar_data == pillar_data + finally: + client.close() + server.close() async def test_req_chan_decode_data_dict_entry_v2(pki_dir): @@ -606,7 +622,9 @@ async def test_req_chan_decode_data_dict_entry_v2(pki_dir): client.auth.get_keys = auth.get_keys client.auth.crypticle.dumps = auth.crypticle.dumps client.auth.crypticle.loads = auth.crypticle.loads + real_transport = client.transport client.transport = MagicMock() + real_transport.close() @tornado.gen.coroutine def mocksend(msg, timeout=60, tries=3): @@ -631,13 +649,17 @@ async def test_req_chan_decode_data_dict_entry_v2(pki_dir): "ver": "2", "cmd": "_pillar", } - ret = await client.crypted_transfer_decode_dictentry( - load, - dictkey="pillar", - ) - assert "version" in client.transport.msg - assert client.transport.msg["version"] == 2 - assert ret == {"pillar1": "meh"} + try: + ret = await client.crypted_transfer_decode_dictentry( + load, + dictkey="pillar", + ) + assert "version" in client.transport.msg + assert client.transport.msg["version"] == 2 + assert ret == {"pillar1": "meh"} + finally: + client.close() + server.close() async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir): @@ -673,7 +695,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir): client.auth.get_keys = auth.get_keys client.auth.crypticle.dumps = auth.crypticle.dumps client.auth.crypticle.loads = auth.crypticle.loads + real_transport = client.transport client.transport = MagicMock() + real_transport.close() ret = server._encrypt_private( pillar_data, dictkey, target, nonce=badnonce, sign_messages=True ) @@ -698,12 +722,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir): "cmd": "_pillar", } - with pytest.raises(salt.crypt.AuthenticationError) as excinfo: - ret = await client.crypted_transfer_decode_dictentry( - load, - dictkey="pillar", - ) - assert "Pillar nonce verification failed." == excinfo.value.message + try: + with pytest.raises(salt.crypt.AuthenticationError) as excinfo: + ret = await client.crypted_transfer_decode_dictentry( + load, + dictkey="pillar", + ) + assert "Pillar nonce verification failed." == excinfo.value.message + finally: + client.close() + server.close() async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir): @@ -739,7 +767,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir): client.auth.get_keys = auth.get_keys client.auth.crypticle.dumps = auth.crypticle.dumps client.auth.crypticle.loads = auth.crypticle.loads + real_transport = client.transport client.transport = MagicMock() + real_transport.close() @tornado.gen.coroutine def mocksend(msg, timeout=60, tries=3): @@ -780,12 +810,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir): "cmd": "_pillar", } - with pytest.raises(salt.crypt.AuthenticationError) as excinfo: - ret = await client.crypted_transfer_decode_dictentry( - load, - dictkey="pillar", - ) - assert "Pillar payload signature failed to validate." == excinfo.value.message + try: + with pytest.raises(salt.crypt.AuthenticationError) as excinfo: + ret = await client.crypted_transfer_decode_dictentry( + load, + dictkey="pillar", + ) + assert "Pillar payload signature failed to validate." == excinfo.value.message + finally: + client.close() + server.close() async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir): @@ -821,7 +855,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir): client.auth.get_keys = auth.get_keys client.auth.crypticle.dumps = auth.crypticle.dumps client.auth.crypticle.loads = auth.crypticle.loads + real_transport = client.transport client.transport = MagicMock() + real_transport.close() @tornado.gen.coroutine def mocksend(msg, timeout=60, tries=3): @@ -869,12 +905,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir): "cmd": "_pillar", } - with pytest.raises(salt.crypt.AuthenticationError) as excinfo: - await client.crypted_transfer_decode_dictentry( - load, - dictkey="pillar", - ) - assert "Key verification failed." == excinfo.value.message + try: + with pytest.raises(salt.crypt.AuthenticationError) as excinfo: + await client.crypted_transfer_decode_dictentry( + load, + dictkey="pillar", + ) + assert "Key verification failed." == excinfo.value.message + finally: + client.close() + server.close() async def test_req_serv_auth_v1(pki_dir): @@ -926,8 +966,11 @@ async def test_req_serv_auth_v1(pki_dir): "token": token, "pub": pub_key, } - ret = server._auth(load, sign_messages=False) - assert "load" not in ret + try: + ret = server._auth(load, sign_messages=False) + assert "load" not in ret + finally: + server.close() async def test_req_serv_auth_v2(pki_dir): @@ -980,9 +1023,12 @@ async def test_req_serv_auth_v2(pki_dir): "token": token, "pub": pub_key, } - ret = server._auth(load, sign_messages=True) - assert "sig" in ret - assert "load" in ret + try: + ret = server._auth(load, sign_messages=True) + assert "sig" in ret + assert "load" in ret + finally: + server.close() async def test_req_chan_auth_v2(pki_dir, io_loop): @@ -1023,15 +1069,19 @@ async def test_req_chan_auth_v2(pki_dir, io_loop): client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop) signin_payload = client.auth.minion_sign_in_payload() pload = client._package_load(signin_payload) - assert "version" in pload - assert pload["version"] == 2 + try: + assert "version" in pload + assert pload["version"] == 2 - ret = server._auth(pload["load"], sign_messages=True) - assert "sig" in ret - ret = client.auth.handle_signin_response(signin_payload, ret) - assert "aes" in ret - assert "master_uri" in ret - assert "publish_port" in ret + ret = server._auth(pload["load"], sign_messages=True) + assert "sig" in ret + ret = client.auth.handle_signin_response(signin_payload, ret) + assert "aes" in ret + assert "master_uri" in ret + assert "publish_port" in ret + finally: + client.close() + server.close() async def test_req_chan_auth_v2_with_master_signing(pki_dir, io_loop): @@ -1113,16 +1163,20 @@ async def test_req_chan_auth_v2_with_master_signing(pki_dir, io_loop): signin_payload = client.auth.minion_sign_in_payload() pload = client._package_load(signin_payload) server_reply = server._auth(pload["load"], sign_messages=True) - ret = client.auth.handle_signin_response(signin_payload, server_reply) + try: + ret = client.auth.handle_signin_response(signin_payload, server_reply) - assert "aes" in ret - assert "master_uri" in ret - assert "publish_port" in ret + assert "aes" in ret + assert "master_uri" in ret + assert "publish_port" in ret - assert ( - pki_dir.joinpath("minion", "minion_master.pub").read_text() - == pki_dir.joinpath("master", "master.pub").read_text() - ) + assert ( + pki_dir.joinpath("minion", "minion_master.pub").read_text() + == pki_dir.joinpath("master", "master.pub").read_text() + ) + finally: + client.close() + server.close() async def test_req_chan_auth_v2_new_minion_with_master_pub(pki_dir, io_loop): @@ -1165,13 +1219,17 @@ async def test_req_chan_auth_v2_new_minion_with_master_pub(pki_dir, io_loop): client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop) signin_payload = client.auth.minion_sign_in_payload() pload = client._package_load(signin_payload) - assert "version" in pload - assert pload["version"] == 2 + try: + assert "version" in pload + assert pload["version"] == 2 - ret = server._auth(pload["load"], sign_messages=True) - assert "sig" in ret - ret = client.auth.handle_signin_response(signin_payload, ret) - assert ret == "retry" + ret = server._auth(pload["load"], sign_messages=True) + assert "sig" in ret + ret = client.auth.handle_signin_response(signin_payload, ret) + assert ret == "retry" + finally: + client.close() + server.close() async def test_req_chan_auth_v2_new_minion_with_master_pub_bad_sig(pki_dir, io_loop): @@ -1223,13 +1281,17 @@ async def test_req_chan_auth_v2_new_minion_with_master_pub_bad_sig(pki_dir, io_l client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop) signin_payload = client.auth.minion_sign_in_payload() pload = client._package_load(signin_payload) - assert "version" in pload - assert pload["version"] == 2 + try: + assert "version" in pload + assert pload["version"] == 2 - ret = server._auth(pload["load"], sign_messages=True) - assert "sig" in ret - with pytest.raises(salt.crypt.SaltClientError, match="Invalid signature"): - ret = client.auth.handle_signin_response(signin_payload, ret) + ret = server._auth(pload["load"], sign_messages=True) + assert "sig" in ret + with pytest.raises(salt.crypt.SaltClientError, match="Invalid signature"): + ret = client.auth.handle_signin_response(signin_payload, ret) + finally: + client.close() + server.close() async def test_req_chan_auth_v2_new_minion_without_master_pub(pki_dir, io_loop): @@ -1273,10 +1335,14 @@ async def test_req_chan_auth_v2_new_minion_without_master_pub(pki_dir, io_loop): client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop) signin_payload = client.auth.minion_sign_in_payload() pload = client._package_load(signin_payload) - assert "version" in pload - assert pload["version"] == 2 + try: + assert "version" in pload + assert pload["version"] == 2 - ret = server._auth(pload["load"], sign_messages=True) - assert "sig" in ret - ret = client.auth.handle_signin_response(signin_payload, ret) - assert ret == "retry" + ret = server._auth(pload["load"], sign_messages=True) + assert "sig" in ret + ret = client.auth.handle_signin_response(signin_payload, ret) + assert ret == "retry" + finally: + client.close() + server.close() diff --git a/tests/pytests/unit/transport/test_publish_client.py b/tests/pytests/unit/transport/test_publish_client.py index db7ef358e4e..a96db798514 100644 --- a/tests/pytests/unit/transport/test_publish_client.py +++ b/tests/pytests/unit/transport/test_publish_client.py @@ -174,6 +174,8 @@ async def test_publish_client_connect_server_down(transport, io_loop): await client.connect(background=True) except TimeoutError: pass + except Exception: + log.error("Got exception", exc_info=True) assert client._stream is None client.close() @@ -188,7 +190,6 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop): import zmq - return ctx = zmq.asyncio.Context() uri = f"tcp://{opts['master_ip']}:{port}" msg = salt.payload.dumps({"meh": 123}) @@ -196,23 +197,21 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop): client = salt.transport.zeromq.PublishClient( opts, io_loop, host=host, port=port ) - await client.connect(background=True) + await client.connect() assert client._socket socket = ctx.socket(zmq.PUB) socket.setsockopt(zmq.BACKLOG, 1000) socket.setsockopt(zmq.LINGER, -1) socket.setsockopt(zmq.SNDHWM, 1000) - print(f"bind {uri}") socket.bind(uri) - await asyncio.sleep(10) + await asyncio.sleep(20) async def recv(): return await client.recv(timeout=1) task = asyncio.create_task(recv()) # Sleep to allow zmq to do it's thing. - await asyncio.sleep(0.03) await socket.send(msg) await task response = task.result() diff --git a/tests/pytests/unit/transport/test_tcp.py b/tests/pytests/unit/transport/test_tcp.py index 7468f8ade2d..92173461fda 100644 --- a/tests/pytests/unit/transport/test_tcp.py +++ b/tests/pytests/unit/transport/test_tcp.py @@ -112,7 +112,7 @@ async def test_message_client_cleanup_on_close( # The run_sync call will set stop_called, reset it # orig_loop.stop_called = False - await client.close() + client.close() # Stop should be called again, client's io_loop should be None # assert orig_loop.stop_called is True diff --git a/tests/pytests/unit/utils/event/test_event.py b/tests/pytests/unit/utils/event/test_event.py index 4816272cfda..bf52b9b717d 100644 --- a/tests/pytests/unit/utils/event/test_event.py +++ b/tests/pytests/unit/utils/event/test_event.py @@ -5,9 +5,8 @@ import time from pathlib import Path import pytest -import tornado.ioloop import tornado.iostream -import zmq.eventloop.ioloop +import zmq import salt.config import salt.utils.event @@ -296,36 +295,36 @@ def test_send_master_event(sock_dir): ) -#def test_connect_pull_should_debug_log_on_StreamClosedError(): -# event = SaltEvent(node=None) -# with patch.object(event, "pusher") as mock_pusher: -# with patch.object( -# salt.utils.event.log, "debug", autospec=True -# ) as mock_log_debug: -# mock_pusher.connect.side_effect = tornado.iostream.StreamClosedError -# event.connect_pull() -# call = mock_log_debug.mock_calls[0] -# assert call.args[0] == "Unable to connect pusher: %s" -# assert isinstance(call.args[1], tornado.iostream.StreamClosedError) -# assert call.args[1].args[0] == "Stream is closed" -# -# -#@pytest.mark.parametrize("error", [Exception, KeyError, IOError]) -#def test_connect_pull_should_error_log_on_other_errors(error): -# event = SaltEvent(node=None) -# with patch.object(event, "pusher") as mock_pusher: -# with patch.object( -# salt.utils.event.log, "debug", autospec=True -# ) as mock_log_debug: -# with patch.object( -# salt.utils.event.log, "error", autospec=True -# ) as mock_log_error: -# mock_pusher.connect.side_effect = error -# event.connect_pull() -# mock_log_debug.assert_not_called() -# call = mock_log_error.mock_calls[0] -# assert call.args[0] == "Unable to connect pusher: %s" -# assert not isinstance(call.args[1], tornado.iostream.StreamClosedError) +def test_connect_pull_should_debug_log_on_StreamClosedError(): + event = SaltEvent(node=None) + with patch.object(event, "pusher") as mock_pusher: + with patch.object( + salt.utils.event.log, "debug", autospec=True + ) as mock_log_debug: + mock_pusher.connect.side_effect = tornado.iostream.StreamClosedError + event.connect_pull() + call = mock_log_debug.mock_calls[0] + assert call.args[0] == "Unable to connect pusher: %s" + assert isinstance(call.args[1], tornado.iostream.StreamClosedError) + assert call.args[1].args[0] == "Stream is closed" + + +@pytest.mark.parametrize("error", [Exception, KeyError, IOError]) +def test_connect_pull_should_error_log_on_other_errors(error): + event = SaltEvent(node=None) + with patch.object(event, "pusher") as mock_pusher: + with patch.object( + salt.utils.event.log, "debug", autospec=True + ) as mock_log_debug: + with patch.object( + salt.utils.event.log, "error", autospec=True + ) as mock_log_error: + mock_pusher.connect.side_effect = error + event.connect_pull() + mock_log_debug.assert_not_called() + call = mock_log_error.mock_calls[0] + assert call.args[0] == "Unable to connect pusher: %s" + assert not isinstance(call.args[1], tornado.iostream.StreamClosedError) @pytest.mark.slow_test diff --git a/tests/support/events.py b/tests/support/events.py index 2d918d29451..06f528c00c8 100644 --- a/tests/support/events.py +++ b/tests/support/events.py @@ -3,14 +3,13 @@ ~~~~~~~~~~~~~~~~~~~~ """ - import multiprocessing import os import time from contextlib import contextmanager import salt.utils.event -from salt.utils.process import clean_proc, Process +from salt.utils.process import Process, clean_proc @contextmanager @@ -35,7 +34,7 @@ def eventpublisher_process(sock_dir): ipc_publisher.publish_payload, ], ) - #proc = salt.utils.event.EventPublisher({"sock_dir": sock_dir}) + # proc = salt.utils.event.EventPublisher({"sock_dir": sock_dir}) proc.start() try: if os.environ.get("TRAVIS_PYTHON_VERSION", None) is not None: