mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
More test fixes
This commit is contained in:
parent
fea99b1335
commit
a2f428e5b3
13 changed files with 289 additions and 293 deletions
|
@ -14,6 +14,7 @@ from collections import deque
|
|||
from salt._logging.mixins import ExcInfoOnLogLevelFormatMixin
|
||||
from salt.utils.versions import warn_until_date
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -94,9 +95,6 @@ class DeferredStreamHandler(StreamHandler):
|
|||
super().__init__(stream)
|
||||
self.__messages = deque(maxlen=max_queue_size)
|
||||
self.__emitting = False
|
||||
import traceback
|
||||
|
||||
self.stack = "".join(traceback.format_stack())
|
||||
|
||||
def handle(self, record):
|
||||
self.acquire()
|
||||
|
|
|
@ -368,7 +368,7 @@ class AsyncReqChannel:
|
|||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
await self.transport.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
import tornado.gen
|
||||
|
||||
TRANSPORTS = (
|
||||
|
@ -63,7 +64,7 @@ def publish_server(opts, **kwargs):
|
|||
if "pub_host" not in kwargs and "pub_path" not in kwargs:
|
||||
kwargs["pub_host"] = opts["interface"]
|
||||
if "pub_port" not in kwargs and "pub_path" not in kwargs:
|
||||
kwargs["pub_port"] = opts["publish_port"]
|
||||
kwargs["pub_port"] = opts.get("publish_port", 4506)
|
||||
|
||||
if "pull_host" not in kwargs and "pull_path" not in kwargs:
|
||||
if opts.get("ipc_mode", "") == "tcp":
|
||||
|
|
|
@ -10,7 +10,6 @@ import asyncio
|
|||
import errno
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import queue
|
||||
import select
|
||||
import socket
|
||||
|
@ -281,7 +280,6 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
while stream is None and (not self._closed and not self._closing):
|
||||
try:
|
||||
if self.host and self.port:
|
||||
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
|
||||
self._tcp_client = TCPClientKeepAlive(
|
||||
self.opts, resolver=self.resolver
|
||||
)
|
||||
|
@ -313,14 +311,8 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
return stream
|
||||
|
||||
async def _connect(self):
|
||||
# log.error("Connect %r %r", self, self._stream)
|
||||
# import traceback
|
||||
# stack = "".join(traceback.format_stack())
|
||||
# log.error(f"MessageClient Connect {name} {stack}")
|
||||
if self._stream is None:
|
||||
log.error("Get stream")
|
||||
self._stream = await self.getstream()
|
||||
log.error("Got stream")
|
||||
if self._stream:
|
||||
# if not self._stream_return_running:
|
||||
# self.io_loop.spawn_callback(self._stream_return)
|
||||
|
@ -407,9 +399,7 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
"""
|
||||
|
||||
async def setup_callback():
|
||||
logit = True
|
||||
while not self._stream:
|
||||
# log.error("On recv wait stream %r", self._stream)
|
||||
await asyncio.sleep(0.003)
|
||||
while True:
|
||||
try:
|
||||
|
@ -418,9 +408,6 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
except tornado.iostream.StreamClosedError:
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
if logit:
|
||||
log.error("Stream Closed", exc_info=True)
|
||||
logit = False
|
||||
await self._connect()
|
||||
await asyncio.sleep(0.03)
|
||||
# if self.disconnect_callback:
|
||||
|
@ -757,15 +744,6 @@ class MessageClient:
|
|||
if self.io_loop is not None:
|
||||
self.io_loop.stop()
|
||||
|
||||
# TODO: timeout inflight sessions
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self.stream.close()
|
||||
self.tcp_client.close()
|
||||
self.io_loop.add_timeout(1, self.check_close)
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
|
@ -803,7 +781,6 @@ class MessageClient:
|
|||
while stream is None and (not self._closed and not self._closing):
|
||||
try:
|
||||
if self.host and self.port:
|
||||
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
|
||||
stream = await self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
|
@ -811,7 +788,6 @@ class MessageClient:
|
|||
**kwargs,
|
||||
)
|
||||
else:
|
||||
log.error("GET STREAM IPC")
|
||||
sock_type = socket.AF_UNIX
|
||||
path = self.url.replace("ipc://", "")
|
||||
stream = tornado.iostream.IOStream(
|
||||
|
@ -831,15 +807,8 @@ class MessageClient:
|
|||
return stream
|
||||
|
||||
async def connect(self):
|
||||
log.error("Connect %r %r", self, self._stream)
|
||||
import traceback
|
||||
|
||||
# stack = "".join(traceback.format_stack())
|
||||
# log.error(f"MessageClient Connect {name} {stack}")
|
||||
if self._stream is None:
|
||||
log.error("Get stream")
|
||||
self._stream = await self.getstream()
|
||||
log.error("Got stream")
|
||||
if self._stream:
|
||||
if not self._stream_return_running:
|
||||
self.task = asyncio.create_task(self._stream_return())
|
||||
|
@ -852,9 +821,7 @@ class MessageClient:
|
|||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
while not self._closing:
|
||||
try:
|
||||
log.error("Stream read bytes %r", self._stream.socket)
|
||||
wire_bytes = await self._stream.read_bytes(4096, partial=True)
|
||||
log.error("Stream got bytes")
|
||||
unpacker.feed(wire_bytes)
|
||||
for framed_msg in unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
|
||||
|
@ -903,9 +870,6 @@ class MessageClient:
|
|||
)
|
||||
else:
|
||||
raise SaltClientError
|
||||
# except OSError:
|
||||
# log.error("OSERROR", exc_info=True)
|
||||
# raise
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
log.error("Exception parsing response", exc_info=True)
|
||||
for future in self.send_future_map.values():
|
||||
|
@ -958,11 +922,9 @@ class MessageClient:
|
|||
future.set_exception(SaltReqTimeoutError("Message timed out"))
|
||||
|
||||
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
|
||||
# log.error("stream send %r %r %r %r", self, self.url, self.port, reply)
|
||||
if self._closing:
|
||||
raise ClosingError()
|
||||
while not self._stream:
|
||||
# log.error("Wait stream %r %r", self, self._stream)
|
||||
await asyncio.sleep(0.03)
|
||||
message_id = self._message_id()
|
||||
header = {"mid": message_id}
|
||||
|
@ -1011,7 +973,6 @@ class MessageClient:
|
|||
try:
|
||||
if timeout == 0:
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV a")
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
poller = select.poll()
|
||||
|
@ -1025,27 +986,21 @@ class MessageClient:
|
|||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV b")
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
else:
|
||||
return
|
||||
elif timeout:
|
||||
log.error("RECV c")
|
||||
return await asyncio.wait_for(self.recv(), timeout=timeout)
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV d")
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
while True:
|
||||
log.error("RECV e")
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV e %r", msg)
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
log.error("RECV e %r", framed_msg)
|
||||
return framed_msg["body"]
|
||||
finally:
|
||||
self._read_in_progress.release()
|
||||
|
@ -1211,39 +1166,38 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
self.opts = opts
|
||||
self.pub_sock = None
|
||||
# Set up Salt IPC server
|
||||
#if self.opts.get("ipc_mode", "") == "tcp":
|
||||
# if self.opts.get("ipc_mode", "") == "tcp":
|
||||
# self.pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
|
||||
#else:
|
||||
# else:
|
||||
# self.pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
#interface = self.opts.get("interface", "127.0.0.1")
|
||||
#self.publish_port = self.opts.get("publish_port", 4560)
|
||||
#self.pub_uri = f"tcp://{interface}:{self.publish_port}"
|
||||
# interface = self.opts.get("interface", "127.0.0.1")
|
||||
# self.publish_port = self.opts.get("publish_port", 4560)
|
||||
# self.pub_uri = f"tcp://{interface}:{self.publish_port}"
|
||||
self.pub_host = kwargs.get("pub_host", None)
|
||||
self.pub_port = kwargs.get("pub_port", None)
|
||||
self.pub_path = kwargs.get("pub_path", None)
|
||||
#if pub_path:
|
||||
# if pub_path:
|
||||
# self.pub_path = pub_path
|
||||
# self.pub_uri = f"ipc://{pub_path}"
|
||||
#else:
|
||||
# else:
|
||||
# self.pub_uri = f"tcp://{pub_host}:{pub_port}"
|
||||
|
||||
#self.publish_port = self.opts.get("publish_port", 4560)
|
||||
|
||||
# self.publish_port = self.opts.get("publish_port", 4560)
|
||||
|
||||
self.pull_host = kwargs.get("pull_host", None)
|
||||
self.pull_port = kwargs.get("pull_port", None)
|
||||
self.pull_path = kwargs.get("pull_path", None)
|
||||
#if pull_path:
|
||||
# if pull_path:
|
||||
# self.pull_uri = f"ipc://{pull_path}"
|
||||
#else:
|
||||
# else:
|
||||
# self.pull_uri = f"tcp://{pub_host}:{pub_port}"
|
||||
#log.error(
|
||||
# log.error(
|
||||
# "TCPPubServer %r %s %s",
|
||||
# self,
|
||||
# self.pull_uri,
|
||||
# #self.publish_port,
|
||||
# self.pub_uri,
|
||||
#)
|
||||
# )
|
||||
|
||||
@property
|
||||
def topic_support(self):
|
||||
|
@ -1265,13 +1219,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
Bind to the interface specified in the configuration file
|
||||
"""
|
||||
io_loop = tornado.ioloop.IOLoop()
|
||||
#log.error(
|
||||
# log.error(
|
||||
# "TCPPubServer daemon %r %s %s %s",
|
||||
# self,
|
||||
# self.pull_uri,
|
||||
# self.publish_port,
|
||||
# self.pub_uri,
|
||||
#)
|
||||
# )
|
||||
|
||||
# Spin up the publisher
|
||||
self.pub_server = pub_server = PubServer(
|
||||
|
@ -1298,13 +1252,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
# else:
|
||||
# pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
self.pub_server = pub_server
|
||||
#if "ipc://" in self.pull_uri:
|
||||
# if "ipc://" in self.pull_uri:
|
||||
# pull_uri = pull_uri = self.pull_uri.replace("ipc://", "")
|
||||
# log.error("WTF PULL URI %r", pull_uri)
|
||||
#elif "tcp://" in self.pull_uri:
|
||||
# elif "tcp://" in self.pull_uri:
|
||||
# log.error("Fallback to publish port %r", self.pull_uri)
|
||||
# pull_uri = self.publish_port
|
||||
#else:
|
||||
# else:
|
||||
# pull_uri = self.pull_uri
|
||||
if self.pull_path:
|
||||
pull_uri = self.pull_path
|
||||
|
@ -1345,7 +1299,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
raise tornado.gen.Return(ret)
|
||||
|
||||
def connect(self):
|
||||
#path = self.pull_uri.replace("ipc://", "")
|
||||
# path = self.pull_uri.replace("ipc://", "")
|
||||
log.error("Connect pusher %s", self.pull_path)
|
||||
# self.pub_sock = salt.utils.asynchronous.SyncWrapper(
|
||||
# salt.transport.ipc.IPCMessageClient,
|
||||
|
@ -1417,15 +1371,11 @@ class TCPReqClient(salt.transport.base.RequestClient):
|
|||
)
|
||||
|
||||
async def connect(self):
|
||||
log.error("TCPReqClient Connect")
|
||||
await self.message_client.connect()
|
||||
log.error("TCPReqClient Connected")
|
||||
|
||||
async def send(self, load, timeout=60):
|
||||
await self.connect()
|
||||
log.error("TCP Request %r", load)
|
||||
msg = await self.message_client.send(load, timeout=timeout)
|
||||
log.error("TCP Reply %r", msg)
|
||||
return msg
|
||||
|
||||
def close(self):
|
||||
|
|
|
@ -222,7 +222,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
if self.path:
|
||||
raise Exception("A host and port or a path must be provided, not both")
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
if self._closing is True:
|
||||
return
|
||||
self._closing = True
|
||||
|
@ -238,21 +238,10 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
if self.callbacks:
|
||||
for cb in self.callbacks:
|
||||
running, task = self.callbacks[cb]
|
||||
task.cancel()
|
||||
|
||||
def close(self):
|
||||
if self._closing is True:
|
||||
return
|
||||
self._closing = True
|
||||
if hasattr(self, "_monitor") and self._monitor is not None:
|
||||
self._monitor.stop()
|
||||
self._monitor = None
|
||||
if hasattr(self, "_stream"):
|
||||
self._stream.close(0)
|
||||
elif hasattr(self, "_socket"):
|
||||
self._socket.close(0)
|
||||
if hasattr(self, "context") and self.context.closed is False:
|
||||
self.context.term()
|
||||
try:
|
||||
task.cancel()
|
||||
except RuntimeError:
|
||||
log.warning("Tasks loop already closed")
|
||||
|
||||
# pylint: enable=W1701
|
||||
def __enter__(self):
|
||||
|
@ -402,7 +391,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
|||
self._closing = False
|
||||
self._monitor = None
|
||||
self._w_monitor = None
|
||||
self.task = None
|
||||
self.tasks = set()
|
||||
self._event = asyncio.Event()
|
||||
|
||||
def zmq_device(self):
|
||||
|
@ -485,6 +474,11 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
|||
self._socket.close()
|
||||
if hasattr(self, "context") and self.context.closed is False:
|
||||
self.context.term()
|
||||
for task in list(self.tasks):
|
||||
try:
|
||||
task.cancel()
|
||||
except RuntimeError:
|
||||
log.error("IOLoop closed when trying to cancel task")
|
||||
|
||||
def pre_fork(self, process_manager):
|
||||
"""
|
||||
|
@ -518,8 +512,8 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
|||
:param IOLoop io_loop: An instance of a Tornado IOLoop, to handle event scheduling
|
||||
"""
|
||||
# context = zmq.Context(1)
|
||||
context = zmq.asyncio.Context(1)
|
||||
self._socket = context.socket(zmq.REP)
|
||||
self.context = zmq.asyncio.Context(1)
|
||||
self._socket = self.context.socket(zmq.REP)
|
||||
# Linger -1 means we'll never discard messages.
|
||||
self._socket.setsockopt(zmq.LINGER, -1)
|
||||
self._start_zmq_monitor()
|
||||
|
@ -541,16 +535,23 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
|||
self.message_handler = message_handler
|
||||
|
||||
async def callback():
|
||||
self.task = asyncio.create_task(self.request_handler())
|
||||
await self.task
|
||||
task = asyncio.create_task(self.request_handler())
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
|
||||
io_loop.add_callback(callback)
|
||||
|
||||
async def request_handler(self):
|
||||
while not self._event.is_set():
|
||||
request = await self._socket.recv()
|
||||
reply = await self.handle_message(None, request)
|
||||
await self._socket.send(self.encode_payload(reply))
|
||||
try:
|
||||
request = await asyncio.wait_for(self._socket.recv(), 1)
|
||||
reply = await self.handle_message(None, request)
|
||||
await self._socket.send(self.encode_payload(reply))
|
||||
except TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
log.error("Exception in request handler", exc_info=True)
|
||||
break
|
||||
|
||||
async def handle_message(self, stream, payload):
|
||||
payload = self.decode_payload(payload)
|
||||
|
@ -605,6 +606,7 @@ def _set_tcp_keepalive(zmq_socket, opts):
|
|||
|
||||
ctx = zmq.asyncio.Context()
|
||||
|
||||
|
||||
# TODO: unit tests!
|
||||
class AsyncReqMessageClient:
|
||||
"""
|
||||
|
@ -643,25 +645,20 @@ class AsyncReqMessageClient:
|
|||
self.socket = None
|
||||
|
||||
async def connect(self):
|
||||
# wire up sockets
|
||||
self._init_socket()
|
||||
if self.socket is None:
|
||||
# wire up sockets
|
||||
self._init_socket()
|
||||
|
||||
# TODO: timeout all in-flight sessions, or error
|
||||
def close(self):
|
||||
try:
|
||||
if self._closing:
|
||||
return
|
||||
except AttributeError:
|
||||
# We must have been called from __del__
|
||||
# The python interpreter has nuked most attributes already
|
||||
if self._closing:
|
||||
return
|
||||
else:
|
||||
self._closing = True
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
if self.context.closed is False:
|
||||
# This hangs if closing the stream causes an import error
|
||||
self.context.term()
|
||||
self._closing = True
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
if self.context.closed is False:
|
||||
# This hangs if closing the stream causes an import error
|
||||
self.context.term()
|
||||
|
||||
def _init_socket(self):
|
||||
if self.socket is not None:
|
||||
|
@ -809,17 +806,17 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
|
||||
def __init__(self, opts, **kwargs):
|
||||
self.opts = opts
|
||||
#if self.opts.get("ipc_mode", "") == "tcp":
|
||||
# if self.opts.get("ipc_mode", "") == "tcp":
|
||||
# self.pull_uri = "tcp://127.0.0.1:{}".format(
|
||||
# self.opts.get("tcp_master_publish_pull", 4514)
|
||||
# )
|
||||
#else:
|
||||
# else:
|
||||
# self.pull_uri = "ipc://{}".format(
|
||||
# os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
# )
|
||||
#interface = self.opts.get("interface", "127.0.0.1")
|
||||
#publish_port = self.opts.get("publish_port", 4560)
|
||||
#self.pub_uri = f"tcp://{interface}:{publish_port}"
|
||||
# interface = self.opts.get("interface", "127.0.0.1")
|
||||
# publish_port = self.opts.get("publish_port", 4560)
|
||||
# self.pub_uri = f"tcp://{interface}:{publish_port}"
|
||||
|
||||
pub_host = kwargs.get("pub_host", None)
|
||||
pub_port = kwargs.get("pub_port", None)
|
||||
|
@ -829,7 +826,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
else:
|
||||
self.pub_uri = f"tcp://{pub_host}:{pub_port}"
|
||||
|
||||
|
||||
pull_host = kwargs.get("pull_host", None)
|
||||
pull_port = kwargs.get("pull_port", None)
|
||||
pull_path = kwargs.get("pull_path", None)
|
||||
|
@ -838,7 +834,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
else:
|
||||
self.pull_uri = f"tcp://{pull_host}:{pull_port}"
|
||||
|
||||
|
||||
self.ctx = None
|
||||
self.sock = None
|
||||
self.daemon_context = None
|
||||
|
|
|
@ -110,9 +110,6 @@ class SyncWrapper:
|
|||
log.error("No async method %s on object %r", method, self.obj)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.exception("Exception encountered while running stop method")
|
||||
# thread = threading.Thread(target=self._run_loop_final, args=(self.asyncio_loop,))
|
||||
# thread.start()
|
||||
# thread.join()
|
||||
io_loop = self.io_loop
|
||||
io_loop.stop()
|
||||
try:
|
||||
|
@ -143,18 +140,6 @@ class SyncWrapper:
|
|||
|
||||
return wrap
|
||||
|
||||
def _run_loop_final(self, asyncio_loop):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
try:
|
||||
|
||||
async def noop():
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
result = io_loop.run_sync(lambda: noop())
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.error("Error on last loop run")
|
||||
|
||||
def _target(self, key, args, kwargs, results, asyncio_loop):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
|
@ -173,10 +158,16 @@ class SyncWrapper:
|
|||
return self
|
||||
else:
|
||||
return ret
|
||||
elif hasattr(self.obj, "__enter__"):
|
||||
ret = self.obj.__enter__()
|
||||
if ret == self.obj:
|
||||
return self
|
||||
else:
|
||||
return ret
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, tb):
|
||||
if hasattr(self.obj, "__aexit__"):
|
||||
return self._wrap("__aexit__")
|
||||
return self._wrap("__aexit__")(exc_type, exc_val, tb)
|
||||
else:
|
||||
self.close()
|
||||
|
|
|
@ -631,12 +631,11 @@ def io_loop():
|
|||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = tornado.ioloop.IOLoop.current()
|
||||
loop.make_current()
|
||||
try:
|
||||
yield loop
|
||||
finally:
|
||||
loop.clear_current()
|
||||
loop.close(all_fds=True)
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
|
||||
# <---- Async Test Fixtures ------------------------------------------------------------------------------------------
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import tracemalloc
|
||||
|
@ -64,7 +65,6 @@ def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
|
|||
|
||||
ZMQ's ipc transport not supported on Windows
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
opts = dict(
|
||||
salt_master.config.copy(),
|
||||
|
@ -98,7 +98,6 @@ def test_issue_36469_tcp(salt_master, salt_minion, transport):
|
|||
"""
|
||||
if transport == "tcp":
|
||||
pytest.skip("Test not applicable to the ZeroMQ transport.")
|
||||
import asyncio
|
||||
|
||||
def _send_small(opts, sid, num=10):
|
||||
loop = asyncio.new_event_loop()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
:codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ctypes
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
@ -337,7 +337,7 @@ def run_loop_in_thread(loop, evt):
|
|||
"""
|
||||
Run the provided loop until an event is set
|
||||
"""
|
||||
loop.make_current()
|
||||
asyncio.set_event_loop(loop.asyncio_loop)
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def stopper():
|
||||
|
@ -377,7 +377,7 @@ class MockSaltMinionMaster:
|
|||
master_opts.update({"transport": "zeromq"})
|
||||
self.server_channel = salt.channel.server.ReqServerChannel.factory(master_opts)
|
||||
self.server_channel.pre_fork(self.process_manager)
|
||||
self.io_loop = tornado.ioloop.IOLoop()
|
||||
self.io_loop = tornado.ioloop.IOLoop(make_current=False)
|
||||
self.evt = threading.Event()
|
||||
self.server_channel.post_fork(self._handle_payload, io_loop=self.io_loop)
|
||||
self.server_thread = threading.Thread(
|
||||
|
@ -403,6 +403,7 @@ class MockSaltMinionMaster:
|
|||
def __exit__(self, *args, **kwargs):
|
||||
self.channel.__exit__(*args, **kwargs)
|
||||
del self.channel
|
||||
self.server_channel.close()
|
||||
# Attempting to kill the children hangs the test suite.
|
||||
# Let the test suite handle this instead.
|
||||
self.process_manager.stop_restarting()
|
||||
|
@ -411,7 +412,6 @@ class MockSaltMinionMaster:
|
|||
self.server_thread.join()
|
||||
# Give the procs a chance to fully close before we stop the io_loop
|
||||
time.sleep(2)
|
||||
self.server_channel.close()
|
||||
SMaster.secrets.pop("aes")
|
||||
del self.server_channel
|
||||
del self.io_loop
|
||||
|
@ -481,28 +481,33 @@ def test_req_server_chan_encrypt_v2(pki_dir):
|
|||
dictkey = "pillar"
|
||||
nonce = "abcdefg"
|
||||
pillar_data = {"pillar1": "meh"}
|
||||
ret = server._encrypt_private(pillar_data, dictkey, "minion", nonce)
|
||||
assert "key" in ret
|
||||
assert dictkey in ret
|
||||
try:
|
||||
ret = server._encrypt_private(pillar_data, dictkey, "minion", nonce)
|
||||
assert "key" in ret
|
||||
assert dictkey in ret
|
||||
|
||||
key = salt.crypt.get_rsa_key(str(pki_dir.joinpath("minion", "minion.pem")), None)
|
||||
if HAS_M2:
|
||||
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
|
||||
else:
|
||||
cipher = PKCS1_OAEP.new(key)
|
||||
aes = cipher.decrypt(ret["key"])
|
||||
pcrypt = salt.crypt.Crypticle(opts, aes)
|
||||
signed_msg = pcrypt.loads(ret[dictkey])
|
||||
key = salt.crypt.get_rsa_key(
|
||||
str(pki_dir.joinpath("minion", "minion.pem")), None
|
||||
)
|
||||
if HAS_M2:
|
||||
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
|
||||
else:
|
||||
cipher = PKCS1_OAEP.new(key)
|
||||
aes = cipher.decrypt(ret["key"])
|
||||
pcrypt = salt.crypt.Crypticle(opts, aes)
|
||||
signed_msg = pcrypt.loads(ret[dictkey])
|
||||
|
||||
assert "sig" in signed_msg
|
||||
assert "data" in signed_msg
|
||||
data = salt.payload.loads(signed_msg["data"])
|
||||
assert "key" in data
|
||||
assert data["key"] == ret["key"]
|
||||
assert "key" in data
|
||||
assert data["nonce"] == nonce
|
||||
assert "pillar" in data
|
||||
assert data["pillar"] == pillar_data
|
||||
assert "sig" in signed_msg
|
||||
assert "data" in signed_msg
|
||||
data = salt.payload.loads(signed_msg["data"])
|
||||
assert "key" in data
|
||||
assert data["key"] == ret["key"]
|
||||
assert "key" in data
|
||||
assert data["nonce"] == nonce
|
||||
assert "pillar" in data
|
||||
assert data["pillar"] == pillar_data
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
|
||||
def test_req_server_chan_encrypt_v1(pki_dir):
|
||||
|
@ -525,20 +530,27 @@ def test_req_server_chan_encrypt_v1(pki_dir):
|
|||
dictkey = "pillar"
|
||||
nonce = "abcdefg"
|
||||
pillar_data = {"pillar1": "meh"}
|
||||
ret = server._encrypt_private(pillar_data, dictkey, "minion", sign_messages=False)
|
||||
try:
|
||||
ret = server._encrypt_private(
|
||||
pillar_data, dictkey, "minion", sign_messages=False
|
||||
)
|
||||
|
||||
assert "key" in ret
|
||||
assert dictkey in ret
|
||||
assert "key" in ret
|
||||
assert dictkey in ret
|
||||
|
||||
key = salt.crypt.get_rsa_key(str(pki_dir.joinpath("minion", "minion.pem")), None)
|
||||
if HAS_M2:
|
||||
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
|
||||
else:
|
||||
cipher = PKCS1_OAEP.new(key)
|
||||
aes = cipher.decrypt(ret["key"])
|
||||
pcrypt = salt.crypt.Crypticle(opts, aes)
|
||||
data = pcrypt.loads(ret[dictkey])
|
||||
assert data == pillar_data
|
||||
key = salt.crypt.get_rsa_key(
|
||||
str(pki_dir.joinpath("minion", "minion.pem")), None
|
||||
)
|
||||
if HAS_M2:
|
||||
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
|
||||
else:
|
||||
cipher = PKCS1_OAEP.new(key)
|
||||
aes = cipher.decrypt(ret["key"])
|
||||
pcrypt = salt.crypt.Crypticle(opts, aes)
|
||||
data = pcrypt.loads(ret[dictkey])
|
||||
assert data == pillar_data
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
|
||||
def test_req_chan_decode_data_dict_entry_v1(pki_dir):
|
||||
|
@ -559,19 +571,23 @@ def test_req_chan_decode_data_dict_entry_v1(pki_dir):
|
|||
master_opts = dict(opts, pki_dir=str(pki_dir.joinpath("master")))
|
||||
server = salt.channel.server.ReqServerChannel.factory(master_opts)
|
||||
client = salt.channel.client.ReqChannel.factory(opts, io_loop=mockloop)
|
||||
dictkey = "pillar"
|
||||
target = "minion"
|
||||
pillar_data = {"pillar1": "meh"}
|
||||
ret = server._encrypt_private(pillar_data, dictkey, target, sign_messages=False)
|
||||
key = client.auth.get_keys()
|
||||
if HAS_M2:
|
||||
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
|
||||
else:
|
||||
cipher = PKCS1_OAEP.new(key)
|
||||
aes = cipher.decrypt(ret["key"])
|
||||
pcrypt = salt.crypt.Crypticle(client.opts, aes)
|
||||
ret_pillar_data = pcrypt.loads(ret[dictkey])
|
||||
assert ret_pillar_data == pillar_data
|
||||
try:
|
||||
dictkey = "pillar"
|
||||
target = "minion"
|
||||
pillar_data = {"pillar1": "meh"}
|
||||
ret = server._encrypt_private(pillar_data, dictkey, target, sign_messages=False)
|
||||
key = client.auth.get_keys()
|
||||
if HAS_M2:
|
||||
aes = key.private_decrypt(ret["key"], RSA.pkcs1_oaep_padding)
|
||||
else:
|
||||
cipher = PKCS1_OAEP.new(key)
|
||||
aes = cipher.decrypt(ret["key"])
|
||||
pcrypt = salt.crypt.Crypticle(client.opts, aes)
|
||||
ret_pillar_data = pcrypt.loads(ret[dictkey])
|
||||
assert ret_pillar_data == pillar_data
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_decode_data_dict_entry_v2(pki_dir):
|
||||
|
@ -606,7 +622,9 @@ async def test_req_chan_decode_data_dict_entry_v2(pki_dir):
|
|||
client.auth.get_keys = auth.get_keys
|
||||
client.auth.crypticle.dumps = auth.crypticle.dumps
|
||||
client.auth.crypticle.loads = auth.crypticle.loads
|
||||
real_transport = client.transport
|
||||
client.transport = MagicMock()
|
||||
real_transport.close()
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def mocksend(msg, timeout=60, tries=3):
|
||||
|
@ -631,13 +649,17 @@ async def test_req_chan_decode_data_dict_entry_v2(pki_dir):
|
|||
"ver": "2",
|
||||
"cmd": "_pillar",
|
||||
}
|
||||
ret = await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "version" in client.transport.msg
|
||||
assert client.transport.msg["version"] == 2
|
||||
assert ret == {"pillar1": "meh"}
|
||||
try:
|
||||
ret = await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "version" in client.transport.msg
|
||||
assert client.transport.msg["version"] == 2
|
||||
assert ret == {"pillar1": "meh"}
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir):
|
||||
|
@ -673,7 +695,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir):
|
|||
client.auth.get_keys = auth.get_keys
|
||||
client.auth.crypticle.dumps = auth.crypticle.dumps
|
||||
client.auth.crypticle.loads = auth.crypticle.loads
|
||||
real_transport = client.transport
|
||||
client.transport = MagicMock()
|
||||
real_transport.close()
|
||||
ret = server._encrypt_private(
|
||||
pillar_data, dictkey, target, nonce=badnonce, sign_messages=True
|
||||
)
|
||||
|
@ -698,12 +722,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_nonce(pki_dir):
|
|||
"cmd": "_pillar",
|
||||
}
|
||||
|
||||
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
|
||||
ret = await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "Pillar nonce verification failed." == excinfo.value.message
|
||||
try:
|
||||
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
|
||||
ret = await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "Pillar nonce verification failed." == excinfo.value.message
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir):
|
||||
|
@ -739,7 +767,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir):
|
|||
client.auth.get_keys = auth.get_keys
|
||||
client.auth.crypticle.dumps = auth.crypticle.dumps
|
||||
client.auth.crypticle.loads = auth.crypticle.loads
|
||||
real_transport = client.transport
|
||||
client.transport = MagicMock()
|
||||
real_transport.close()
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def mocksend(msg, timeout=60, tries=3):
|
||||
|
@ -780,12 +810,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_signature(pki_dir):
|
|||
"cmd": "_pillar",
|
||||
}
|
||||
|
||||
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
|
||||
ret = await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "Pillar payload signature failed to validate." == excinfo.value.message
|
||||
try:
|
||||
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
|
||||
ret = await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "Pillar payload signature failed to validate." == excinfo.value.message
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir):
|
||||
|
@ -821,7 +855,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir):
|
|||
client.auth.get_keys = auth.get_keys
|
||||
client.auth.crypticle.dumps = auth.crypticle.dumps
|
||||
client.auth.crypticle.loads = auth.crypticle.loads
|
||||
real_transport = client.transport
|
||||
client.transport = MagicMock()
|
||||
real_transport.close()
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def mocksend(msg, timeout=60, tries=3):
|
||||
|
@ -869,12 +905,16 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_key(pki_dir):
|
|||
"cmd": "_pillar",
|
||||
}
|
||||
|
||||
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
|
||||
await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "Key verification failed." == excinfo.value.message
|
||||
try:
|
||||
with pytest.raises(salt.crypt.AuthenticationError) as excinfo:
|
||||
await client.crypted_transfer_decode_dictentry(
|
||||
load,
|
||||
dictkey="pillar",
|
||||
)
|
||||
assert "Key verification failed." == excinfo.value.message
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_serv_auth_v1(pki_dir):
|
||||
|
@ -926,8 +966,11 @@ async def test_req_serv_auth_v1(pki_dir):
|
|||
"token": token,
|
||||
"pub": pub_key,
|
||||
}
|
||||
ret = server._auth(load, sign_messages=False)
|
||||
assert "load" not in ret
|
||||
try:
|
||||
ret = server._auth(load, sign_messages=False)
|
||||
assert "load" not in ret
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_serv_auth_v2(pki_dir):
|
||||
|
@ -980,9 +1023,12 @@ async def test_req_serv_auth_v2(pki_dir):
|
|||
"token": token,
|
||||
"pub": pub_key,
|
||||
}
|
||||
ret = server._auth(load, sign_messages=True)
|
||||
assert "sig" in ret
|
||||
assert "load" in ret
|
||||
try:
|
||||
ret = server._auth(load, sign_messages=True)
|
||||
assert "sig" in ret
|
||||
assert "load" in ret
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_auth_v2(pki_dir, io_loop):
|
||||
|
@ -1023,15 +1069,19 @@ async def test_req_chan_auth_v2(pki_dir, io_loop):
|
|||
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
|
||||
signin_payload = client.auth.minion_sign_in_payload()
|
||||
pload = client._package_load(signin_payload)
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
try:
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
assert "aes" in ret
|
||||
assert "master_uri" in ret
|
||||
assert "publish_port" in ret
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
assert "aes" in ret
|
||||
assert "master_uri" in ret
|
||||
assert "publish_port" in ret
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_auth_v2_with_master_signing(pki_dir, io_loop):
|
||||
|
@ -1113,16 +1163,20 @@ async def test_req_chan_auth_v2_with_master_signing(pki_dir, io_loop):
|
|||
signin_payload = client.auth.minion_sign_in_payload()
|
||||
pload = client._package_load(signin_payload)
|
||||
server_reply = server._auth(pload["load"], sign_messages=True)
|
||||
ret = client.auth.handle_signin_response(signin_payload, server_reply)
|
||||
try:
|
||||
ret = client.auth.handle_signin_response(signin_payload, server_reply)
|
||||
|
||||
assert "aes" in ret
|
||||
assert "master_uri" in ret
|
||||
assert "publish_port" in ret
|
||||
assert "aes" in ret
|
||||
assert "master_uri" in ret
|
||||
assert "publish_port" in ret
|
||||
|
||||
assert (
|
||||
pki_dir.joinpath("minion", "minion_master.pub").read_text()
|
||||
== pki_dir.joinpath("master", "master.pub").read_text()
|
||||
)
|
||||
assert (
|
||||
pki_dir.joinpath("minion", "minion_master.pub").read_text()
|
||||
== pki_dir.joinpath("master", "master.pub").read_text()
|
||||
)
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_auth_v2_new_minion_with_master_pub(pki_dir, io_loop):
|
||||
|
@ -1165,13 +1219,17 @@ async def test_req_chan_auth_v2_new_minion_with_master_pub(pki_dir, io_loop):
|
|||
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
|
||||
signin_payload = client.auth.minion_sign_in_payload()
|
||||
pload = client._package_load(signin_payload)
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
try:
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
assert ret == "retry"
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
assert ret == "retry"
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_auth_v2_new_minion_with_master_pub_bad_sig(pki_dir, io_loop):
|
||||
|
@ -1223,13 +1281,17 @@ async def test_req_chan_auth_v2_new_minion_with_master_pub_bad_sig(pki_dir, io_l
|
|||
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
|
||||
signin_payload = client.auth.minion_sign_in_payload()
|
||||
pload = client._package_load(signin_payload)
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
try:
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
with pytest.raises(salt.crypt.SaltClientError, match="Invalid signature"):
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
with pytest.raises(salt.crypt.SaltClientError, match="Invalid signature"):
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_req_chan_auth_v2_new_minion_without_master_pub(pki_dir, io_loop):
|
||||
|
@ -1273,10 +1335,14 @@ async def test_req_chan_auth_v2_new_minion_without_master_pub(pki_dir, io_loop):
|
|||
client = salt.channel.client.AsyncReqChannel.factory(opts, io_loop=io_loop)
|
||||
signin_payload = client.auth.minion_sign_in_payload()
|
||||
pload = client._package_load(signin_payload)
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
try:
|
||||
assert "version" in pload
|
||||
assert pload["version"] == 2
|
||||
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
assert ret == "retry"
|
||||
ret = server._auth(pload["load"], sign_messages=True)
|
||||
assert "sig" in ret
|
||||
ret = client.auth.handle_signin_response(signin_payload, ret)
|
||||
assert ret == "retry"
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
|
|
@ -174,6 +174,8 @@ async def test_publish_client_connect_server_down(transport, io_loop):
|
|||
await client.connect(background=True)
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception:
|
||||
log.error("Got exception", exc_info=True)
|
||||
assert client._stream is None
|
||||
client.close()
|
||||
|
||||
|
@ -188,7 +190,6 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
|
|||
|
||||
import zmq
|
||||
|
||||
return
|
||||
ctx = zmq.asyncio.Context()
|
||||
uri = f"tcp://{opts['master_ip']}:{port}"
|
||||
msg = salt.payload.dumps({"meh": 123})
|
||||
|
@ -196,23 +197,21 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
|
|||
client = salt.transport.zeromq.PublishClient(
|
||||
opts, io_loop, host=host, port=port
|
||||
)
|
||||
await client.connect(background=True)
|
||||
await client.connect()
|
||||
assert client._socket
|
||||
|
||||
socket = ctx.socket(zmq.PUB)
|
||||
socket.setsockopt(zmq.BACKLOG, 1000)
|
||||
socket.setsockopt(zmq.LINGER, -1)
|
||||
socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
print(f"bind {uri}")
|
||||
socket.bind(uri)
|
||||
await asyncio.sleep(10)
|
||||
await asyncio.sleep(20)
|
||||
|
||||
async def recv():
|
||||
return await client.recv(timeout=1)
|
||||
|
||||
task = asyncio.create_task(recv())
|
||||
# Sleep to allow zmq to do it's thing.
|
||||
await asyncio.sleep(0.03)
|
||||
await socket.send(msg)
|
||||
await task
|
||||
response = task.result()
|
||||
|
|
|
@ -112,7 +112,7 @@ async def test_message_client_cleanup_on_close(
|
|||
|
||||
# The run_sync call will set stop_called, reset it
|
||||
# orig_loop.stop_called = False
|
||||
await client.close()
|
||||
client.close()
|
||||
|
||||
# Stop should be called again, client's io_loop should be None
|
||||
# assert orig_loop.stop_called is True
|
||||
|
|
|
@ -5,9 +5,8 @@ import time
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import tornado.ioloop
|
||||
import tornado.iostream
|
||||
import zmq.eventloop.ioloop
|
||||
import zmq
|
||||
|
||||
import salt.config
|
||||
import salt.utils.event
|
||||
|
@ -296,36 +295,36 @@ def test_send_master_event(sock_dir):
|
|||
)
|
||||
|
||||
|
||||
#def test_connect_pull_should_debug_log_on_StreamClosedError():
|
||||
# event = SaltEvent(node=None)
|
||||
# with patch.object(event, "pusher") as mock_pusher:
|
||||
# with patch.object(
|
||||
# salt.utils.event.log, "debug", autospec=True
|
||||
# ) as mock_log_debug:
|
||||
# mock_pusher.connect.side_effect = tornado.iostream.StreamClosedError
|
||||
# event.connect_pull()
|
||||
# call = mock_log_debug.mock_calls[0]
|
||||
# assert call.args[0] == "Unable to connect pusher: %s"
|
||||
# assert isinstance(call.args[1], tornado.iostream.StreamClosedError)
|
||||
# assert call.args[1].args[0] == "Stream is closed"
|
||||
#
|
||||
#
|
||||
#@pytest.mark.parametrize("error", [Exception, KeyError, IOError])
|
||||
#def test_connect_pull_should_error_log_on_other_errors(error):
|
||||
# event = SaltEvent(node=None)
|
||||
# with patch.object(event, "pusher") as mock_pusher:
|
||||
# with patch.object(
|
||||
# salt.utils.event.log, "debug", autospec=True
|
||||
# ) as mock_log_debug:
|
||||
# with patch.object(
|
||||
# salt.utils.event.log, "error", autospec=True
|
||||
# ) as mock_log_error:
|
||||
# mock_pusher.connect.side_effect = error
|
||||
# event.connect_pull()
|
||||
# mock_log_debug.assert_not_called()
|
||||
# call = mock_log_error.mock_calls[0]
|
||||
# assert call.args[0] == "Unable to connect pusher: %s"
|
||||
# assert not isinstance(call.args[1], tornado.iostream.StreamClosedError)
|
||||
def test_connect_pull_should_debug_log_on_StreamClosedError():
|
||||
event = SaltEvent(node=None)
|
||||
with patch.object(event, "pusher") as mock_pusher:
|
||||
with patch.object(
|
||||
salt.utils.event.log, "debug", autospec=True
|
||||
) as mock_log_debug:
|
||||
mock_pusher.connect.side_effect = tornado.iostream.StreamClosedError
|
||||
event.connect_pull()
|
||||
call = mock_log_debug.mock_calls[0]
|
||||
assert call.args[0] == "Unable to connect pusher: %s"
|
||||
assert isinstance(call.args[1], tornado.iostream.StreamClosedError)
|
||||
assert call.args[1].args[0] == "Stream is closed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("error", [Exception, KeyError, IOError])
|
||||
def test_connect_pull_should_error_log_on_other_errors(error):
|
||||
event = SaltEvent(node=None)
|
||||
with patch.object(event, "pusher") as mock_pusher:
|
||||
with patch.object(
|
||||
salt.utils.event.log, "debug", autospec=True
|
||||
) as mock_log_debug:
|
||||
with patch.object(
|
||||
salt.utils.event.log, "error", autospec=True
|
||||
) as mock_log_error:
|
||||
mock_pusher.connect.side_effect = error
|
||||
event.connect_pull()
|
||||
mock_log_debug.assert_not_called()
|
||||
call = mock_log_error.mock_calls[0]
|
||||
assert call.args[0] == "Unable to connect pusher: %s"
|
||||
assert not isinstance(call.args[1], tornado.iostream.StreamClosedError)
|
||||
|
||||
|
||||
@pytest.mark.slow_test
|
||||
|
|
|
@ -3,14 +3,13 @@
|
|||
~~~~~~~~~~~~~~~~~~~~
|
||||
"""
|
||||
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
import salt.utils.event
|
||||
from salt.utils.process import clean_proc, Process
|
||||
from salt.utils.process import Process, clean_proc
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -35,7 +34,7 @@ def eventpublisher_process(sock_dir):
|
|||
ipc_publisher.publish_payload,
|
||||
],
|
||||
)
|
||||
#proc = salt.utils.event.EventPublisher({"sock_dir": sock_dir})
|
||||
# proc = salt.utils.event.EventPublisher({"sock_dir": sock_dir})
|
||||
proc.start()
|
||||
try:
|
||||
if os.environ.get("TRAVIS_PYTHON_VERSION", None) is not None:
|
||||
|
|
Loading…
Add table
Reference in a new issue