Test fixes

This commit is contained in:
Daniel A. Wozniak 2023-06-26 00:51:31 -07:00 committed by Gareth J. Greenaway
parent 077c253954
commit 9683260d61
22 changed files with 2426 additions and 1750 deletions

View file

@ -94,6 +94,9 @@ class DeferredStreamHandler(StreamHandler):
super().__init__(stream)
self.__messages = deque(maxlen=max_queue_size)
self.__emitting = False
import traceback
self.stack = "".join(traceback.format_stack())
def handle(self, record):
self.acquire()
@ -115,6 +118,7 @@ class DeferredStreamHandler(StreamHandler):
super().handle(record)
finally:
self.__emitting = False
# This will raise a ValueError if the file handle has been closed.
super().flush()
def sync_with_handlers(self, handlers=()):

View file

@ -474,7 +474,15 @@ def setup_temp_handler(log_level=None):
break
else:
handler = DeferredStreamHandler(sys.stderr)
atexit.register(handler.flush)
def tryflush():
try:
handler.flush()
except ValueError:
# File handle has already been closed.
pass
atexit.register(tryflush)
handler.setLevel(log_level)
# Set the default temporary console formatter config

View file

@ -103,6 +103,7 @@ class AsyncReqChannel:
"_uncrypted_transfer",
"send",
"connect",
# "close",
]
close_methods = [
"close",
@ -314,9 +315,8 @@ class AsyncReqChannel:
raise tornado.gen.Return(ret)
@tornado.gen.coroutine
def connect(self):
yield self.transport.connect()
async def connect(self):
await self.transport.connect()
@tornado.gen.coroutine
def send(self, load, tries=None, timeout=None, raw=False):
@ -367,6 +367,14 @@ class AsyncReqChannel:
def __exit__(self, *args):
self.close()
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc, tb):
# print("AEXIT")
self.close()
class AsyncPubChannel:
"""
@ -376,7 +384,7 @@ class AsyncPubChannel:
async_methods = [
"connect",
"_decode_messages",
# "close",
# "close",
]
close_methods = [
"close",
@ -406,7 +414,9 @@ class AsyncPubChannel:
io_loop = tornado.ioloop.IOLoop.current()
auth = salt.crypt.AsyncAuth(opts, io_loop=io_loop)
transport = salt.transport.publish_client(opts, io_loop)
host = opts.get("master_ip", "127.0.0.1")
port = int(opts.get("publish_port", 4506))
transport = salt.transport.publish_client(opts, io_loop, host=host, port=port)
return cls(opts, transport, auth, io_loop)
def __init__(self, opts, transport, auth, io_loop=None):
@ -432,6 +442,7 @@ class AsyncPubChannel:
try:
if not self.auth.authenticated:
yield self.auth.authenticate()
# log.error("*** Creds %r", self.auth.creds)
# if this is changed from the default, we assume it was intentional
if int(self.opts.get("publish_port", 4506)) != 4506:
publish_port = self.opts.get("publish_port")
@ -447,6 +458,8 @@ class AsyncPubChannel:
except KeyboardInterrupt: # pylint: disable=try-except-raise
raise
except Exception as exc: # pylint: disable=broad-except
# TODO: Basing re-try logic off exception messages is brittle and
# prone to errors; use exception types or some other method.
if "-|RETRY|-" not in str(exc):
raise salt.exceptions.SaltClientError(
f"Unable to sign_in to master: {exc}"
@ -456,11 +469,8 @@ class AsyncPubChannel:
"""
Close the channel
"""
log.error("AsyncPubChannel.close called")
self.transport.close()
log.error("Transport closed")
if self.event is not None:
log.error("Event destroy called")
self.event.destroy()
self.event = None
@ -616,7 +626,13 @@ class AsyncPubChannel:
return self
def __exit__(self, *args):
self.close()
self.io_loop.spawn_callback(self.close)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
class AsyncPushChannel:

View file

@ -699,9 +699,6 @@ class PubServerChannel:
Factory class to create subscription channels to the master's Publisher
"""
def __repr__(self):
return f"<PubServerChannel pub_uri={self.transport.pub_uri} pull_uri={self.transport.pull_uri} at {id(self)}>"
@classmethod
def factory(cls, opts, **kwargs):
if "master_uri" not in opts and "master_uri" in kwargs:
@ -720,6 +717,9 @@ class PubServerChannel:
transport = salt.transport.publish_server(opts, **kwargs)
return cls(opts, transport, presence_events=presence_events)
def __repr__(self):
return f"<PubServerChannel pub_uri={self.transport.pub_uri} pull_uri={self.transport.pull_uri} at {id(self)}>"
def __init__(self, opts, transport, presence_events=False):
self.opts = opts
self.ckminions = salt.utils.minions.CkMinions(self.opts)
@ -774,6 +774,7 @@ class PubServerChannel:
secrets = kwargs.get("secrets", None)
if secrets is not None:
salt.master.SMaster.secrets = secrets
log.error("RUN TRANSPORT PUBD")
self.transport.publish_daemon(
self.publish_payload, self.presence_callback, self.remove_presence_callback
)

View file

@ -222,7 +222,7 @@ class SaltCMD(salt.utils.parsers.SaltCMDOptionParser):
AuthorizationError,
SaltInvocationError,
EauthAuthenticationError,
SaltClientError,
# SaltClientError,
) as exc:
print(repr(exc))
ret = str(exc)

View file

@ -1527,7 +1527,7 @@ class Crypticle:
ret_nonce = data[:32].decode()
data = data[32:]
if ret_nonce != nonce:
raise SaltClientError("Nonce verification error")
raise SaltClientError(f"Nonce verification error {ret_nonce} {nonce}")
payload = salt.payload.loads(data, raw=raw)
if isinstance(payload, dict):
if "serial" in payload:

View file

@ -1048,7 +1048,30 @@ class MinionManager(MinionBase):
# self.opts,
# io_loop=self.io_loop,
# )
# import hashlib
# ipc_publisher = salt.transport.publish_server(self.opts)
# hash_type = getattr(hashlib, self.opts["hash_type"])
# id_hash = hash_type(
# salt.utils.stringutils.to_bytes(self.opts["id"])
# ).hexdigest()[:10]
# epub_sock_path = "ipc://{}".format(
# os.path.join(
# self.opts["sock_dir"], "minion_event_{}_pub.ipc".format(id_hash)
# )
# )
# if os.path.exists(epub_sock_path):
# os.unlink(epub_sock_path)
# epull_sock_path = "ipc://{}".format(
# os.path.join(
# self.opts["sock_dir"], "minion_event_{}_pull.ipc".format(id_hash)
# )
# )
# ipc_publisher.pub_uri = epub_sock_path
# ipc_publisher.pull_uri = epull_sock_path
# self.io_loop.add_callback(ipc_publisher.publisher, ipc_publisher.publish_payload, self.io_loop)
import hashlib
ipc_publisher = salt.transport.publish_server(self.opts)
hash_type = getattr(hashlib, self.opts["hash_type"])
id_hash = hash_type(
@ -1068,7 +1091,18 @@ class MinionManager(MinionBase):
)
ipc_publisher.pub_uri = epub_sock_path
ipc_publisher.pull_uri = epull_sock_path
self.io_loop.add_callback(ipc_publisher.publisher, ipc_publisher.publish_payload, self.io_loop)
if self.opts["transport"] == "tcp":
def target():
ipc_publisher.publish_daemon(ipc_publisher.publish_payload)
proc = salt.utils.process.Process(target=target, daemon=True)
proc.start()
else:
self.io_loop.add_callback(
ipc_publisher.publisher, ipc_publisher.publish_payload, self.io_loop
)
log.error("get event ")
self.event = salt.utils.event.get_event(
"minion", opts=self.opts, io_loop=self.io_loop
)
@ -1139,11 +1173,8 @@ class MinionManager(MinionBase):
self.io_loop.spawn_callback(self._connect_minion, minion)
self.io_loop.call_later(timeout, self._check_minions)
@tornado.gen.coroutine
def _connect_minion(self, minion):
"""
Create a minion, and asynchronously connect it to a master
"""
async def _connect_minion(self, minion):
"""Create a minion, and asynchronously connect it to a master"""
last = 0 # never have we signed in
auth_wait = minion.opts["acceptance_wait_time"]
failed = False
@ -1154,7 +1185,8 @@ class MinionManager(MinionBase):
if minion.opts.get("scheduler_before_connect", False):
minion.setup_scheduler(before_connect=True)
if minion.opts.get("master_type", "str") != "disable":
yield minion.connect_master(failed=failed)
await minion.connect_master(failed=failed)
log.error("RUN MINION TUNE IN")
minion.tune_in(start=False)
self.minions.append(minion)
break
@ -1164,11 +1196,12 @@ class MinionManager(MinionBase):
"Error while bringing up minion for multi-master. Is "
"master at %s responding?",
minion.opts["master"],
exc_info=True,
)
last = time.time()
if auth_wait < self.max_auth_wait:
auth_wait += self.auth_wait
yield tornado.gen.sleep(auth_wait) # TODO: log?
await tornado.gen.sleep(auth_wait) # TODO: log?
except SaltMasterUnresolvableError:
err = (
"Master address: '{}' could not be resolved. Invalid or"
@ -3269,16 +3302,16 @@ class Minion(MinionBase):
self.pub_channel.on_recv(None)
log.error("create pub_channel.close task %r", self)
self.pub_channel.close()
#self.io_loop.asyncio_loop.run_until_complete(self.pub_channel.close())
#if hasattr(self.pub_channel, "close"):
# self.io_loop.asyncio_loop.run_until_complete(self.pub_channel.close())
# if hasattr(self.pub_channel, "close"):
# asyncio.create_task(
# self.pub_channel.close()
# )
# #self.pub_channel.close()
#del self.pub_channel
# del self.pub_channel
if hasattr(self, "event"):
log.error("HAS EVENT")
#if hasattr(self, "periodic_callbacks"):
# if hasattr(self, "periodic_callbacks"):
# for cb in self.periodic_callbacks.values():
# cb.stop()
log.error("%r destroy method finished", self)

View file

@ -74,7 +74,7 @@ def publish_server(opts, **kwargs):
raise Exception("Transport type not found: {}".format(ttype))
def publish_client(opts, io_loop):
def ipc_publish_client(opts, io_loop):
# Default to ZeroMQ for now
ttype = "zeromq"
# determine the ttype
@ -85,6 +85,7 @@ def publish_client(opts, io_loop):
# switch on available ttypes
if ttype == "zeromq":
import salt.transport.zeromq
return salt.transport.zeromq.PublishClient(opts, io_loop)
elif ttype == "tcp":
import salt.transport.tcp
@ -93,6 +94,30 @@ def publish_client(opts, io_loop):
raise Exception("Transport type not found: {}".format(ttype))
def publish_client(opts, io_loop, host=None, port=None, path=None):
# Default to ZeroMQ for now
ttype = "zeromq"
# determine the ttype
if "transport" in opts:
ttype = opts["transport"]
elif "transport" in opts.get("pillar", {}).get("master", {}):
ttype = opts["pillar"]["master"]["transport"]
# switch on available ttypes
if ttype == "zeromq":
import salt.transport.zeromq
return salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port, path=path
)
elif ttype == "tcp":
import salt.transport.tcp
return salt.transport.tcp.TCPPubClient(
opts, io_loop, host=host, port=port, path=path
)
raise Exception("Transport type not found: {}".format(ttype))
class RequestClient:
"""
The RequestClient transport is used to make requests and get corresponding

View file

@ -134,11 +134,10 @@ class IPCServer:
else:
self.sock = tornado.netutil.bind_unix_socket(self.socket_path)
with salt.utils.asynchronous.current_ioloop(self.io_loop):
tornado.netutil.add_accept_handler(
self.sock,
self.handle_connection,
)
tornado.netutil.add_accept_handler(
self.sock,
self.handle_connection,
)
self._started = True
@tornado.gen.coroutine
@ -208,7 +207,7 @@ class IPCServer:
log.error("Exception occurred while handling stream: %s", exc)
def handle_connection(self, connection, address):
log.trace(
log.error(
"IPCServer: Handling connection to address: %s",
address if address else connection,
)
@ -338,8 +337,8 @@ class IPCClient:
break
if self.stream is None:
with salt.utils.asynchronous.current_ioloop(self.io_loop):
self.stream = IOStream(socket.socket(sock_type, socket.SOCK_STREAM))
# with salt.utils.asynchronous.current_ioloop(self.io_loop):
self.stream = IOStream(socket.socket(sock_type, socket.SOCK_STREAM))
try:
log.trace("IPCClient: Connecting to socket: %s", self.socket_path)
yield self.stream.connect(sock_addr)
@ -440,8 +439,8 @@ class IPCMessageClient(IPCClient):
# FIXME timeout unimplemented
# FIXME tries unimplemented
@tornado.gen.coroutine
def send(self, msg, timeout=None, tries=None):
# @tornado.gen.coroutine
async def send(self, msg, timeout=None, tries=None):
"""
Send a message to an IPC socket
@ -451,9 +450,9 @@ class IPCMessageClient(IPCClient):
:param int timeout: Timeout when sending message (Currently unimplemented)
"""
if not self.connected():
yield self.connect()
await self.connect()
pack = salt.transport.frame.frame_msg_ipc(msg, raw_body=True)
yield self.stream.write(pack)
await self.stream.write(pack)
class IPCMessageServer(IPCServer):

View file

@ -6,12 +6,13 @@ Wire protocol: "len(payload) msgpack({'head': SOMEHEADER, 'body': SOMEBODY})"
"""
import asyncio
import errno
import logging
import multiprocessing
import os
import queue
import select
import socket
import threading
import urllib
@ -23,6 +24,7 @@ import tornado.iostream
import tornado.netutil
import tornado.tcpclient
import tornado.tcpserver
from tornado.locks import Lock
import salt.master
import salt.payload
@ -215,13 +217,45 @@ class TCPPubClient(salt.transport.base.PublishClient):
ttype = "tcp"
async_methods = [
"connect",
"connect_uri",
"recv",
# "close",
]
close_methods = [
"close",
]
def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
self.opts = opts
self.io_loop = io_loop
self.message_client = None
self.unpacker = salt.utils.msgpack.Unpacker()
self.connected = False
self._closing = False
self.resolver = Resolver()
self._stream = None
self._closing = False
self._closed = False
self.backoff = opts.get("tcp_reconnect_backoff", 1)
self.resolver = kwargs.get("resolver")
self._read_in_progress = Lock()
self.poller = None
self.host = kwargs.get("host", None)
self.port = kwargs.get("port", None)
self.path = kwargs.get("path", None)
self.source_ip = self.opts.get("source_ip")
self.source_port = self.opts.get("source_publish_port")
self.connect_callback = None
self.disconnect_callback = None
if self.host is None and self.port is None:
if self.path is None:
raise Exception("A host and port or a path must be provided")
elif self.host and self.port:
if self.path:
raise Exception("A host and port or a path must be provided, not both")
def close(self):
if self._closing:
@ -237,21 +271,80 @@ class TCPPubClient(salt.transport.base.PublishClient):
# pylint: enable=W1701
@tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
self.publish_port = publish_port
self.message_client = MessageClient(
self.opts,
self.opts["master_ip"],
int(self.publish_port),
io_loop=self.io_loop,
connect_callback=connect_callback,
disconnect_callback=disconnect_callback,
source_ip=self.opts.get("source_ip"),
source_port=self.opts.get("source_publish_port"),
)
yield self.message_client.connect() # wait for the client to be connected
self.connected = True
async def getstream(self, **kwargs):
if self.source_ip or self.source_port:
kwargs = {
"source_ip": self.source_ip,
"source_port": self.source_port,
}
stream = None
while stream is None and (not self._closed and not self._closing):
try:
if self.host and self.port:
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
self._tcp_client = TCPClientKeepAlive(
self.opts, resolver=self.resolver
)
stream = await self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs,
)
else:
sock_type = socket.AF_UNIX
stream = tornado.iostream.IOStream(
socket.socket(sock_type, socket.SOCK_STREAM)
)
await stream.connect(self.path)
self.poller = select.poll()
self.poller.register(stream.socket, select.POLLIN)
except Exception as exc: # pylint: disable=broad-except
log.warning(
"TCP Message Client encountered an exception while connecting to"
" %s:%s %s: %r, will reconnect in %d seconds",
self.host,
self.port,
self.path,
exc,
self.backoff,
)
await tornado.gen.sleep(self.backoff)
return stream
async def _connect(self):
# log.error("Connect %r %r", self, self._stream)
# import traceback
# stack = "".join(traceback.format_stack())
# log.error(f"MessageClient Connect {name} {stack}")
if self._stream is None:
log.error("Get stream")
self._stream = await self.getstream()
log.error("Got stream")
if self._stream:
# if not self._stream_return_running:
# self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
self.connected = True
async def connect(
self,
port=None,
connect_callback=None,
disconnect_callback=None,
background=False,
):
if port is not None:
self.port = port
if connect_callback:
self.connect_callback = None
if disconnect_callback:
self.disconnect_callback = None
if background:
self.io_loop.spawn_callback(self._connect)
else:
await self._connect()
def _decode_messages(self, messages):
if not isinstance(messages, dict):
@ -263,15 +356,83 @@ class TCPPubClient(salt.transport.base.PublishClient):
body = messages
return body
@tornado.gen.coroutine
def send(self, msg):
yield self.message_client._stream.write(msg)
async def send(self, msg):
await self.message_client.send(msg, reply=False)
async def recv(self, timeout=None):
try:
await self._read_in_progress.acquire(timeout=0.00000001)
except tornado.gen.TimeoutError:
log.error("Timeout Error")
return
try:
if timeout == 0:
if not self._stream:
await asyncio.sleep(0.001)
return
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
poller = select.poll()
poller.register(self._stream.socket, select.POLLIN)
try:
events = poller.poll(0)
except TimeoutError:
events = []
if events:
while True:
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
finally:
self._read_in_progress.release()
def on_recv(self, callback):
"""
Register an on_recv callback
Register a callback for received messages (that we didn't initiate)
"""
return self.message_client.on_recv(callback)
async def setup_callback():
logit = True
while not self._stream:
# log.error("On recv wait stream %r", self._stream)
await asyncio.sleep(0.003)
while True:
try:
msg = await self.recv()
logit = True
except tornado.iostream.StreamClosedError:
self._stream.close()
self._stream = None
if logit:
log.error("Stream Closed", exc_info=True)
logit = False
await self._connect()
await asyncio.sleep(0.03)
# if self.disconnect_callback:
# self.disconnect_callback()
self.unpacker = salt.utils.msgpack.Unpacker()
continue
except Exception:
log.error("Other exception", exc_info=True)
log.error("on recv got msg %r", msg)
callback(msg)
self.io_loop.spawn_callback(setup_callback)
def __enter__(self):
return self
@ -389,6 +550,7 @@ class TCPReqServer(salt.transport.base.DaemonizedRequestServer):
def handle_message(self, stream, payload, header=None):
payload = self.decode_payload(payload)
reply = yield self.message_handler(payload)
# XXX Handle StreamClosedError
stream.write(salt.transport.frame.frame_msg(reply, header=header))
def decode_payload(self, payload):
@ -545,6 +707,7 @@ class MessageClient:
opts,
host,
port,
url=None,
io_loop=None,
resolver=None,
connect_callback=None,
@ -552,9 +715,13 @@ class MessageClient:
source_ip=None,
source_port=None,
):
# import traceback
# stack = "".join(traceback.format_stack())
# print(stack)
self.opts = opts
self.host = host
self.port = port
self.url = url
self.source_ip = source_ip
self.source_port = source_port
self.connect_callback = connect_callback
@ -577,6 +744,14 @@ class MessageClient:
self._stream = None
self.backoff = opts.get("tcp_reconnect_backoff", 1)
self.callbacks = {}
import traceback
self.stack = "".join(traceback.format_stack())
self.unpacker = salt.utils.msgpack.Unpacker()
self._read_in_progress = Lock()
self.task = None
self.tcp_client = None
def _stop_io_loop(self):
if self.io_loop is not None:
@ -587,8 +762,21 @@ class MessageClient:
if self._closing:
return
self._closing = True
self.stream.close()
self.tcp_client.close()
self.io_loop.add_timeout(1, self.check_close)
def close(self):
if self._closing:
return
self._closing = True
if self._stream is not None:
self._stream.close()
if self.tcp_client is not None:
self.tcp_client.close()
if self.task is not None:
self.task.cancel()
@tornado.gen.coroutine
def check_close(self):
if not self.send_future_map:
@ -601,12 +789,11 @@ class MessageClient:
# pylint: disable=W1701
def __del__(self):
self.close()
self.io_loop.spawn_callback(self.close)
# pylint: enable=W1701
@tornado.gen.coroutine
def getstream(self, **kwargs):
async def getstream(self, **kwargs):
if self.source_ip or self.source_port:
kwargs = {
"source_ip": self.source_ip,
@ -615,12 +802,22 @@ class MessageClient:
stream = None
while stream is None and (not self._closed and not self._closing):
try:
stream = yield self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs
)
if self.host and self.port:
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
stream = await self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
**kwargs,
)
else:
log.error("GET STREAM IPC")
sock_type = socket.AF_UNIX
path = self.url.replace("ipc://", "")
stream = tornado.iostream.IOStream(
socket.socket(sock_type, socket.SOCK_STREAM)
)
await stream.connect(path)
except Exception as exc: # pylint: disable=broad-except
log.warning(
"TCP Message Client encountered an exception while connecting to"
@ -630,26 +827,34 @@ class MessageClient:
exc,
self.backoff,
)
yield tornado.gen.sleep(self.backoff)
raise tornado.gen.Return(stream)
await tornado.gen.sleep(self.backoff)
return stream
@tornado.gen.coroutine
def connect(self):
async def connect(self):
log.error("Connect %r %r", self, self._stream)
import traceback
# stack = "".join(traceback.format_stack())
# log.error(f"MessageClient Connect {name} {stack}")
if self._stream is None:
self._stream = yield self.getstream()
log.error("Get stream")
self._stream = await self.getstream()
log.error("Got stream")
if self._stream:
if not self._stream_return_running:
self.io_loop.spawn_callback(self._stream_return)
self.task = asyncio.create_task(self._stream_return())
# self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
@tornado.gen.coroutine
def _stream_return(self):
async def _stream_return(self):
self._stream_return_running = True
unpacker = salt.utils.msgpack.Unpacker()
while not self._closing:
try:
wire_bytes = yield self._stream.read_bytes(4096, partial=True)
log.error("Stream read bytes %r", self._stream.socket)
wire_bytes = await self._stream.read_bytes(4096, partial=True)
log.error("Stream got bytes")
unpacker.feed(wire_bytes)
for framed_msg in unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
@ -663,6 +868,7 @@ class MessageClient:
else:
if self._on_recv is not None:
self.io_loop.spawn_callback(self._on_recv, header, body)
# await self._on_recv(header, body)
else:
log.error(
"Got response for message_id %s that we are not"
@ -670,7 +876,7 @@ class MessageClient:
message_id,
)
except tornado.iostream.StreamClosedError as e:
log.debug(
log.error(
"tcp stream to %s:%s closed, unable to recv",
self.host,
self.port,
@ -687,7 +893,7 @@ class MessageClient:
if stream:
stream.close()
unpacker = salt.utils.msgpack.Unpacker()
yield self.connect()
await self.connect()
except TypeError:
# This is an invalid transport
if "detect_mode" in self.opts:
@ -697,6 +903,9 @@ class MessageClient:
)
else:
raise SaltClientError
# except OSError:
# log.error("OSERROR", exc_info=True)
# raise
except Exception as e: # pylint: disable=broad-except
log.error("Exception parsing response", exc_info=True)
for future in self.send_future_map.values():
@ -711,7 +920,7 @@ class MessageClient:
if stream:
stream.close()
unpacker = salt.utils.msgpack.Unpacker()
yield self.connect()
await self.connect()
self._stream_return_running = False
def _message_id(self):
@ -733,14 +942,7 @@ class MessageClient:
"""
Register a callback for received messages (that we didn't initiate)
"""
if callback is None:
self._on_recv = callback
else:
def wrap_recv(header, body):
callback(body)
self._on_recv = wrap_recv
self._on_recv = callback
def remove_message_timeout(self, message_id):
if message_id not in self.send_timeout_map:
@ -755,13 +957,19 @@ class MessageClient:
if future is not None:
future.set_exception(SaltReqTimeoutError("Message timed out"))
@tornado.gen.coroutine
def send(self, msg, timeout=None, callback=None, raw=False):
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
# log.error("stream send %r %r %r %r", self, self.url, self.port, reply)
if self._closing:
raise ClosingError()
while not self._stream:
# log.error("Wait stream %r %r", self, self._stream)
await asyncio.sleep(0.03)
message_id = self._message_id()
header = {"mid": message_id}
# item = salt.transport.frame.frame_msg(msg, header=header)
# await self._stream.write(item)
# if reply:
# return await self.recv(timeout=None)
future = tornado.concurrent.Future()
if callback is not None:
@ -782,18 +990,66 @@ class MessageClient:
item = salt.transport.frame.frame_msg(msg, header=header)
@tornado.gen.coroutine
def _do_send():
yield self.connect()
async def _do_send():
await self.connect()
# If the _stream is None, we failed to connect.
if self._stream:
yield self._stream.write(item)
await self._stream.write(item)
# Run send in a callback so we can wait on the future, in case we time
# out before we are able to connect.
self.io_loop.add_callback(_do_send)
recv = yield future
raise tornado.gen.Return(recv)
recv = await future
return recv
async def recv(self, timeout=None):
try:
await self._read_in_progress.acquire(timeout=0.00000001)
except tornado.gen.TimeoutError:
log.error("Timeout Error")
return
try:
if timeout == 0:
for msg in self.unpacker:
log.error("RECV a")
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
poller = select.poll()
poller.register(self._stream.socket, select.POLLIN)
try:
events = poller.poll(0)
except TimeoutError:
events = []
if events:
while True:
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
log.error("RECV b")
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
else:
return
elif timeout:
log.error("RECV c")
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
log.error("RECV d")
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
log.error("RECV e")
byts = await self._stream.read_bytes(4096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
log.error("RECV e %r", msg)
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
log.error("RECV e %r", framed_msg)
return framed_msg["body"]
finally:
self._read_in_progress.release()
# await asyncio.sleep(.003)
class Subscriber:
@ -898,9 +1154,10 @@ class PubServer(tornado.tcpserver.TCPServer):
self.io_loop.spawn_callback(self._stream_read, client)
# TODO: ACK the publish through IPC
@tornado.gen.coroutine
def publish_payload(self, package, topic_list=None):
log.trace("TCP PubServer sending payload: %s \n\n %r", package, topic_list)
async def publish_payload(self, package, topic_list=None):
log.trace(
"TCP PubServer sending payload: topic_list=%r %r", topic_list, package
)
payload = salt.transport.frame.frame_msg(package)
to_remove = []
if topic_list:
@ -910,7 +1167,7 @@ class PubServer(tornado.tcpserver.TCPServer):
if topic == client.id_:
try:
# Write the packed str
yield client.stream.write(payload)
await client.stream.write(payload)
sent = True
# self.io_loop.add_future(f, lambda f: True)
except tornado.iostream.StreamClosedError:
@ -921,7 +1178,7 @@ class PubServer(tornado.tcpserver.TCPServer):
for client in self.clients:
try:
# Write the packed str
yield client.stream.write(payload)
await client.stream.write(payload)
except tornado.iostream.StreamClosedError:
to_remove.append(client)
for client in to_remove:
@ -942,10 +1199,32 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
# TODO: opts!
# Based on default used in tornado.netutil.bind_sockets()
backlog = 128
async_methods = [
"publish",
# "close",
]
close_methods = [
"close",
]
def __init__(self, opts):
self.opts = opts
self.pub_sock = None
# Set up Salt IPC server
if self.opts.get("ipc_mode", "") == "tcp":
self.pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
else:
self.pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
interface = self.opts.get("interface", "127.0.0.1")
self.publish_port = self.opts.get("publish_port", 4560)
self.pub_uri = f"tcp://{interface}:{self.publish_port}"
log.error(
"TCPPubServer %r %s %s %s",
self,
self.pull_uri,
self.publish_port,
self.pub_uri,
)
@property
def topic_support(self):
@ -967,6 +1246,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
Bind to the interface specified in the configuration file
"""
io_loop = tornado.ioloop.IOLoop()
log.error(
"TCPPubServer daemon %r %s %s %s",
self,
self.pull_uri,
self.publish_port,
self.pub_uri,
)
# Spin up the publisher
self.pub_server = pub_server = PubServer(
@ -975,21 +1261,34 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
presence_callback=presence_callback,
remove_presence_callback=remove_presence_callback,
)
sock = _get_socket(self.opts)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(sock, self.opts)
sock.setblocking(0)
sock.bind(_get_bind_addr(self.opts, "publish_port"))
if self.pub_uri.startswith("ipc://"):
pub_path = self.pub_uri.replace("ipc://", "")
sock = tornado.netutil.bind_unix_socket(pub_path)
else:
sock = _get_socket(self.opts)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(sock, self.opts)
sock.setblocking(0)
sock.bind(_get_bind_addr(self.opts, "publish_port"))
sock.listen(self.backlog)
# pub_server will take ownership of the socket
pub_server.add_socket(sock)
# Set up Salt IPC server
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
else:
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
# if self.opts.get("ipc_mode", "") == "tcp":
# pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
# else:
# pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
self.pub_server = pub_server
if "ipc://" in self.pull_uri:
pull_uri = pull_uri = self.pull_uri.replace("ipc://", "")
log.error("WTF PULL URI %r", pull_uri)
elif "tcp://" in self.pull_uri:
log.error("Fallback to publish port %r", self.pull_uri)
pull_uri = self.publish_port
else:
pull_uri = self.pull_uri
pull_sock = salt.transport.ipc.IPCMessageServer(
pull_uri,
io_loop=io_loop,
@ -997,7 +1296,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
)
# Securely create socket
log.warning("Starting the Salt Puller on %s", pull_uri)
log.warning("Starting the Salt Puller on %s", self.pull_uri)
with salt.utils.files.set_umask(0o177):
pull_sock.start()
@ -1019,24 +1318,44 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
@tornado.gen.coroutine
def publish_payload(self, payload, *args):
ret = yield self.pub_server.publish_payload(payload, *args)
# log.error("Publish paylaod %r %r", payload, args)
ret = yield self.pub_server.publish_payload(payload) # , *args)
raise tornado.gen.Return(ret)
def publish(self, payload, **kwargs):
def connect(self):
path = self.pull_uri.replace("ipc://", "")
log.error("Connect pusher %s", path)
# self.pub_sock = salt.utils.asynchronous.SyncWrapper(
# salt.transport.ipc.IPCMessageClient,
# (path,),
# loop_kwarg="io_loop",
# )
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
salt.transport.ipc.IPCMessageClient,
(path,),
loop_kwarg="io_loop",
)
# self.pub_sock = salt.transport.ipc.IPCMessageClient(path)
self.pub_sock.connect()
async def publish(self, payload, **kwargs):
"""
Publish "load" to minions
"""
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
else:
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
if not self.pub_sock:
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
salt.transport.ipc.IPCMessageClient,
(pull_uri,),
loop_kwarg="io_loop",
)
self.pub_sock.connect()
self.connect()
# if self.opts.get("ipc_mode", "") == "tcp":
# pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
# else:
# pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
# if not self.pub_sock:
# self.pub_sock = salt.utils.asynchronous.SyncWrapper(
# salt.transport.ipc.IPCMessageClient,
# (pull_uri,),
# loop_kwarg="io_loop",
# )
# self.pub_sock.connect()
# await self.pub_sock.send(payload)
self.pub_sock.send(payload)
def close(self):
@ -1045,6 +1364,11 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
self.pub_sock = None
import tracemalloc
tracemalloc.start()
class TCPReqClient(salt.transport.base.RequestClient):
"""
Tornado based TCP RequestClient
@ -1070,14 +1394,17 @@ class TCPReqClient(salt.transport.base.RequestClient):
source_port=opts.get("source_ret_port"),
)
@tornado.gen.coroutine
def connect(self):
yield self.message_client.connect()
async def connect(self):
log.error("TCPReqClient Connect")
await self.message_client.connect()
log.error("TCPReqClient Connected")
@tornado.gen.coroutine
def send(self, load, timeout=60):
ret = yield self.message_client.send(load, timeout=timeout)
raise tornado.gen.Return(ret)
async def send(self, load, timeout=60):
await self.connect()
log.error("TCP Request %r", load)
msg = await self.message_client.send(load, timeout=timeout)
log.error("TCP Reply %r", msg)
return msg
def close(self):
self.message_client.close()

View file

@ -113,7 +113,7 @@ class PublishClient(salt.transport.base.PublishClient):
"connect",
"connect_uri",
"recv",
#"close",
# "close",
]
close_methods = [
"close",
@ -139,7 +139,6 @@ class PublishClient(salt.transport.base.PublishClient):
self.hexid = hashlib.sha1(salt.utils.stringutils.to_bytes(_id)).hexdigest()
self._closing = False
self.context = zmq.asyncio.Context()
log.error("ZMQ Context creat %r", self)
self._socket = self.context.socket(zmq.SUB)
self._socket.setsockopt(zmq.LINGER, -1)
if zmq_filtering:
@ -211,6 +210,18 @@ class PublishClient(salt.transport.base.PublishClient):
self.connect_called = False
self.callbacks = {}
self.host = kwargs.get("host", None)
self.port = kwargs.get("port", None)
self.path = kwargs.get("path", None)
self.source_ip = self.opts.get("source_ip")
self.source_port = self.opts.get("source_publish_port")
if self.host is None and self.port is None:
if self.path is None:
raise Exception("A host and port or a path must be provided")
elif self.host and self.port:
if self.path:
raise Exception("A host and port or a path must be provided, not both")
async def close(self):
if self._closing is True:
return
@ -223,9 +234,7 @@ class PublishClient(salt.transport.base.PublishClient):
elif hasattr(self, "_socket"):
self._socket.close(0)
if hasattr(self, "context") and self.context.closed is False:
log.error("ZMQ Context term %r", self)
self.context.term()
log.error("ZMQ Context after term %r", self)
if self.callbacks:
for cb in self.callbacks:
running, task = self.callbacks[cb]
@ -243,9 +252,7 @@ class PublishClient(salt.transport.base.PublishClient):
elif hasattr(self, "_socket"):
self._socket.close(0)
if hasattr(self, "context") and self.context.closed is False:
log.error("ZMQ Context term %r", self)
self.context.term()
log.error("ZMQ Context after term %r", self)
# pylint: enable=W1701
def __enter__(self):
@ -255,18 +262,26 @@ class PublishClient(salt.transport.base.PublishClient):
self.close()
# TODO: this is the time to see if we are connected, maybe use the req channel to guess?
async def connect(
self, publish_port, connect_callback=None, disconnect_callback=None
):
async def connect(self, port=None, connect_callback=None, disconnect_callback=None):
self.connect_called = True
self.publish_port = publish_port
self.uri = self.master_pub
log.debug(
"Connecting the Minion to the Master publish port, using the URI: %s",
self.master_pub,
)
self._socket.connect(self.master_pub)
# await connect_callback(True)
if self.path:
pub_uri = f"ipc://{self.path}"
log.debug("Connecting the publisher client to: %s", pub_uri)
self._socket.connect(pub_uri)
else:
# host = self.opts["master_ip"],
if port is not None:
self.port = port
master_pub_uri = _get_master_uri(
self.host, self.port, self.source_ip, self.source_port
)
log.debug(
"Connecting the Minion to the Master publish port, using the URI: %s",
master_pub_uri,
)
self._socket.connect(master_pub_uri)
if connect_callback:
await connect_callback(True)
async def connect_uri(self, uri, connect_callback=None, disconnect_callback=None):
self.connect_called = True
@ -275,19 +290,19 @@ class PublishClient(salt.transport.base.PublishClient):
self.uri = uri
self._socket.connect(uri)
if connect_callback:
connect_callback(True)
await connect_callback(True)
@property
def master_pub(self):
"""
Return the master publish port
"""
return _get_master_uri(
self.opts["master_ip"],
self.publish_port,
source_ip=self.opts.get("source_ip"),
source_port=self.opts.get("source_publish_port"),
)
# @property
# def master_pub(self):
# """
# Return the master publish port
# """
# return _get_master_uri(
# self.opts["master_ip"],
# self.publish_port,
# source_ip=self.opts.get("source_ip"),
# source_port=self.opts.get("source_publish_port"),
# )
def _decode_messages(self, messages):
"""
@ -359,27 +374,23 @@ class PublishClient(salt.transport.base.PublishClient):
try:
while running.is_set():
try:
log.error("Waiting for pyaload from %r", self.uri)
msg = await self.recv(timeout=None)
log.error("Got for pyaload from %r", self.uri)
except zmq.error.ZMQError as exc:
log.error("ZMQERROR, %s", exc)
# We've disconnected just die
break
except Exception: # pylint: disable=broad-except
log.error("WTF", exc_info=True)
break
# except Exception: # pylint: disable=broad-except
# log.error("WTF", exc_info=True)
# break
if msg:
try:
log.error("Running callback for pyaload from %r", self.uri)
await callback(msg)
log.error("Finished callback for pyaload from %r", self.uri)
except Exception: # pylint: disable=broad-except
log.error("Exception while running callback", exc_info=True)
#log.debug("Callback done %r", callback)
# log.debug("Callback done %r", callback)
except Exception as exc: # pylint: disable=broad-except
log.error("CONSUME Exception %s %s", self.uri, exc, exc_info=True)
log.error("CONSUME ENDING %s", self.uri)
log.error(
"Exception while consuming%s %s", self.uri, exc, exc_info=True
)
task = self.io_loop.spawn_callback(consume, running)
self.callbacks[callback] = running, task
@ -400,7 +411,6 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
"""
self.__setup_signals()
context = zmq.Context(self.opts["worker_threads"])
log.error("ZMQ Context create %r", self)
# Prepare the zeromq sockets
self.uri = "tcp://{interface}:{ret_port}".format(**self.opts)
self.clients = context.socket(zmq.ROUTER)
@ -448,9 +458,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
raise
except (KeyboardInterrupt, SystemExit):
break
log.error("ZMQ Context term %r", self)
context.term()
log.error("ZMQ Context after term %r", self)
def close(self):
"""
@ -476,9 +484,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
if hasattr(self, "_socket") and self._socket.closed is False:
self._socket.close()
if hasattr(self, "context") and self.context.closed is False:
log.error("ZMQ Context term %r", self)
self.context.term()
log.error("ZMQ Context after term %r", self)
def pre_fork(self, process_manager):
"""
@ -596,6 +602,7 @@ def _set_tcp_keepalive(zmq_socket, opts):
if "tcp_keepalive_intvl" in opts:
zmq_socket.setsockopt(zmq.TCP_KEEPALIVE_INTVL, opts["tcp_keepalive_intvl"])
ctx = zmq.asyncio.Context()
# TODO: unit tests!
@ -627,13 +634,13 @@ class AsyncReqMessageClient:
self.io_loop = io_loop
self.context = zmq.asyncio.Context()
log.error("ZMQ Context create %r", self)
self.send_queue = []
# mapping of message -> future
self.send_future_map = {}
self._closing = False
self.socket = None
async def connect(self):
# wire up sockets
@ -650,15 +657,14 @@ class AsyncReqMessageClient:
return
else:
self._closing = True
self.socket.close()
if self.socket:
self.socket.close()
if self.context.closed is False:
# This hangs if closing the stream causes an import error
log.error("ZMQ Context term %r", self)
self.context.term()
log.error("ZMQ Context after term %r", self)
def _init_socket(self):
if hasattr(self, "socket"):
if self.socket is not None:
self.socket.close() # pylint: disable=E0203
del self.socket
@ -693,6 +699,8 @@ class AsyncReqMessageClient:
future.set_exception(SaltReqTimeoutError("Message timed out"))
async def _send_recv(self, message):
if not self.socket:
await self.connect()
message = salt.payload.dumps(message)
await self.socket.send(message)
ret = await self.socket.recv()
@ -793,7 +801,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
# _sock_data = threading.local()
async_methods = [
"publish",
#"close",
# "close",
]
close_methods = [
"close",
@ -814,10 +822,10 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
self.pub_uri = f"tcp://{interface}:{publish_port}"
self.ctx = None
self.sock = None
self.deamon_context = None
self.deamon_pub_sock = None
self.deamon_pull_sock = None
self.deamon_monitor = None
self.daemon_context = None
self.daemon_pub_sock = None
self.daemon_pull_sock = None
self.daemon_monitor = None
def __repr__(self):
return f"<PublishServer pub_uri={self.pub_uri} pull_uri={self.pull_uri} at {hex(id(self))}>"
@ -837,10 +845,9 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
try:
ioloop.start()
finally:
self.close()
self.daemon_context.term()
def _get_sockets(self, context, ioloop):
log.error("ZMQ Context create %r", self)
pub_sock = context.socket(zmq.PUB)
monitor = ZeroMQSocketMonitor(pub_sock)
monitor.start_io_loop(ioloop)
@ -893,12 +900,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
ioloop = tornado.ioloop.IOLoop.current()
ioloop.asyncio_loop.set_debug(True)
self.daemon_context = zmq.asyncio.Context()
self.daemon_pull_sock, self.daemon_pub_sock, self.deamon_monitor = self._get_sockets(self.daemon_context, ioloop)
(
self.daemon_pull_sock,
self.daemon_pub_sock,
self.daemon_monitor,
) = self._get_sockets(self.daemon_context, ioloop)
while True:
try:
log.error("Publisher wait package %s", self.pull_uri)
package = await self.daemon_pull_sock.recv()
log.error("Publisher got package %s %r", self.pull_uri, package)
# payload = salt.payload.loads(package)
await publish_payload(package)
except Exception as exc: # pylint: disable=broad-except
@ -907,8 +916,8 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
)
async def publish_payload(self, payload, topic_list=None):
log.error(f"Publish payload %s %r", self.pub_uri, payload)
try:
log.trace("Publish payload %r", payload)
# payload = salt.payload.dumps(payload)
if self.opts["zmq_filtering"]:
if topic_list:
@ -967,7 +976,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
already exists "pub_close" is called before creating and connecting a
new socket.
"""
log.debug("Connecting to pub server: %s", self.pull_uri)
log.error("Connecting to pub server: %s", self.pull_uri)
self.ctx = zmq.asyncio.Context()
self.sock = self.ctx.socket(zmq.PUSH)
self.sock.setsockopt(zmq.LINGER, 300)
@ -982,23 +991,19 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
if self.sock is not None:
sock = self.sock
self.sock = None
log.error("Socket close %r", self)
sock.close()
log.error("Socket closed %r", self)
if self.ctx and self.ctx.closed is False:
ctx = self.ctx
self.ctx = None
log.error("Context term %r", self)
ctx.term()
log.error("After context term %r", self)
if self.deamon_pub_sock:
self.deamon_pub_sock.close()
if self.deamon_pull_sock:
self.deamon_pull_sock.close()
if self.daemon_pub_sock:
self.daemon_pub_sock.close()
if self.daemon_pull_sock:
self.daemon_pull_sock.close()
if self.daemon_monitor:
self.daemon_monitor.close()
if self.deamon_context:
self.deamon_context.term()
self.daemon_monitor.stop()
if self.daemon_context:
self.daemon_context.term()
async def publish(self, payload, **kwargs):
"""
@ -1009,7 +1014,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
"""
if not self.sock:
self.connect()
log.error("%r send %r", self, payload)
await self.sock.send(payload)
@property
@ -1040,7 +1044,7 @@ class RequestClient(salt.transport.base.RequestClient):
await self.message_client.connect()
async def send(self, load, timeout=60):
self.connect()
await self.connect()
return await self.message_client.send(load, timeout=timeout)
def close(self):

View file

@ -59,7 +59,9 @@ class SyncWrapper:
loop_kwarg=None,
):
self.asyncio_loop = asyncio.new_event_loop()
self.io_loop = tornado.ioloop.IOLoop(asyncio_loop=self.asyncio_loop)
self.io_loop = tornado.ioloop.IOLoop(
asyncio_loop=self.asyncio_loop, make_current=False
)
if args is None:
args = []
if kwargs is None:
@ -108,12 +110,16 @@ class SyncWrapper:
log.error("No async method %s on object %r", method, self.obj)
except Exception: # pylint: disable=broad-except
log.exception("Exception encountered while running stop method")
# thread = threading.Thread(target=self._run_loop_final, args=(self.asyncio_loop,))
# thread.start()
# thread.join()
io_loop = self.io_loop
io_loop.stop()
try:
io_loop.close(all_fds=True)
except KeyError:
pass
self.asyncio_loop.close()
def __getattr__(self, key):
if key in self._async_methods:
@ -137,6 +143,18 @@ class SyncWrapper:
return wrap
def _run_loop_final(self, asyncio_loop):
asyncio.set_event_loop(asyncio_loop)
io_loop = tornado.ioloop.IOLoop.current()
try:
async def noop():
await asyncio.sleep(0.2)
result = io_loop.run_sync(lambda: noop())
except Exception: # pylint: disable=broad-except
log.error("Error on last loop run")
def _target(self, key, args, kwargs, results, asyncio_loop):
asyncio.set_event_loop(asyncio_loop)
io_loop = tornado.ioloop.IOLoop.current()
@ -149,7 +167,16 @@ class SyncWrapper:
results.append(sys.exc_info())
def __enter__(self):
if hasattr(self.obj, "__aenter__"):
ret = self._wrap("__aenter__")()
if ret == self.obj:
return self
else:
return ret
return self
def __exit__(self, exc_type, exc_val, tb):
self.close()
if hasattr(self.obj, "__aexit__"):
return self._wrap("__aexit__")
else:
self.close()

View file

@ -230,7 +230,7 @@ class SaltEvent:
self.io_loop = io_loop
self._run_io_loop_sync = False
else:
self.io_loop = tornado.ioloop.IOLoop()
# self.io_loop = tornado.ioloop.IOLoop()
self._run_io_loop_sync = True
self.cpub = False
self.cpush = False
@ -353,48 +353,43 @@ class SaltEvent:
if self.cpub:
return True
kwargs = {"io_loop": self.io_loop}
if isinstance(self.puburi, int):
kwargs.update(host="127.0.0.1", port=self.puburi)
else:
kwargs.update(path=self.puburi)
if self._run_io_loop_sync:
with salt.utils.asynchronous.current_ioloop(self.io_loop):
if self.subscriber is None:
# self.subscriber = salt.utils.asynchronous.SyncWrapper(
# salt.transport.ipc.IPCMessageSubscriber,
# args=(self.puburi,),
# kwargs={"io_loop": self.io_loop},
# loop_kwarg="io_loop",
# )
# self.subscriber = salt.transport.publish_client(self.opts)
self.subscriber = salt.utils.asynchronous.SyncWrapper(
salt.transport.publish_client,
args=(self.opts,),
kwargs={"io_loop": self.io_loop},
loop_kwarg="io_loop",
)
try:
# self.subscriber.connect(timeout=timeout)
puburi = "ipc://{}".format(self.puburi)
self.subscriber.connect_uri(puburi)
self.cpub = True
except tornado.iostream.StreamClosedError:
log.error("Encountered StreamClosedException")
except OSError as exc:
if exc.errno != errno.ENOENT:
raise
log.error("Error opening stream, file does not exist")
except Exception as exc: # pylint: disable=broad-except
log.info(
"An exception occurred connecting publisher: %s",
exc,
exc_info_on_loglevel=logging.DEBUG,
)
if self.subscriber is None:
self.subscriber = salt.utils.asynchronous.SyncWrapper(
salt.transport.publish_client,
args=(self.opts,),
kwargs=kwargs,
loop_kwarg="io_loop",
)
try:
# self.subscriber.connect(timeout=timeout)
log.debug("Event connect subscriber %r", self.puburi)
self.subscriber.connect()
self.cpub = True
except tornado.iostream.StreamClosedError:
log.error("Encountered StreamClosedException")
except OSError as exc:
if exc.errno != errno.ENOENT:
raise
log.error("Error opening stream, file does not exist")
except Exception as exc: # pylint: disable=broad-except
log.info(
"An exception occurred connecting publisher: %s",
exc,
exc_info_on_loglevel=logging.DEBUG,
)
else:
if self.subscriber is None:
if "master_ip" not in self.opts:
self.opts["master_ip"] = ""
self.subscriber = salt.transport.publish_client(self.opts, self.io_loop)
puburi = "ipc://{}".format(self.puburi)
# self.io_loop.run_sync(self.subscriber.connect_uri, puburi)
self.io_loop.spawn_callback(self.subscriber.connect_uri, puburi)
log.error("WTF")
self.subscriber = salt.transport.publish_client(self.opts, **kwargs)
log.debug("Event connect subscriber %r", self.puburi)
self.io_loop.spawn_callback(self.subscriber.connect)
# self.subscriber = salt.transport.ipc.IPCMessageSubscriber(
# self.puburi, io_loop=self.io_loop
# )
@ -410,9 +405,9 @@ class SaltEvent:
"""
if not self.cpub:
return
#if isinstance(self.subscriber, salt.utils.asynchronous.SyncWrapper):
# if isinstance(self.subscriber, salt.utils.asynchronous.SyncWrapper):
# self.subscriber.close()
#else:
# else:
# asyncio.create_task(self.subscriber.close())
self.subscriber.close()
self.subscriber = None
@ -433,6 +428,7 @@ class SaltEvent:
salt.transport.publish_server,
args=(self.opts,),
)
log.error("PUSHER %r %r", self, self.pusher.io_loop.asyncio_loop)
self.pusher.obj.pub_uri = "ipc://{}".format(self.puburi)
self.pusher.obj.pull_uri = "ipc://{}".format(self.pulluri)
# self.pusher = salt.utils.asynchronous.SyncWrapper(
@ -716,7 +712,6 @@ class SaltEvent:
if not self.cpub:
if not self.connect_pub():
return None
log.error("GET EVENT NOBLOCK %r", self.subscriber)
raw = self.subscriber.recv(timeout=0)
if raw is None:
return None
@ -733,7 +728,6 @@ class SaltEvent:
if not self.cpub:
if not self.connect_pub():
return None
log.error("GET EVENT BLOCK %r", self.subscriber)
raw = self.subscriber.recv(timeout=None)
if raw is None:
return None
@ -792,7 +786,7 @@ class SaltEvent:
is_msgpacked=True,
use_bin_type=True,
)
log.error(
log.debug(
"Sending event(fire_event_async): tag = %s; data = %s %r",
tag,
data,
@ -850,7 +844,7 @@ class SaltEvent:
is_msgpacked=True,
use_bin_type=True,
)
log.error(
log.debug(
"Sending event(fire_event): tag = %s; data = %s %s",
tag,
data,
@ -865,7 +859,6 @@ class SaltEvent:
)
msg = salt.utils.stringutils.to_bytes(event, "utf-8")
if self._run_io_loop_sync:
log.error("FIRE EVENT A %r %r", msg, self.pusher.obj)
try:
# self.pusher.send(msg)
self.pusher.publish(msg)
@ -877,7 +870,6 @@ class SaltEvent:
)
raise
else:
log.error("FIRE EVENT B %r %r", msg, self.pusher)
asyncio.create_task(self.pusher.publish(msg))
# self.io_loop.spawn_callback(self.pusher.send, msg)
return True
@ -897,8 +889,8 @@ class SaltEvent:
self.close_pub()
if self.pusher is not None:
self.close_pull()
if self._run_io_loop_sync and not self.keep_loop:
self.io_loop.close()
# if self._run_io_loop_sync and not self.keep_loop:
# self.io_loop.close()
def _fire_ret_load_specific_fun(self, load, fun_index=0):
"""
@ -989,12 +981,11 @@ class SaltEvent:
Invoke the event_handler callback each time an event arrives.
"""
assert not self._run_io_loop_sync
if not self.cpub:
self.connect_pub()
# This will handle reconnects
# return self.subscriber.read_async(event_handler)
self.subscriber.on_recv(event_handler)
self.io_loop.spawn_callback(self.subscriber.on_recv, event_handler)
# pylint: disable=W1701
def __del__(self):

View file

@ -30,6 +30,10 @@ from tests.support.runtests import RUNTIME_VARS
log = logging.getLogger(__name__)
import tracemalloc
tracemalloc.start()
@pytest.fixture(scope="session")
def salt_auth_account_1_factory():
@ -607,8 +611,14 @@ def pytest_pyfunc_call(pyfuncitem):
__tracebackhide__ = True
loop.run_sync(
CoroTestFunction(pyfuncitem.obj, testargs), timeout=get_test_timeout(pyfuncitem)
# loop.run_sync(
# CoroTestFunction(pyfuncitem.obj, testargs), timeout=get_test_timeout(pyfuncitem)
# )
loop.asyncio_loop.run_until_complete(
asyncio.wait_for(
CoroTestFunction(pyfuncitem.obj, testargs)(),
timeout=get_test_timeout(pyfuncitem),
)
)
return True

View file

@ -1,5 +1,6 @@
import logging
import time
import tracemalloc
from concurrent.futures.thread import ThreadPoolExecutor
import pytest
@ -9,6 +10,8 @@ import salt.channel.server
import salt.master
from tests.support.pytest.transport import PubServerChannelProcess
tracemalloc.start()
log = logging.getLogger(__name__)
@ -61,6 +64,8 @@ def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
ZMQ's ipc transport not supported on Windows
"""
import asyncio
opts = dict(
salt_master.config.copy(),
ipc_mode="ipc",
@ -91,28 +96,51 @@ def test_issue_36469_tcp(salt_master, salt_minion, transport):
"""
if transport == "tcp":
pytest.skip("Test not applicable to the ZeroMQ transport.")
import asyncio
def _send_small(opts, sid, num=10):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_channel = salt.channel.server.PubServerChannel.factory(opts)
for idx in range(num):
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-s{}".format(sid, idx)}
server_channel.publish(load)
async def send():
await asyncio.sleep(0.3)
for idx in range(num):
load = {
"tgt_type": "glob",
"tgt": "*",
"jid": "{}-s{}".format(sid, idx),
}
await server_channel.publish(load)
asyncio.run(send())
time.sleep(0.3)
time.sleep(3)
server_channel.close_pub()
server_channel.close()
loop.close()
asyncio.set_event_loop(None)
def _send_large(opts, sid, num=10, size=250000 * 3):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_channel = salt.channel.server.PubServerChannel.factory(opts)
for idx in range(num):
load = {
"tgt_type": "glob",
"tgt": "*",
"jid": "{}-l{}".format(sid, idx),
"xdata": "0" * size,
}
server_channel.publish(load)
async def send():
await asyncio.sleep(0.3)
for idx in range(num):
load = {
"tgt_type": "glob",
"tgt": "*",
"jid": "{}-l{}".format(sid, idx),
"xdata": "0" * size,
}
await server_channel.publish(load)
asyncio.run(send())
time.sleep(0.3)
server_channel.close_pub()
server_channel.close()
loop.close()
asyncio.set_event_loop(None)
opts = dict(salt_master.config.copy(), ipc_mode="tcp", pub_hwm=0)
send_num = 10 * 4

View file

@ -1,3 +1,4 @@
import asyncio
import logging
import pytest
@ -26,16 +27,26 @@ def server(config):
disconnect = False
async def handle_stream(self, stream, address):
while self.disconnect is False:
for msg in self.send[:]:
msg = self.send.pop(0)
try:
await stream.write(msg)
except tornado.iostream.StreamClosedError:
break
else:
await tornado.gen.sleep(1)
stream.close()
try:
log.error("Got stream")
while self.disconnect is False:
for msg in self.send[:]:
msg = self.send.pop(0)
try:
log.error("Write %r", msg)
await stream.write(msg)
except tornado.iostream.StreamClosedError:
log.error("Stream Closed Error From Test Server")
break
else:
log.error("SLEEP")
await asyncio.sleep(1)
log.error("Close stream")
log.error("After close stream")
except:
log.error("WTFSON", exc_info=True)
finally:
stream.close()
server = TestServer()
try:
@ -47,14 +58,16 @@ def server(config):
@pytest.fixture
def client(io_loop, config):
client = salt.transport.tcp.TCPPubClient(config.copy(), io_loop)
client = salt.transport.tcp.TCPPubClient(
config.copy(), io_loop, host=config["master_ip"], port=config["publish_port"]
)
try:
yield client
finally:
client.close()
async def test_message_client_reconnect(io_loop, config, client, server):
async def test_message_client_reconnect(config, client, server):
"""
Verify that the tcp MessageClient class re-sets it's unpacker after a
stream disconnect.
@ -69,7 +82,7 @@ async def test_message_client_reconnect(io_loop, config, client, server):
received.append(msg)
client.on_recv(handler)
await asyncio.sleep(0.03)
# Prepare two packed messages
msg = salt.utils.msgpack.dumps({"test": "test1"})
pmsg = salt.utils.msgpack.dumps({"head": {}, "body": msg})
@ -78,24 +91,40 @@ async def test_message_client_reconnect(io_loop, config, client, server):
# Send one full and one partial msg to the client.
partial = pmsg[:40]
log.error("Send partial %r", partial)
server.send.append(partial)
while not received:
await tornado.gen.sleep(1)
log.error("wait received")
await asyncio.sleep(1)
log.error("assert received")
assert received == [msg]
# log.error("sleep")
# await asyncio.sleep(1)
# The message client has unpacked one msg and there is a partial msg left in
# the unpacker. Closing the stream now leaves the unpacker in a bad state
# since the rest of the partil message will never be received.
log.error("disconnect")
server.disconnect = True
await tornado.gen.sleep(1)
log.error("sleep")
await asyncio.sleep(1)
log.error("after sleep")
log.error("disconnect false")
server.disconnect = False
log.error("sleep")
await asyncio.sleep(1)
log.error("after sleep")
log.error("Disconnect False")
received = []
# Prior to the fix for #60831, the unpacker would be left in a broken state
# resulting in either a TypeError or BufferFull error from msgpack. The
# rest of this test would fail.
log.error("Send pmsg %r", pmsg)
server.send.append(pmsg)
while not received:
await tornado.gen.sleep(1)
assert received == [msg, msg]
server.disconnect = True
await tornado.gen.sleep(1)

View file

@ -22,6 +22,7 @@ def test_zeromq_filtering(salt_master, salt_minion):
"""
Test sending messages to publisher using UDP with zeromq_filtering enabled
"""
log.error("TEST START")
opts = dict(
salt_master.config.copy(),
ipc_mode="ipc",
@ -31,6 +32,7 @@ def test_zeromq_filtering(salt_master, salt_minion):
)
send_num = 1
expect = []
log.error("TEST START 2 ")
with patch(
"salt.utils.minions.CkMinions.check_minions",
MagicMock(
@ -41,9 +43,12 @@ def test_zeromq_filtering(salt_master, salt_minion):
}
),
):
# log.error("Get Server Channel")
log.error("TEST START 3")
with PubServerChannelProcess(
opts, salt_minion.config.copy(), zmq_filtering=True
) as server_channel:
log.error("pub chan started")
expect.append(send_num)
load = {"tgt_type": "glob", "tgt": "*", "jid": send_num}
server_channel.publish(load)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,249 @@
"""
:codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
"""
import hashlib
import logging
import pytest
import tornado.ioloop
import salt.crypt
import salt.transport.tcp
import salt.transport.zeromq
import salt.utils.stringutils
from tests.support.mock import MagicMock, patch
log = logging.getLogger(__name__)
pytestmark = [
pytest.mark.core_test,
]
def transport_ids(value):
return "Transport({})".format(value)
@pytest.fixture(params=("zeromq", "tcp"), ids=transport_ids)
def transport(request):
return request.param
async def test_zeromq_async_pub_channel_publish_port(temp_salt_master):
"""
test when connecting that we use the publish_port set in opts when its not 4506
"""
opts = dict(
temp_salt_master.config.copy(),
ipc_mode="ipc",
pub_hwm=0,
recon_randomize=False,
publish_port=455505,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
sign_pub_messages=False,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
ioloop = tornado.ioloop.IOLoop()
# Transport will connect to port given to connect method.
transport = salt.transport.zeromq.PublishClient(
opts, ioloop, host=opts["master_ip"], port=121212
)
with transport:
patch_socket = MagicMock(return_value=True)
patch_auth = MagicMock(return_value=True)
with patch.object(transport, "_socket", patch_socket):
await transport.connect(opts["publish_port"])
assert str(opts["publish_port"]) in patch_socket.mock_calls[0][1][0]
def test_zeromq_async_pub_channel_filtering_decode_message_no_match(
temp_salt_master,
):
"""
test zeromq PublishClient _decode_messages when
zmq_filtering enabled and minion does not match
"""
message = [
b"4f26aeafdb2367620a393c973eddbe8f8b846eb",
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
]
opts = dict(
temp_salt_master.config.copy(),
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
recon_randomize=False,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
sign_pub_messages=False,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
ioloop = tornado.ioloop.IOLoop()
transport = salt.transport.zeromq.PublishClient(
opts, ioloop, host=opts["master_ip"], port=121212
)
with transport:
with patch(
"salt.crypt.AsyncAuth.crypticle",
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
):
res = transport._decode_messages(message)
assert res is None
def test_zeromq_async_pub_channel_filtering_decode_message(
temp_salt_master, temp_salt_minion
):
"""
test AsyncZeroMQPublishClient _decode_messages when zmq_filtered enabled
"""
minion_hexid = salt.utils.stringutils.to_bytes(
hashlib.sha1(salt.utils.stringutils.to_bytes(temp_salt_minion.id)).hexdigest()
)
message = [
minion_hexid,
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
]
opts = dict(
temp_salt_master.config.copy(),
id=temp_salt_minion.id,
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
recon_randomize=False,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
sign_pub_messages=False,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
ioloop = tornado.ioloop.IOLoop()
transport = salt.transport.zeromq.PublishClient(
opts, ioloop, host=opts["master_ip"], port=121212
)
with transport:
with patch(
"salt.crypt.AsyncAuth.crypticle",
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
) as mock_test:
res = transport._decode_messages(message)
assert res["enc"] == "aes"
async def test_publish_client_connect_server_down(transport, io_loop):
opts = {"master_ip": "127.0.0.1"}
host = "127.0.0.1"
port = 111222
print(transport)
if transport == "zeromq":
client = salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port
)
await client.connect()
assert client._socket
elif transport == "tcp":
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
try:
await client.connect(background=True)
except TimeoutError:
pass
assert client._stream is None
client.close()
async def test_publish_client_connect_server_comes_up(transport, io_loop):
print(io_loop)
opts = {"master_ip": "127.0.0.1"}
host = "127.0.0.1"
port = 11122
if transport == "zeromq":
import asyncio
import zmq
return
ctx = zmq.asyncio.Context()
uri = f"tcp://{opts['master_ip']}:{port}"
msg = salt.payload.dumps({"meh": 123})
print(f"send {msg}")
client = salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port
)
await client.connect(background=True)
assert client._socket
socket = ctx.socket(zmq.PUB)
socket.setsockopt(zmq.BACKLOG, 1000)
socket.setsockopt(zmq.LINGER, -1)
socket.setsockopt(zmq.SNDHWM, 1000)
print(f"bind {uri}")
socket.bind(uri)
await asyncio.sleep(10)
async def recv():
return await client.recv(timeout=1)
task = asyncio.create_task(recv())
# Sleep to allow zmq to do it's thing.
await asyncio.sleep(0.03)
await socket.send(msg)
await task
response = task.result()
assert response
client.close()
socket.close()
await asyncio.sleep(0.03)
ctx.term()
elif transport == "tcp":
import asyncio
import socket
import tornado
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
await client.connect(port, background=True)
assert client._stream is None
await asyncio.sleep(2)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind((opts["master_ip"], port))
sock.listen(128)
await asyncio.sleep(0.03)
msg = salt.payload.dumps({"meh": 123})
msg = salt.transport.frame.frame_msg(msg, header=None)
conn, addr = sock.accept()
conn.send(msg)
response = await client.recv()
assert response
else:
raise Exception(f"Unknown transport {transport}")

View file

@ -80,12 +80,14 @@ def client_socket():
yield _client_socket
def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
async def test_message_client_cleanup_on_close(
client_socket, temp_salt_master, io_loop
):
"""
test message client cleanup on close
"""
orig_loop = tornado.ioloop.IOLoop()
orig_loop.make_current()
orig_loop = io_loop
opts = dict(temp_salt_master.config.copy(), transport="tcp")
client = salt.transport.tcp.MessageClient(
@ -103,15 +105,14 @@ def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
orig_loop.stop = stop
try:
assert client.io_loop == orig_loop
client.io_loop.run_sync(client.connect)
await client.connect()
# Ensure we are testing the _read_until_future and io_loop teardown
assert client._stream is not None
assert orig_loop.stop_called is True
# The run_sync call will set stop_called, reset it
# orig_loop.stop_called = False
client.close()
await client.close()
# Stop should be called again, client's io_loop should be None
# assert orig_loop.stop_called is True
@ -120,8 +121,10 @@ def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
orig_loop.stop = orig_loop.real_stop
del orig_loop.real_stop
del orig_loop.stop_called
orig_loop.clear_current()
orig_loop.close(all_fds=True)
# orig_loop.clear_current()
# orig_loop.close(all_fds=True)
async def test_async_tcp_pub_channel_connect_publish_port(
@ -130,6 +133,8 @@ async def test_async_tcp_pub_channel_connect_publish_port(
"""
test when publish_port is not 4506
"""
import asyncio
opts = dict(
temp_salt_master.config.copy(),
master_uri="tcp://127.0.0.1:1234",
@ -141,6 +146,10 @@ async def test_async_tcp_pub_channel_connect_publish_port(
)
patch_auth = MagicMock(return_value=True)
transport = MagicMock(spec=salt.transport.tcp.TCPPubClient)
transport.connect = MagicMock()
future = asyncio.Future()
transport.connect.return_value = future
future.set_result(True)
with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
"salt.crypt.AsyncAuth.authenticated", patch_auth
), patch("salt.transport.tcp.TCPPubClient", transport):
@ -150,6 +159,7 @@ async def test_async_tcp_pub_channel_connect_publish_port(
with pytest.raises(salt.exceptions.SaltClientError):
await channel.connect()
# The first call to the mock is the instance's __init__, and the first argument to those calls is the opts dict
await asyncio.sleep(0.3)
assert channel.transport.connect.call_args[0][0] == opts["publish_port"]
@ -248,10 +258,7 @@ def salt_message_client():
{}, "127.0.0.1", ports.get_unused_localhost_port(), io_loop=io_loop_mock
)
try:
yield client
finally:
client.close()
yield client
# XXX we don't reutnr a future anymore, this needs a different way of testing.

File diff suppressed because it is too large Load diff

View file

@ -73,13 +73,21 @@ class Collector(salt.utils.process.SignalHandlingProcess):
"rotate_master_key": self._rotate_secrets,
}
def _teardown_listener(self):
if self.transport == "zeromq":
self.sock.close()
self.ctx.term()
else:
self.sock.close()
def _setup_listener(self):
if self.transport == "zeromq":
ctx = zmq.Context()
self.sock = ctx.socket(zmq.SUB)
self.ctx = zmq.Context()
self.sock = self.ctx.socket(zmq.SUB)
self.sock.setsockopt(zmq.LINGER, -1)
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
log.error("Collector listen %s", pub_uri)
self.sock.connect(pub_uri)
else:
end = time.time() + 120
@ -97,22 +105,28 @@ class Collector(salt.utils.process.SignalHandlingProcess):
@tornado.gen.coroutine
def _recv(self):
# log.error("RECV %s", self.transport)
if self.transport == "zeromq":
# test_zeromq_filtering requires catching the
# SaltDeserializationError in order to pass.
try:
payload = self.sock.recv(zmq.NOBLOCK)
# log.error("ZMQ Payload is %r", payload)
serial_payload = salt.payload.loads(payload)
raise tornado.gen.Return(serial_payload)
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
raise RecvError("ZMQ Error")
else:
for msg in self.unpacker:
raise tornado.gen.Return(msg["body"])
# log.error("TCP Payload is %r", msg)
serial_payload = salt.payload.loads(msg["body"])
# raise tornado.gen.Return(msg["body"])
raise tornado.gen.Return(serial_payload)
byts = yield self.sock.read_bytes(8096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
raise tornado.gen.Return(msg["body"])
serial_payload = salt.payload.loads(msg["body"])
raise tornado.gen.Return(serial_payload)
raise RecvError("TCP Error")
@tornado.gen.coroutine
@ -128,40 +142,44 @@ class Collector(salt.utils.process.SignalHandlingProcess):
self.start = last_msg
serial = salt.payload.Serial(self.minion_config)
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
while True:
curr_time = time.time()
if time.time() > self.hard_timeout:
log.error("Hard timeout reaced in test collector!")
break
if curr_time - last_msg >= self.timeout:
log.error("Receive timeout reaced in test collector!")
break
try:
payload = yield self._recv()
except RecvError:
time.sleep(0.01)
else:
try:
while True:
curr_time = time.time()
if time.time() > self.hard_timeout:
log.error("Hard timeout reaced in test collector!")
break
if curr_time - last_msg >= self.timeout:
log.error("Receive timeout reaced in test collector!")
break
try:
payload = crypticle.loads(payload["load"])
if not payload:
continue
if "start" in payload:
log.info("Collector started")
self.running.set()
continue
if "stop" in payload:
log.info("Collector stopped")
break
last_msg = time.time()
self.results.append(payload["jid"])
except salt.exceptions.SaltDeserializationError:
log.error("Deserializer Error")
if not self.zmq_filtering:
log.exception("Failed to deserialize...")
break
self.end = time.time()
print(f"Total time {self.end - self.start}")
loop.stop()
payload = yield self._recv()
except RecvError:
time.sleep(0.03)
else:
try:
log.trace("Colleted payload %r", payload)
payload = crypticle.loads(payload["load"])
if not payload:
continue
if "start" in payload:
log.info("Collector started")
self.running.set()
continue
if "stop" in payload:
log.info("Collector stopped")
break
last_msg = time.time()
self.results.append(payload["jid"])
except salt.exceptions.SaltDeserializationError:
log.error("Deserializer Error")
if not self.zmq_filtering:
log.exception("Failed to deserialize...")
break
self.end = time.time()
print(f"Total time {self.end - self.start}")
finally:
self._teardown_listener()
loop.stop()
def run(self):
"""
@ -173,11 +191,17 @@ class Collector(salt.utils.process.SignalHandlingProcess):
loop.start()
def __enter__(self):
import sys
print("COL ENTER")
sys.stdout.flush()
self.manager.__enter__()
self.start()
# Wait until we can start receiving events
self.started.wait()
self.started.clear()
print("COL ENTER - Done")
sys.stdout.flush()
return self
def __exit__(self, *args):
@ -197,6 +221,7 @@ class Collector(salt.utils.process.SignalHandlingProcess):
class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
def __init__(self, master_config, minion_config, **collector_kwargs):
super().__init__()
self.name = "PubServerChannelProcess"
self._closing = False
self.master_config = master_config
self.minion_config = minion_config
@ -226,17 +251,24 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
self.master_config["interface"],
self.master_config["publish_port"],
self.aes_key,
**self.collector_kwargs
**self.collector_kwargs,
)
def run(self):
import queue
ioloop = tornado.ioloop.IOLoop()
try:
while True:
payload = self.queue.get()
try:
payload = self.queue.get(False)
except queue.Empty:
time.sleep(0.03)
continue
if payload is None:
log.debug("We received the stop sentinel")
break
self.pub_server_channel.publish(payload)
ioloop.run_sync(lambda: self.pub_server_channel.publish(payload))
except KeyboardInterrupt:
pass
finally:
@ -253,8 +285,8 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
if self.process_manager is None:
return
self.process_manager.terminate()
if hasattr(self.pub_server_channel, "pub_close"):
self.pub_server_channel.pub_close()
if hasattr(self.pub_server_channel, "close"):
self.pub_server_channel.close()
# Really terminate any process still left behind
for pid in self.process_manager._process_map:
terminate_process(pid=pid, kill_children=True, slow_stop=False)
@ -264,9 +296,12 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
self.queue.put(payload)
def __enter__(self):
log.error("Proc start")
self.start()
log.error("Col enter")
self.collector.__enter__()
attempts = 300
log.error("Wait collector")
while attempts > 0:
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "start": True})
if self.collector.running.wait(1) is True:
@ -274,6 +309,7 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
attempts -= 1
else:
pytest.fail("Failed to confirm the collector has started")
log.error("Collector started")
return self
def __exit__(self, *args):
@ -290,4 +326,4 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
self.stopped.wait(10)
self.close()
self.terminate()
log.info("The PubServerChannelProcess has terminated")
log.error("The PubServerChannelProcess has terminated")