More test fixes

This commit is contained in:
Daniel A. Wozniak 2023-06-26 22:01:43 -07:00 committed by Gareth J. Greenaway
parent fea99b1335
commit a2f428e5b3
13 changed files with 289 additions and 293 deletions

View file

@ -14,6 +14,7 @@ from collections import deque
from salt._logging.mixins import ExcInfoOnLogLevelFormatMixin
from salt.utils.versions import warn_until_date
log = logging.getLogger(__name__)
@ -94,9 +95,6 @@ class DeferredStreamHandler(StreamHandler):
super().__init__(stream)
self.__messages = deque(maxlen=max_queue_size)
self.__emitting = False
import traceback
self.stack = "".join(traceback.format_stack())
def handle(self, record):
self.acquire()

View file

@ -368,7 +368,7 @@ class AsyncReqChannel:
self.close()
async def __aenter__(self):
await self.connect()
await self.transport.connect()
return self
async def __aexit__(self, exc_type, exc, tb):

View file

@ -1,4 +1,5 @@
import os
import tornado.gen
TRANSPORTS = (
@ -63,7 +64,7 @@ def publish_server(opts, **kwargs):
if "pub_host" not in kwargs and "pub_path" not in kwargs:
kwargs["pub_host"] = opts["interface"]
if "pub_port" not in kwargs and "pub_path" not in kwargs:
kwargs["pub_port"] = opts["publish_port"]
kwargs["pub_port"] = opts.get("publish_port", 4506)
if "pull_host" not in kwargs and "pull_path" not in kwargs:
if opts.get("ipc_mode", "") == "tcp":

View file

@ -10,7 +10,6 @@ import asyncio
import errno
import logging
import multiprocessing
import os
import queue
import select
import socket
@ -281,7 +280,6 @@ class TCPPubClient(salt.transport.base.PublishClient):
while stream is None and (not self._closed and not self._closing):
try:
if self.host and self.port:
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
self._tcp_client = TCPClientKeepAlive(
self.opts, resolver=self.resolver
)
@ -313,14 +311,8 @@ class TCPPubClient(salt.transport.base.PublishClient):
return stream
async def _connect(self):
# log.error("Connect %r %r", self, self._stream)
# import traceback
# stack = "".join(traceback.format_stack())
# log.error(f"MessageClient Connect {name} {stack}")
if self._stream is None:
log.error("Get stream")
self._stream = await self.getstream()
log.error("Got stream")
if self._stream:
# if not self._stream_return_running:
# self.io_loop.spawn_callback(self._stream_return)
@ -407,9 +399,7 @@ class TCPPubClient(salt.transport.base.PublishClient):
"""
async def setup_callback():
logit = True
while not self._stream:
# log.error("On recv wait stream %r", self._stream)
await asyncio.sleep(0.003)
while True:
try:
@ -418,9 +408,6 @@ class TCPPubClient(salt.transport.base.PublishClient):
except tornado.iostream.StreamClosedError:
self._stream.close()
self._stream = None
if logit:
log.error("Stream Closed", exc_info=True)
logit = False
await self._connect()
await asyncio.sleep(0.03)
# if self.disconnect_callback:
@ -757,15 +744,6 @@ class MessageClient:
if self.io_loop is not None:
self.io_loop.stop()
# TODO: timeout inflight sessions
def close(self):
if self._closing:
return
self._closing = True
self.stream.close()
self.tcp_client.close()
self.io_loop.add_timeout(1, self.check_close)
def close(self):
if self._closing:
return
@ -803,7 +781,6 @@ class MessageClient:
while stream is None and (not self._closed and not self._closing):
try:
if self.host and self.port:
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
stream = await self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
@ -811,7 +788,6 @@ class MessageClient:
**kwargs,
)
else:
log.error("GET STREAM IPC")
sock_type = socket.AF_UNIX
path = self.url.replace("ipc://", "")
stream = tornado.iostream.IOStream(
@ -831,15 +807,8 @@ class MessageClient:
return stream
async def connect(self):
log.error("Connect %r %r", self, self._stream)
import traceback
# stack = "".join(traceback.format_stack())
# log.error(f"MessageClient Connect {name} {stack}")
if self._stream is None:
log.error("Get stream")
self._stream = await self.getstream()
log.error("Got stream")
if self._stream:
if not self._stream_return_running:
self.task = asyncio.create_task(self._stream_return())
@ -852,9 +821,7 @@ class MessageClient:
unpacker = salt.utils.msgpack.Unpacker()
while not self._closing:
try:
log.error("Stream read bytes %r", self._stream.socket)
wire_bytes = await self._stream.read_bytes(4096, partial=True)
log.error("Stream got bytes")
unpacker.feed(wire_bytes)
for framed_msg in unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
@ -903,9 +870,6 @@ class MessageClient:
)
else:
raise SaltClientError
# except OSError:
# log.error("OSERROR", exc_info=True)
# raise
except Exception as e: # pylint: disable=broad-except
log.error("Exception parsing response", exc_info=True)
for future in self.send_future_map.values():
@ -958,11 +922,9 @@ class MessageClient:
future.set_exception(SaltReqTimeoutError("Message timed out"))
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
# log.error("stream send %r %r %r %r", self, self.url, self.port, reply)
if self._closing:
raise ClosingError()
while not self._stream:
# log.error("Wait stream %r %r", self, self._stream)
await asyncio.sleep(0.03)
message_id = self._message_id()
header = {"mid": message_id}
@ -1011,7 +973,6 @@ class MessageClient:
try:
if timeout == 0:
for msg in self.unpacker:
log.error("RECV a")
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
poller = select.poll()
@ -1025,27 +986,21 @@ class MessageClient:
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
log.error("RECV b")
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
else:
return
elif timeout:
log.error("RECV c")
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
log.error("RECV d")
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
log.error("RECV e")
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
log.error("RECV e %r", msg)
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("RECV e %r", framed_msg)
return framed_msg["body"]
finally:
self._read_in_progress.release()
@ -1211,39 +1166,38 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
self.opts = opts
self.pub_sock = None
# Set up Salt IPC server
#if self.opts.get("ipc_mode", "") == "tcp":
# if self.opts.get("ipc_mode", "") == "tcp":
# self.pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
#else:
# else:
# self.pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
#interface = self.opts.get("interface", "127.0.0.1")
#self.publish_port = self.opts.get("publish_port", 4560)
#self.pub_uri = f"tcp://{interface}:{self.publish_port}"
# interface = self.opts.get("interface", "127.0.0.1")
# self.publish_port = self.opts.get("publish_port", 4560)
# self.pub_uri = f"tcp://{interface}:{self.publish_port}"
self.pub_host = kwargs.get("pub_host", None)
self.pub_port = kwargs.get("pub_port", None)
self.pub_path = kwargs.get("pub_path", None)
#if pub_path:
# if pub_path:
# self.pub_path = pub_path
# self.pub_uri = f"ipc://{pub_path}"
#else:
# else:
# self.pub_uri = f"tcp://{pub_host}:{pub_port}"
#self.publish_port = self.opts.get("publish_port", 4560)
# self.publish_port = self.opts.get("publish_port", 4560)
self.pull_host = kwargs.get("pull_host", None)
self.pull_port = kwargs.get("pull_port", None)
self.pull_path = kwargs.get("pull_path", None)
#if pull_path:
# if pull_path:
# self.pull_uri = f"ipc://{pull_path}"
#else:
# else:
# self.pull_uri = f"tcp://{pub_host}:{pub_port}"
#log.error(
# log.error(
# "TCPPubServer %r %s %s",
# self,
# self.pull_uri,
# #self.publish_port,
# self.pub_uri,
#)
# )
@property
def topic_support(self):
@ -1265,13 +1219,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
Bind to the interface specified in the configuration file
"""
io_loop = tornado.ioloop.IOLoop()
#log.error(
# log.error(
# "TCPPubServer daemon %r %s %s %s",
# self,
# self.pull_uri,
# self.publish_port,
# self.pub_uri,
#)
# )
# Spin up the publisher
self.pub_server = pub_server = PubServer(
@ -1298,13 +1252,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
# else:
# pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
self.pub_server = pub_server
#if "ipc://" in self.pull_uri:
# if "ipc://" in self.pull_uri:
# pull_uri = pull_uri = self.pull_uri.replace("ipc://", "")
# log.error("WTF PULL URI %r", pull_uri)
#elif "tcp://" in self.pull_uri:
# elif "tcp://" in self.pull_uri:
# log.error("Fallback to publish port %r", self.pull_uri)
# pull_uri = self.publish_port
#else:
# else:
# pull_uri = self.pull_uri
if self.pull_path:
pull_uri = self.pull_path
@ -1345,7 +1299,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
raise tornado.gen.Return(ret)
def connect(self):
#path = self.pull_uri.replace("ipc://", "")
# path = self.pull_uri.replace("ipc://", "")
log.error("Connect pusher %s", self.pull_path)
# self.pub_sock = salt.utils.asynchronous.SyncWrapper(
# salt.transport.ipc.IPCMessageClient,
@ -1417,15 +1371,11 @@ class TCPReqClient(salt.transport.base.RequestClient):
)
async def connect(self):
log.error("TCPReqClient Connect")
await self.message_client.connect()
log.error("TCPReqClient Connected")
async def send(self, load, timeout=60):
await self.connect()
log.error("TCP Request %r", load)
msg = await self.message_client.send(load, timeout=timeout)
log.error("TCP Reply %r", msg)
return msg
def close(self):

View file

@ -222,7 +222,7 @@ class PublishClient(salt.transport.base.PublishClient):
if self.path:
raise Exception("A host and port or a path must be provided, not both")
async def close(self):
def close(self):
if self._closing is True:
return
self._closing = True
@ -238,21 +238,10 @@ class PublishClient(salt.transport.base.PublishClient):
if self.callbacks:
for cb in self.callbacks:
running, task = self.callbacks[cb]
task.cancel()
def close(self):
if self._closing is True:
return
self._closing = True
if hasattr(self, "_monitor") and self._monitor is not None:
self._monitor.stop()
self._monitor = None
if hasattr(self, "_stream"):
self._stream.close(0)
elif hasattr(self, "_socket"):
self._socket.close(0)
if hasattr(self, "context") and self.context.closed is False:
self.context.term()
try:
task.cancel()
except RuntimeError:
log.warning("Tasks loop already closed")
# pylint: enable=W1701
def __enter__(self):
@ -402,7 +391,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
self._closing = False
self._monitor = None
self._w_monitor = None
self.task = None
self.tasks = set()
self._event = asyncio.Event()
def zmq_device(self):
@ -485,6 +474,11 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
self._socket.close()
if hasattr(self, "context") and self.context.closed is False:
self.context.term()
for task in list(self.tasks):
try:
task.cancel()
except RuntimeError:
log.error("IOLoop closed when trying to cancel task")
def pre_fork(self, process_manager):
"""
@ -518,8 +512,8 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
:param IOLoop io_loop: An instance of a Tornado IOLoop, to handle event scheduling
"""
# context = zmq.Context(1)
context = zmq.asyncio.Context(1)
self._socket = context.socket(zmq.REP)
self.context = zmq.asyncio.Context(1)
self._socket = self.context.socket(zmq.REP)
# Linger -1 means we'll never discard messages.
self._socket.setsockopt(zmq.LINGER, -1)
self._start_zmq_monitor()
@ -541,16 +535,23 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
self.message_handler = message_handler
async def callback():
self.task = asyncio.create_task(self.request_handler())
await self.task
task = asyncio.create_task(self.request_handler())
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)
io_loop.add_callback(callback)
async def request_handler(self):
while not self._event.is_set():
request = await self._socket.recv()
reply = await self.handle_message(None, request)
await self._socket.send(self.encode_payload(reply))
try:
request = await asyncio.wait_for(self._socket.recv(), 1)
reply = await self.handle_message(None, request)
await self._socket.send(self.encode_payload(reply))
except TimeoutError:
continue
except Exception:
log.error("Exception in request handler", exc_info=True)
break
async def handle_message(self, stream, payload):
payload = self.decode_payload(payload)
@ -605,6 +606,7 @@ def _set_tcp_keepalive(zmq_socket, opts):
ctx = zmq.asyncio.Context()
# TODO: unit tests!
class AsyncReqMessageClient:
"""
@ -643,25 +645,20 @@ class AsyncReqMessageClient:
self.socket = None
async def connect(self):
# wire up sockets
self._init_socket()
if self.socket is None:
# wire up sockets
self._init_socket()
# TODO: timeout all in-flight sessions, or error
def close(self):
try:
if self._closing:
return
except AttributeError:
# We must have been called from __del__
# The python interpreter has nuked most attributes already
if self._closing:
return
else:
self._closing = True
if self.socket:
self.socket.close()
if self.context.closed is False:
# This hangs if closing the stream causes an import error
self.context.term()
self._closing = True
if self.socket:
self.socket.close()
if self.context.closed is False:
# This hangs if closing the stream causes an import error
self.context.term()
def _init_socket(self):
if self.socket is not None:
@ -809,17 +806,17 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
def __init__(self, opts, **kwargs):
self.opts = opts
#if self.opts.get("ipc_mode", "") == "tcp":
# if self.opts.get("ipc_mode", "") == "tcp":
# self.pull_uri = "tcp://127.0.0.1:{}".format(
# self.opts.get("tcp_master_publish_pull", 4514)
# )
#else:
# else:
# self.pull_uri = "ipc://{}".format(
# os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
# )
#interface = self.opts.get("interface", "127.0.0.1")
#publish_port = self.opts.get("publish_port", 4560)
#self.pub_uri = f"tcp://{interface}:{publish_port}"
# interface = self.opts.get("interface", "127.0.0.1")
# publish_port = self.opts.get("publish_port", 4560)
# self.pub_uri = f"tcp://{interface}:{publish_port}"
pub_host = kwargs.get("pub_host", None)
pub_port = kwargs.get("pub_port", None)
@ -829,7 +826,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
else:
self.pub_uri = f"tcp://{pub_host}:{pub_port}"
pull_host = kwargs.get("pull_host", None)
pull_port = kwargs.get("pull_port", None)
pull_path = kwargs.get("pull_path", None)
@ -838,7 +834,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
else:
self.pull_uri = f"tcp://{pull_host}:{pull_port}"
self.ctx = None
self.sock = None
self.daemon_context = None

View file

@ -110,9 +110,6 @@ class SyncWrapper:
log.error("No async method %s on object %r", method, self.obj)
except Exception: # pylint: disable=broad-except
log.exception("Exception encountered while running stop method")
# thread = threading.Thread(target=self._run_loop_final, args=(self.asyncio_loop,))
# thread.start()
# thread.join()
io_loop = self.io_loop
io_loop.stop()
try:
@ -143,18 +140,6 @@ class SyncWrapper:
return wrap
def _run_loop_final(self, asyncio_loop):
asyncio.set_event_loop(asyncio_loop)
io_loop = tornado.ioloop.IOLoop.current()
try:
async def noop():
await asyncio.sleep(0.2)
result = io_loop.run_sync(lambda: noop())
except Exception: # pylint: disable=broad-except
log.error("Error on last loop run")
def _target(self, key, args, kwargs, results, asyncio_loop):
asyncio.set_event_loop(asyncio_loop)
io_loop = tornado.ioloop.IOLoop.current()
@ -173,10 +158,16 @@ class SyncWrapper:
return self
else:
return ret
elif hasattr(self.obj, "__enter__"):
ret = self.obj.__enter__()
if ret == self.obj:
return self
else:
return ret
return self
def __exit__(self, exc_type, exc_val, tb):
if hasattr(self.obj, "__aexit__"):
return self._wrap("__aexit__")
return self._wrap("__aexit__")(exc_type, exc_val, tb)
else:
self.close()

View file

@ -631,12 +631,11 @@ def io_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = tornado.ioloop.IOLoop.current()
loop.make_current()
try:
yield loop
finally:
loop.clear_current()
loop.close(all_fds=True)
asyncio.set_event_loop(None)
# <---- Async Test Fixtures ------------------------------------------------------------------------------------------

View file

@ -1,3 +1,4 @@
import asyncio
import logging
import time
import tracemalloc
@ -64,7 +65,6 @@ def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
ZMQ's ipc transport not supported on Windows
"""
import asyncio
opts = dict(
salt_master.config.copy(),
@ -98,7 +98,6 @@ def test_issue_36469_tcp(salt_master, salt_minion, transport):
"""
if transport == "tcp":
pytest.skip("Test not applicable to the ZeroMQ transport.")
import asyncio
def _send_small(opts, sid, num=10):
loop = asyncio.new_event_loop()

View file

@ -1,7 +1,7 @@
"""
:codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
"""
import asyncio
import ctypes
import logging
import multiprocessing
@ -337,7 +337,7 @@ def run_loop_in_thread(loop, evt):
"""
Run the provided loop until an event is set
"""
loop.make_current()
asyncio.set_event_loop(loop.asyncio_loop)
@tornado.gen.coroutine
def stopper():
@ -377,7 +377,7 @@ class MockSaltMinionMaster:
master_opts.update({"transport": "zeromq"})
self.server_channel = salt.channel.server.ReqServerChannel.factory(master_opts)
self.server_channel.pre_fork(self.process_manager)
self.io_loop = tornado.ioloop.IOLoop()
self.io_loop = tornado.ioloop.IOLoop(make_current=False)
self.evt = threading.Event()
self.server_channel.post_fork(self._handle_payload, io_loop=self.io_loop)
self.server_thread = threading.Thread(
@ -403,6 +403,7 @@ class MockSaltMinionMaster:
def __exit__(self, *args, **kwargs):
self.channel.__exit__(*args, **kwargs)
del self.channel
self.server_channel.close()
# Attempting to kill the children hangs the test suite.
# Let the test suite handle this instead.
self.process_manager.stop_restarting()
@ -411,7 +412,6 @@ class MockSaltMinionMaster:
self.server_thread.join()
# Give the procs a chance to fully close before we stop the io_loop
time.sleep(2)
self.server_channel.close()
SMaster.secrets.pop("aes")
del self.server_channel
del self.io_loop
@ -481,28 +481,33 @@ def test_req_server_chan_encrypt_v2(pki_dir):
dictkey = "pillar"
nonce = "abcdefg"
pillar_data = {"pillar1": "meh"}
ret = server._encrypt_private(pillar_data, dictkey, "minion", nonce)
assert "key" in ret
assert dictkey in ret
try:
ret = server._encrypt_private(pillar_data, dictkey, "minion", nonce)
assert "key" in ret
assert dictkey in ret
key = salt.crypt.get_rsa_key(str(pki_dir.joinpath("minion", "minion.pem")), None)
if HAS_M2:
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
else:
cipher = PKCS1_OAEP.new(key)
aes = cipher.decrypt(ret["key"])
pcrypt = salt.crypt.Crypticle(opts, aes)
signed_msg = pcrypt.loads(ret[dictkey])
key = salt.crypt.get_rsa_key(
str(pki_dir.joinpath("minion", "minion.pem")), None
)
if HAS_M2:
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
else:
cipher = PKCS1_OAEP.new(key)
aes = cipher.decrypt(ret["key"])
pcrypt = salt.crypt.Crypticle(opts, aes)
signed_msg = pcrypt.loads(ret[dictkey])
assert "sig" in signed_msg
assert "data" in signed_msg
data = salt.payload.loads(signed_msg["data"])
assert "key" in data
assert data["key"] == ret["key"]
assert "key" in data
assert data["nonce"] == nonce
assert "pillar" in data
assert data["pillar"] == pillar_data
assert "sig" in signed_msg
assert "data" in signed_msg
data = salt.payload.loads(signed_msg["data"])
assert "key" in data
assert data["key"] == ret["key"]
assert "key" in data
assert data["nonce"] == nonce
assert "pillar" in data
assert data["pillar"] == pillar_data
finally:
server.close()
def test_req_server_chan_encrypt_v1(pki_dir):
@ -525,20 +530,27 @@ def test_req_server_chan_encrypt_v1(pki_dir):
dictkey = "pillar"
nonce = "abcdefg"
pillar_data = {"pillar1": "meh"}
ret = server._encrypt_private(pillar_data, dictkey, "minion", sign_messages=False)
try:
ret = server._encrypt_private(
pillar_data, dictkey, "minion", sign_messages=False
)
assert "key" in ret
assert dictkey in ret
assert "key" in ret
assert dictkey in ret
key = salt.crypt.get_rsa_key(str(pki_dir.joinpath("minion", "minion.pem")), None)
if HAS_M2:
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
else:
cipher = PKCS1_OAEP.new(key)
aes = cipher.decrypt(ret["key"])
pcrypt = salt.crypt.Crypticle(opts, aes)
data = pcrypt.loads(ret[dictkey])
assert data == pillar_data
key = salt.crypt.get_rsa_key(
str(pki_dir.joinpath("minion", "minion.pem")), None
)
if HAS_M2:
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
else:
cipher = PKCS1_OAEP.new(key)
aes = cipher.decrypt(ret["key"])
pcrypt = salt.crypt.Crypticle(opts, aes)
data = pcrypt.loads(ret[dictkey])
assert data == pillar_data
finally:
server.close()
def test_req_chan_decode_data_dict_entry_v1(pki_dir):
@ -559,19 +571,23 @@ def test_req_chan_decode_data_dict_entry_v1(pki_dir):
master_opts = dict(opts, pki_dir=str(pki_dir.joinpath("master")))
server = salt.channel.server.ReqServerChannel.factory(master_opts)
client = salt.channel.client.ReqChannel.factory(opts, io_loop=mockloop)
dictkey = "pillar"
target = "minion"
pillar_data = {"pillar1": "meh"}
ret = server._encrypt_private(pillar_data, dictkey, target, sign_messages=False)
key = client.auth.get_keys()
if HAS_M2:
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
else:
cipher = PKCS1_OAEP.new(key)
aes = cipher.decrypt(ret["key"])
pcrypt = salt.crypt.Crypticle(client.opts, aes)
ret_pillar_data = pcrypt.loads(ret[dictkey])
assert ret_pillar_data == pillar_data
try:
dictkey = "pillar"
target = "minion"
pillar_data = {"pillar1": "meh"}
ret = server._encrypt_private(pillar_data, dictkey, target, sign_messages=False)
key = client.auth.get_keys()
if HAS_M2:
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
else:
cipher = PKCS1_OAEP.new(key)
aes = cipher.decrypt(ret["key"])
pcrypt = salt.crypt.Crypticle(client.opts, aes)
ret_pillar_data = pcrypt.loads(ret[dictkey])
assert ret_pillar_data == pillar_data
finally:
client.close()
server.close()
async def test_req_chan_decode_data_dict_entry_v2(pki_dir):
@ -606,7 +622,9 @@ async def test_req_chan_decode_data_dict_entry_v2(pki_dir):
client.auth.get_keys = auth.get_keys
client.auth.crypticle.dumps = auth.crypticle.dumps
client.auth.crypticle.loads = auth.crypticle.loads
real_transport = client.transport
client.transport = MagicMock()
real_transport.close()
@tornado.gen.coroutine
def mocksend(msg, timeout=60, tries=3):
@ -631,13 +649,17 @@ async def test_req_chan_decode_data_dict_entry_v2(pki_dir):
"ver": "2",
"cmd": "_pillar",
}
ret = await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "version" in client.transport.msg
assert client.transport.msg["version"] == 2
assert ret == {"pillar1": "meh"}
try:
ret = await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "version" in client.transport.msg
assert client.transport.msg["version"] == 2
assert ret == {"pillar1": "meh"}
finally:
client.close()
server.close()
async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir):
@ -673,7 +695,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir):
client.auth.get_keys = auth.get_keys
client.auth.crypticle.dumps = auth.crypticle.dumps
client.auth.crypticle.loads = auth.crypticle.loads
real_transport = client.transport
client.transport = MagicMock()
real_transport.close()
ret = server._encrypt_private(
pillar_data, dictkey, target, nonce=badnonce, sign_messages=True
)
@ -698,12 +722,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir):
"cmd": "_pillar",
}
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
ret = await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "Pillar nonce verification failed." == excinfo.value.message
try:
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
ret = await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "Pillar nonce verification failed." == excinfo.value.message
finally:
client.close()
server.close()
async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir):
@ -739,7 +767,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir):
client.auth.get_keys = auth.get_keys
client.auth.crypticle.dumps = auth.crypticle.dumps
client.auth.crypticle.loads = auth.crypticle.loads
real_transport = client.transport
client.transport = MagicMock()
real_transport.close()
@tornado.gen.coroutine
def mocksend(msg, timeout=60, tries=3):
@ -780,12 +810,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir):
"cmd": "_pillar",
}
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
ret = await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "Pillar payload signature failed to validate." == excinfo.value.message
try:
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
ret = await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "Pillar payload signature failed to validate." == excinfo.value.message
finally:
client.close()
server.close()
async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir):
@ -821,7 +855,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir):
client.auth.get_keys = auth.get_keys
client.auth.crypticle.dumps = auth.crypticle.dumps
client.auth.crypticle.loads = auth.crypticle.loads
real_transport = client.transport
client.transport = MagicMock()
real_transport.close()
@tornado.gen.coroutine
def mocksend(msg, timeout=60, tries=3):
@ -869,12 +905,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir):
"cmd": "_pillar",
}
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "Key verification failed." == excinfo.value.message
try:
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
await client.crypted_transfer_decode_dictentry(
load,
dictkey="pillar",
)
assert "Key verification failed." == excinfo.value.message
finally:
client.close()
server.close()
async def test_req_serv_auth_v1(pki_dir):
@ -926,8 +966,11 @@ async def test_req_serv_auth_v1(pki_dir):
"token": token,
"pub": pub_key,
}
ret = server._auth(load, sign_messages=False)
assert "load" not in ret
try:
ret = server._auth(load, sign_messages=False)
assert "load" not in ret
finally:
server.close()
async def test_req_serv_auth_v2(pki_dir):
@ -980,9 +1023,12 @@ async def test_req_serv_auth_v2(pki_dir):
"token": token,
"pub": pub_key,
}
ret = server._auth(load, sign_messages=True)
assert "sig" in ret
assert "load" in ret
try:
ret = server._auth(load, sign_messages=True)
assert "sig" in ret
assert "load" in ret
finally:
server.close()
async def test_req_chan_auth_v2(pki_dir, io_loop):
@ -1023,15 +1069,19 @@ async def test_req_chan_auth_v2(pki_dir, io_loop):
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
signin_payload = client.auth.minion_sign_in_payload()
pload = client._package_load(signin_payload)
assert "version" in pload
assert pload["version"] == 2
try:
assert "version" in pload
assert pload["version"] == 2
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
ret = client.auth.handle_signin_response(signin_payload, ret)
assert "aes" in ret
assert "master_uri" in ret
assert "publish_port" in ret
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
ret = client.auth.handle_signin_response(signin_payload, ret)
assert "aes" in ret
assert "master_uri" in ret
assert "publish_port" in ret
finally:
client.close()
server.close()
async def test_req_chan_auth_v2_with_master_signing(pki_dir, io_loop):
@ -1113,16 +1163,20 @@ async def test_req_chan_auth_v2_with_master_signing(pki_dir, io_loop):
signin_payload = client.auth.minion_sign_in_payload()
pload = client._package_load(signin_payload)
server_reply = server._auth(pload["load"], sign_messages=True)
ret = client.auth.handle_signin_response(signin_payload, server_reply)
try:
ret = client.auth.handle_signin_response(signin_payload, server_reply)
assert "aes" in ret
assert "master_uri" in ret
assert "publish_port" in ret
assert "aes" in ret
assert "master_uri" in ret
assert "publish_port" in ret
assert (
pki_dir.joinpath("minion", "minion_master.pub").read_text()
== pki_dir.joinpath("master", "master.pub").read_text()
)
assert (
pki_dir.joinpath("minion", "minion_master.pub").read_text()
== pki_dir.joinpath("master", "master.pub").read_text()
)
finally:
client.close()
server.close()
async def test_req_chan_auth_v2_new_minion_with_master_pub(pki_dir, io_loop):
@ -1165,13 +1219,17 @@ async def test_req_chan_auth_v2_new_minion_with_master_pub(pki_dir, io_loop):
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
signin_payload = client.auth.minion_sign_in_payload()
pload = client._package_load(signin_payload)
assert "version" in pload
assert pload["version"] == 2
try:
assert "version" in pload
assert pload["version"] == 2
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
ret = client.auth.handle_signin_response(signin_payload, ret)
assert ret == "retry"
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
ret = client.auth.handle_signin_response(signin_payload, ret)
assert ret == "retry"
finally:
client.close()
server.close()
async def test_req_chan_auth_v2_new_minion_with_master_pub_bad_sig(pki_dir, io_loop):
@ -1223,13 +1281,17 @@ async def test_req_chan_auth_v2_new_minion_with_master_pub_bad_sig(pki_dir, io_l
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
signin_payload = client.auth.minion_sign_in_payload()
pload = client._package_load(signin_payload)
assert "version" in pload
assert pload["version"] == 2
try:
assert "version" in pload
assert pload["version"] == 2
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
with pytest.raises(salt.crypt.SaltClientError, match="Invalid signature"):
ret = client.auth.handle_signin_response(signin_payload, ret)
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
with pytest.raises(salt.crypt.SaltClientError, match="Invalid signature"):
ret = client.auth.handle_signin_response(signin_payload, ret)
finally:
client.close()
server.close()
async def test_req_chan_auth_v2_new_minion_without_master_pub(pki_dir, io_loop):
@ -1273,10 +1335,14 @@ async def test_req_chan_auth_v2_new_minion_without_master_pub(pki_dir, io_loop):
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
signin_payload = client.auth.minion_sign_in_payload()
pload = client._package_load(signin_payload)
assert "version" in pload
assert pload["version"] == 2
try:
assert "version" in pload
assert pload["version"] == 2
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
ret = client.auth.handle_signin_response(signin_payload, ret)
assert ret == "retry"
ret = server._auth(pload["load"], sign_messages=True)
assert "sig" in ret
ret = client.auth.handle_signin_response(signin_payload, ret)
assert ret == "retry"
finally:
client.close()
server.close()

View file

@ -174,6 +174,8 @@ async def test_publish_client_connect_server_down(transport, io_loop):
await client.connect(background=True)
except TimeoutError:
pass
except Exception:
log.error("Got exception", exc_info=True)
assert client._stream is None
client.close()
@ -188,7 +190,6 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
import zmq
return
ctx = zmq.asyncio.Context()
uri = f"tcp://{opts['master_ip']}:{port}"
msg = salt.payload.dumps({"meh": 123})
@ -196,23 +197,21 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
client = salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port
)
await client.connect(background=True)
await client.connect()
assert client._socket
socket = ctx.socket(zmq.PUB)
socket.setsockopt(zmq.BACKLOG, 1000)
socket.setsockopt(zmq.LINGER, -1)
socket.setsockopt(zmq.SNDHWM, 1000)
print(f"bind {uri}")
socket.bind(uri)
await asyncio.sleep(10)
await asyncio.sleep(20)
async def recv():
return await client.recv(timeout=1)
task = asyncio.create_task(recv())
# Sleep to allow zmq to do it's thing.
await asyncio.sleep(0.03)
await socket.send(msg)
await task
response = task.result()

View file

@ -112,7 +112,7 @@ async def test_message_client_cleanup_on_close(
# The run_sync call will set stop_called, reset it
# orig_loop.stop_called = False
await client.close()
client.close()
# Stop should be called again, client's io_loop should be None
# assert orig_loop.stop_called is True

View file

@ -5,9 +5,8 @@ import time
from pathlib import Path
import pytest
import tornado.ioloop
import tornado.iostream
import zmq.eventloop.ioloop
import zmq
import salt.config
import salt.utils.event
@ -296,36 +295,36 @@ def test_send_master_event(sock_dir):
)
#def test_connect_pull_should_debug_log_on_StreamClosedError():
# event = SaltEvent(node=None)
# with patch.object(event, "pusher") as mock_pusher:
# with patch.object(
# salt.utils.event.log, "debug", autospec=True
# ) as mock_log_debug:
# mock_pusher.connect.side_effect = tornado.iostream.StreamClosedError
# event.connect_pull()
# call = mock_log_debug.mock_calls[0]
# assert call.args[0] == "Unable to connect pusher: %s"
# assert isinstance(call.args[1], tornado.iostream.StreamClosedError)
# assert call.args[1].args[0] == "Stream is closed"
#
#
#@pytest.mark.parametrize("error", [Exception, KeyError, IOError])
#def test_connect_pull_should_error_log_on_other_errors(error):
# event = SaltEvent(node=None)
# with patch.object(event, "pusher") as mock_pusher:
# with patch.object(
# salt.utils.event.log, "debug", autospec=True
# ) as mock_log_debug:
# with patch.object(
# salt.utils.event.log, "error", autospec=True
# ) as mock_log_error:
# mock_pusher.connect.side_effect = error
# event.connect_pull()
# mock_log_debug.assert_not_called()
# call = mock_log_error.mock_calls[0]
# assert call.args[0] == "Unable to connect pusher: %s"
# assert not isinstance(call.args[1], tornado.iostream.StreamClosedError)
def test_connect_pull_should_debug_log_on_StreamClosedError():
event = SaltEvent(node=None)
with patch.object(event, "pusher") as mock_pusher:
with patch.object(
salt.utils.event.log, "debug", autospec=True
) as mock_log_debug:
mock_pusher.connect.side_effect = tornado.iostream.StreamClosedError
event.connect_pull()
call = mock_log_debug.mock_calls[0]
assert call.args[0] == "Unable to connect pusher: %s"
assert isinstance(call.args[1], tornado.iostream.StreamClosedError)
assert call.args[1].args[0] == "Stream is closed"
@pytest.mark.parametrize("error", [Exception, KeyError, IOError])
def test_connect_pull_should_error_log_on_other_errors(error):
event = SaltEvent(node=None)
with patch.object(event, "pusher") as mock_pusher:
with patch.object(
salt.utils.event.log, "debug", autospec=True
) as mock_log_debug:
with patch.object(
salt.utils.event.log, "error", autospec=True
) as mock_log_error:
mock_pusher.connect.side_effect = error
event.connect_pull()
mock_log_debug.assert_not_called()
call = mock_log_error.mock_calls[0]
assert call.args[0] == "Unable to connect pusher: %s"
assert not isinstance(call.args[1], tornado.iostream.StreamClosedError)
@pytest.mark.slow_test

View file

@ -3,14 +3,13 @@
~~~~~~~~~~~~~~~~~~~~
"""
import multiprocessing
import os
import time
from contextlib import contextmanager
import salt.utils.event
from salt.utils.process import clean_proc, Process
from salt.utils.process import Process, clean_proc
@contextmanager
@ -35,7 +34,7 @@ def eventpublisher_process(sock_dir):
ipc_publisher.publish_payload,
],
)
#proc = salt.utils.event.EventPublisher({"sock_dir": sock_dir})
# proc = salt.utils.event.EventPublisher({"sock_dir": sock_dir})
proc.start()
try:
if os.environ.get("TRAVIS_PYTHON_VERSION", None) is not None: