mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Test fixes
This commit is contained in:
parent
077c253954
commit
9683260d61
22 changed files with 2426 additions and 1750 deletions
|
@ -94,6 +94,9 @@ class DeferredStreamHandler(StreamHandler):
|
|||
super().__init__(stream)
|
||||
self.__messages = deque(maxlen=max_queue_size)
|
||||
self.__emitting = False
|
||||
import traceback
|
||||
|
||||
self.stack = "".join(traceback.format_stack())
|
||||
|
||||
def handle(self, record):
|
||||
self.acquire()
|
||||
|
@ -115,6 +118,7 @@ class DeferredStreamHandler(StreamHandler):
|
|||
super().handle(record)
|
||||
finally:
|
||||
self.__emitting = False
|
||||
# This will raise a ValueError if the file handle has been closed.
|
||||
super().flush()
|
||||
|
||||
def sync_with_handlers(self, handlers=()):
|
||||
|
|
|
@ -474,7 +474,15 @@ def setup_temp_handler(log_level=None):
|
|||
break
|
||||
else:
|
||||
handler = DeferredStreamHandler(sys.stderr)
|
||||
atexit.register(handler.flush)
|
||||
|
||||
def tryflush():
|
||||
try:
|
||||
handler.flush()
|
||||
except ValueError:
|
||||
# File handle has already been closed.
|
||||
pass
|
||||
|
||||
atexit.register(tryflush)
|
||||
handler.setLevel(log_level)
|
||||
|
||||
# Set the default temporary console formatter config
|
||||
|
|
|
@ -103,6 +103,7 @@ class AsyncReqChannel:
|
|||
"_uncrypted_transfer",
|
||||
"send",
|
||||
"connect",
|
||||
# "close",
|
||||
]
|
||||
close_methods = [
|
||||
"close",
|
||||
|
@ -314,9 +315,8 @@ class AsyncReqChannel:
|
|||
|
||||
raise tornado.gen.Return(ret)
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def connect(self):
|
||||
yield self.transport.connect()
|
||||
async def connect(self):
|
||||
await self.transport.connect()
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def send(self, load, tries=None, timeout=None, raw=False):
|
||||
|
@ -367,6 +367,14 @@ class AsyncReqChannel:
|
|||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
# print("AEXIT")
|
||||
self.close()
|
||||
|
||||
|
||||
class AsyncPubChannel:
|
||||
"""
|
||||
|
@ -376,7 +384,7 @@ class AsyncPubChannel:
|
|||
async_methods = [
|
||||
"connect",
|
||||
"_decode_messages",
|
||||
# "close",
|
||||
# "close",
|
||||
]
|
||||
close_methods = [
|
||||
"close",
|
||||
|
@ -406,7 +414,9 @@ class AsyncPubChannel:
|
|||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
|
||||
auth = salt.crypt.AsyncAuth(opts, io_loop=io_loop)
|
||||
transport = salt.transport.publish_client(opts, io_loop)
|
||||
host = opts.get("master_ip", "127.0.0.1")
|
||||
port = int(opts.get("publish_port", 4506))
|
||||
transport = salt.transport.publish_client(opts, io_loop, host=host, port=port)
|
||||
return cls(opts, transport, auth, io_loop)
|
||||
|
||||
def __init__(self, opts, transport, auth, io_loop=None):
|
||||
|
@ -432,6 +442,7 @@ class AsyncPubChannel:
|
|||
try:
|
||||
if not self.auth.authenticated:
|
||||
yield self.auth.authenticate()
|
||||
# log.error("*** Creds %r", self.auth.creds)
|
||||
# if this is changed from the default, we assume it was intentional
|
||||
if int(self.opts.get("publish_port", 4506)) != 4506:
|
||||
publish_port = self.opts.get("publish_port")
|
||||
|
@ -447,6 +458,8 @@ class AsyncPubChannel:
|
|||
except KeyboardInterrupt: # pylint: disable=try-except-raise
|
||||
raise
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
# TODO: Basing re-try logic off exception messages is brittle and
|
||||
# prone to errors; use exception types or some other method.
|
||||
if "-|RETRY|-" not in str(exc):
|
||||
raise salt.exceptions.SaltClientError(
|
||||
f"Unable to sign_in to master: {exc}"
|
||||
|
@ -456,11 +469,8 @@ class AsyncPubChannel:
|
|||
"""
|
||||
Close the channel
|
||||
"""
|
||||
log.error("AsyncPubChannel.close called")
|
||||
self.transport.close()
|
||||
log.error("Transport closed")
|
||||
if self.event is not None:
|
||||
log.error("Event destroy called")
|
||||
self.event.destroy()
|
||||
self.event = None
|
||||
|
||||
|
@ -616,7 +626,13 @@ class AsyncPubChannel:
|
|||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
self.io_loop.spawn_callback(self.close)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
class AsyncPushChannel:
|
||||
|
|
|
@ -699,9 +699,6 @@ class PubServerChannel:
|
|||
Factory class to create subscription channels to the master's Publisher
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PubServerChannel pub_uri={self.transport.pub_uri} pull_uri={self.transport.pull_uri} at {id(self)}>"
|
||||
|
||||
@classmethod
|
||||
def factory(cls, opts, **kwargs):
|
||||
if "master_uri" not in opts and "master_uri" in kwargs:
|
||||
|
@ -720,6 +717,9 @@ class PubServerChannel:
|
|||
transport = salt.transport.publish_server(opts, **kwargs)
|
||||
return cls(opts, transport, presence_events=presence_events)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PubServerChannel pub_uri={self.transport.pub_uri} pull_uri={self.transport.pull_uri} at {id(self)}>"
|
||||
|
||||
def __init__(self, opts, transport, presence_events=False):
|
||||
self.opts = opts
|
||||
self.ckminions = salt.utils.minions.CkMinions(self.opts)
|
||||
|
@ -774,6 +774,7 @@ class PubServerChannel:
|
|||
secrets = kwargs.get("secrets", None)
|
||||
if secrets is not None:
|
||||
salt.master.SMaster.secrets = secrets
|
||||
log.error("RUN TRANSPORT PUBD")
|
||||
self.transport.publish_daemon(
|
||||
self.publish_payload, self.presence_callback, self.remove_presence_callback
|
||||
)
|
||||
|
|
|
@ -222,7 +222,7 @@ class SaltCMD(salt.utils.parsers.SaltCMDOptionParser):
|
|||
AuthorizationError,
|
||||
SaltInvocationError,
|
||||
EauthAuthenticationError,
|
||||
SaltClientError,
|
||||
# SaltClientError,
|
||||
) as exc:
|
||||
print(repr(exc))
|
||||
ret = str(exc)
|
||||
|
|
|
@ -1527,7 +1527,7 @@ class Crypticle:
|
|||
ret_nonce = data[:32].decode()
|
||||
data = data[32:]
|
||||
if ret_nonce != nonce:
|
||||
raise SaltClientError("Nonce verification error")
|
||||
raise SaltClientError(f"Nonce verification error {ret_nonce} {nonce}")
|
||||
payload = salt.payload.loads(data, raw=raw)
|
||||
if isinstance(payload, dict):
|
||||
if "serial" in payload:
|
||||
|
|
|
@ -1048,7 +1048,30 @@ class MinionManager(MinionBase):
|
|||
# self.opts,
|
||||
# io_loop=self.io_loop,
|
||||
# )
|
||||
|
||||
# import hashlib
|
||||
# ipc_publisher = salt.transport.publish_server(self.opts)
|
||||
# hash_type = getattr(hashlib, self.opts["hash_type"])
|
||||
# id_hash = hash_type(
|
||||
# salt.utils.stringutils.to_bytes(self.opts["id"])
|
||||
# ).hexdigest()[:10]
|
||||
# epub_sock_path = "ipc://{}".format(
|
||||
# os.path.join(
|
||||
# self.opts["sock_dir"], "minion_event_{}_pub.ipc".format(id_hash)
|
||||
# )
|
||||
# )
|
||||
# if os.path.exists(epub_sock_path):
|
||||
# os.unlink(epub_sock_path)
|
||||
# epull_sock_path = "ipc://{}".format(
|
||||
# os.path.join(
|
||||
# self.opts["sock_dir"], "minion_event_{}_pull.ipc".format(id_hash)
|
||||
# )
|
||||
# )
|
||||
# ipc_publisher.pub_uri = epub_sock_path
|
||||
# ipc_publisher.pull_uri = epull_sock_path
|
||||
# self.io_loop.add_callback(ipc_publisher.publisher, ipc_publisher.publish_payload, self.io_loop)
|
||||
import hashlib
|
||||
|
||||
ipc_publisher = salt.transport.publish_server(self.opts)
|
||||
hash_type = getattr(hashlib, self.opts["hash_type"])
|
||||
id_hash = hash_type(
|
||||
|
@ -1068,7 +1091,18 @@ class MinionManager(MinionBase):
|
|||
)
|
||||
ipc_publisher.pub_uri = epub_sock_path
|
||||
ipc_publisher.pull_uri = epull_sock_path
|
||||
self.io_loop.add_callback(ipc_publisher.publisher, ipc_publisher.publish_payload, self.io_loop)
|
||||
if self.opts["transport"] == "tcp":
|
||||
|
||||
def target():
|
||||
ipc_publisher.publish_daemon(ipc_publisher.publish_payload)
|
||||
|
||||
proc = salt.utils.process.Process(target=target, daemon=True)
|
||||
proc.start()
|
||||
else:
|
||||
self.io_loop.add_callback(
|
||||
ipc_publisher.publisher, ipc_publisher.publish_payload, self.io_loop
|
||||
)
|
||||
log.error("get event ")
|
||||
self.event = salt.utils.event.get_event(
|
||||
"minion", opts=self.opts, io_loop=self.io_loop
|
||||
)
|
||||
|
@ -1139,11 +1173,8 @@ class MinionManager(MinionBase):
|
|||
self.io_loop.spawn_callback(self._connect_minion, minion)
|
||||
self.io_loop.call_later(timeout, self._check_minions)
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _connect_minion(self, minion):
|
||||
"""
|
||||
Create a minion, and asynchronously connect it to a master
|
||||
"""
|
||||
async def _connect_minion(self, minion):
|
||||
"""Create a minion, and asynchronously connect it to a master"""
|
||||
last = 0 # never have we signed in
|
||||
auth_wait = minion.opts["acceptance_wait_time"]
|
||||
failed = False
|
||||
|
@ -1154,7 +1185,8 @@ class MinionManager(MinionBase):
|
|||
if minion.opts.get("scheduler_before_connect", False):
|
||||
minion.setup_scheduler(before_connect=True)
|
||||
if minion.opts.get("master_type", "str") != "disable":
|
||||
yield minion.connect_master(failed=failed)
|
||||
await minion.connect_master(failed=failed)
|
||||
log.error("RUN MINION TUNE IN")
|
||||
minion.tune_in(start=False)
|
||||
self.minions.append(minion)
|
||||
break
|
||||
|
@ -1164,11 +1196,12 @@ class MinionManager(MinionBase):
|
|||
"Error while bringing up minion for multi-master. Is "
|
||||
"master at %s responding?",
|
||||
minion.opts["master"],
|
||||
exc_info=True,
|
||||
)
|
||||
last = time.time()
|
||||
if auth_wait < self.max_auth_wait:
|
||||
auth_wait += self.auth_wait
|
||||
yield tornado.gen.sleep(auth_wait) # TODO: log?
|
||||
await tornado.gen.sleep(auth_wait) # TODO: log?
|
||||
except SaltMasterUnresolvableError:
|
||||
err = (
|
||||
"Master address: '{}' could not be resolved. Invalid or"
|
||||
|
@ -3269,16 +3302,16 @@ class Minion(MinionBase):
|
|||
self.pub_channel.on_recv(None)
|
||||
log.error("create pub_channel.close task %r", self)
|
||||
self.pub_channel.close()
|
||||
#self.io_loop.asyncio_loop.run_until_complete(self.pub_channel.close())
|
||||
#if hasattr(self.pub_channel, "close"):
|
||||
# self.io_loop.asyncio_loop.run_until_complete(self.pub_channel.close())
|
||||
# if hasattr(self.pub_channel, "close"):
|
||||
# asyncio.create_task(
|
||||
# self.pub_channel.close()
|
||||
# )
|
||||
# #self.pub_channel.close()
|
||||
#del self.pub_channel
|
||||
# del self.pub_channel
|
||||
if hasattr(self, "event"):
|
||||
log.error("HAS EVENT")
|
||||
#if hasattr(self, "periodic_callbacks"):
|
||||
# if hasattr(self, "periodic_callbacks"):
|
||||
# for cb in self.periodic_callbacks.values():
|
||||
# cb.stop()
|
||||
log.error("%r destroy method finished", self)
|
||||
|
|
|
@ -74,7 +74,7 @@ def publish_server(opts, **kwargs):
|
|||
raise Exception("Transport type not found: {}".format(ttype))
|
||||
|
||||
|
||||
def publish_client(opts, io_loop):
|
||||
def ipc_publish_client(opts, io_loop):
|
||||
# Default to ZeroMQ for now
|
||||
ttype = "zeromq"
|
||||
# determine the ttype
|
||||
|
@ -85,6 +85,7 @@ def publish_client(opts, io_loop):
|
|||
# switch on available ttypes
|
||||
if ttype == "zeromq":
|
||||
import salt.transport.zeromq
|
||||
|
||||
return salt.transport.zeromq.PublishClient(opts, io_loop)
|
||||
elif ttype == "tcp":
|
||||
import salt.transport.tcp
|
||||
|
@ -93,6 +94,30 @@ def publish_client(opts, io_loop):
|
|||
raise Exception("Transport type not found: {}".format(ttype))
|
||||
|
||||
|
||||
def publish_client(opts, io_loop, host=None, port=None, path=None):
|
||||
# Default to ZeroMQ for now
|
||||
ttype = "zeromq"
|
||||
# determine the ttype
|
||||
if "transport" in opts:
|
||||
ttype = opts["transport"]
|
||||
elif "transport" in opts.get("pillar", {}).get("master", {}):
|
||||
ttype = opts["pillar"]["master"]["transport"]
|
||||
# switch on available ttypes
|
||||
if ttype == "zeromq":
|
||||
import salt.transport.zeromq
|
||||
|
||||
return salt.transport.zeromq.PublishClient(
|
||||
opts, io_loop, host=host, port=port, path=path
|
||||
)
|
||||
elif ttype == "tcp":
|
||||
import salt.transport.tcp
|
||||
|
||||
return salt.transport.tcp.TCPPubClient(
|
||||
opts, io_loop, host=host, port=port, path=path
|
||||
)
|
||||
raise Exception("Transport type not found: {}".format(ttype))
|
||||
|
||||
|
||||
class RequestClient:
|
||||
"""
|
||||
The RequestClient transport is used to make requests and get corresponding
|
||||
|
|
|
@ -134,11 +134,10 @@ class IPCServer:
|
|||
else:
|
||||
self.sock = tornado.netutil.bind_unix_socket(self.socket_path)
|
||||
|
||||
with salt.utils.asynchronous.current_ioloop(self.io_loop):
|
||||
tornado.netutil.add_accept_handler(
|
||||
self.sock,
|
||||
self.handle_connection,
|
||||
)
|
||||
tornado.netutil.add_accept_handler(
|
||||
self.sock,
|
||||
self.handle_connection,
|
||||
)
|
||||
self._started = True
|
||||
|
||||
@tornado.gen.coroutine
|
||||
|
@ -208,7 +207,7 @@ class IPCServer:
|
|||
log.error("Exception occurred while handling stream: %s", exc)
|
||||
|
||||
def handle_connection(self, connection, address):
|
||||
log.trace(
|
||||
log.error(
|
||||
"IPCServer: Handling connection to address: %s",
|
||||
address if address else connection,
|
||||
)
|
||||
|
@ -338,8 +337,8 @@ class IPCClient:
|
|||
break
|
||||
|
||||
if self.stream is None:
|
||||
with salt.utils.asynchronous.current_ioloop(self.io_loop):
|
||||
self.stream = IOStream(socket.socket(sock_type, socket.SOCK_STREAM))
|
||||
# with salt.utils.asynchronous.current_ioloop(self.io_loop):
|
||||
self.stream = IOStream(socket.socket(sock_type, socket.SOCK_STREAM))
|
||||
try:
|
||||
log.trace("IPCClient: Connecting to socket: %s", self.socket_path)
|
||||
yield self.stream.connect(sock_addr)
|
||||
|
@ -440,8 +439,8 @@ class IPCMessageClient(IPCClient):
|
|||
|
||||
# FIXME timeout unimplemented
|
||||
# FIXME tries unimplemented
|
||||
@tornado.gen.coroutine
|
||||
def send(self, msg, timeout=None, tries=None):
|
||||
# @tornado.gen.coroutine
|
||||
async def send(self, msg, timeout=None, tries=None):
|
||||
"""
|
||||
Send a message to an IPC socket
|
||||
|
||||
|
@ -451,9 +450,9 @@ class IPCMessageClient(IPCClient):
|
|||
:param int timeout: Timeout when sending message (Currently unimplemented)
|
||||
"""
|
||||
if not self.connected():
|
||||
yield self.connect()
|
||||
await self.connect()
|
||||
pack = salt.transport.frame.frame_msg_ipc(msg, raw_body=True)
|
||||
yield self.stream.write(pack)
|
||||
await self.stream.write(pack)
|
||||
|
||||
|
||||
class IPCMessageServer(IPCServer):
|
||||
|
|
|
@ -6,12 +6,13 @@ Wire protocol: "len(payload) msgpack({'head': SOMEHEADER, 'body': SOMEBODY})"
|
|||
|
||||
"""
|
||||
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import queue
|
||||
import select
|
||||
import socket
|
||||
import threading
|
||||
import urllib
|
||||
|
@ -23,6 +24,7 @@ import tornado.iostream
|
|||
import tornado.netutil
|
||||
import tornado.tcpclient
|
||||
import tornado.tcpserver
|
||||
from tornado.locks import Lock
|
||||
|
||||
import salt.master
|
||||
import salt.payload
|
||||
|
@ -215,13 +217,45 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
|
||||
ttype = "tcp"
|
||||
|
||||
async_methods = [
|
||||
"connect",
|
||||
"connect_uri",
|
||||
"recv",
|
||||
# "close",
|
||||
]
|
||||
close_methods = [
|
||||
"close",
|
||||
]
|
||||
|
||||
def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
|
||||
self.opts = opts
|
||||
self.io_loop = io_loop
|
||||
self.message_client = None
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
self.connected = False
|
||||
self._closing = False
|
||||
self.resolver = Resolver()
|
||||
self._stream = None
|
||||
self._closing = False
|
||||
self._closed = False
|
||||
self.backoff = opts.get("tcp_reconnect_backoff", 1)
|
||||
self.resolver = kwargs.get("resolver")
|
||||
self._read_in_progress = Lock()
|
||||
self.poller = None
|
||||
|
||||
self.host = kwargs.get("host", None)
|
||||
self.port = kwargs.get("port", None)
|
||||
self.path = kwargs.get("path", None)
|
||||
self.source_ip = self.opts.get("source_ip")
|
||||
self.source_port = self.opts.get("source_publish_port")
|
||||
self.connect_callback = None
|
||||
self.disconnect_callback = None
|
||||
if self.host is None and self.port is None:
|
||||
if self.path is None:
|
||||
raise Exception("A host and port or a path must be provided")
|
||||
elif self.host and self.port:
|
||||
if self.path:
|
||||
raise Exception("A host and port or a path must be provided, not both")
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
|
@ -237,21 +271,80 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
|
||||
# pylint: enable=W1701
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
|
||||
self.publish_port = publish_port
|
||||
self.message_client = MessageClient(
|
||||
self.opts,
|
||||
self.opts["master_ip"],
|
||||
int(self.publish_port),
|
||||
io_loop=self.io_loop,
|
||||
connect_callback=connect_callback,
|
||||
disconnect_callback=disconnect_callback,
|
||||
source_ip=self.opts.get("source_ip"),
|
||||
source_port=self.opts.get("source_publish_port"),
|
||||
)
|
||||
yield self.message_client.connect() # wait for the client to be connected
|
||||
self.connected = True
|
||||
async def getstream(self, **kwargs):
|
||||
if self.source_ip or self.source_port:
|
||||
kwargs = {
|
||||
"source_ip": self.source_ip,
|
||||
"source_port": self.source_port,
|
||||
}
|
||||
stream = None
|
||||
while stream is None and (not self._closed and not self._closing):
|
||||
try:
|
||||
if self.host and self.port:
|
||||
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
|
||||
self._tcp_client = TCPClientKeepAlive(
|
||||
self.opts, resolver=self.resolver
|
||||
)
|
||||
stream = await self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
sock_type = socket.AF_UNIX
|
||||
stream = tornado.iostream.IOStream(
|
||||
socket.socket(sock_type, socket.SOCK_STREAM)
|
||||
)
|
||||
await stream.connect(self.path)
|
||||
self.poller = select.poll()
|
||||
self.poller.register(stream.socket, select.POLLIN)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
log.warning(
|
||||
"TCP Message Client encountered an exception while connecting to"
|
||||
" %s:%s %s: %r, will reconnect in %d seconds",
|
||||
self.host,
|
||||
self.port,
|
||||
self.path,
|
||||
exc,
|
||||
self.backoff,
|
||||
)
|
||||
await tornado.gen.sleep(self.backoff)
|
||||
return stream
|
||||
|
||||
async def _connect(self):
|
||||
# log.error("Connect %r %r", self, self._stream)
|
||||
# import traceback
|
||||
# stack = "".join(traceback.format_stack())
|
||||
# log.error(f"MessageClient Connect {name} {stack}")
|
||||
if self._stream is None:
|
||||
log.error("Get stream")
|
||||
self._stream = await self.getstream()
|
||||
log.error("Got stream")
|
||||
if self._stream:
|
||||
# if not self._stream_return_running:
|
||||
# self.io_loop.spawn_callback(self._stream_return)
|
||||
if self.connect_callback:
|
||||
self.connect_callback(True)
|
||||
self.connected = True
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
port=None,
|
||||
connect_callback=None,
|
||||
disconnect_callback=None,
|
||||
background=False,
|
||||
):
|
||||
if port is not None:
|
||||
self.port = port
|
||||
if connect_callback:
|
||||
self.connect_callback = None
|
||||
if disconnect_callback:
|
||||
self.disconnect_callback = None
|
||||
if background:
|
||||
self.io_loop.spawn_callback(self._connect)
|
||||
else:
|
||||
await self._connect()
|
||||
|
||||
def _decode_messages(self, messages):
|
||||
if not isinstance(messages, dict):
|
||||
|
@ -263,15 +356,83 @@ class TCPPubClient(salt.transport.base.PublishClient):
|
|||
body = messages
|
||||
return body
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def send(self, msg):
|
||||
yield self.message_client._stream.write(msg)
|
||||
async def send(self, msg):
|
||||
await self.message_client.send(msg, reply=False)
|
||||
|
||||
async def recv(self, timeout=None):
|
||||
try:
|
||||
await self._read_in_progress.acquire(timeout=0.00000001)
|
||||
except tornado.gen.TimeoutError:
|
||||
log.error("Timeout Error")
|
||||
return
|
||||
try:
|
||||
if timeout == 0:
|
||||
if not self._stream:
|
||||
await asyncio.sleep(0.001)
|
||||
return
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
poller = select.poll()
|
||||
poller.register(self._stream.socket, select.POLLIN)
|
||||
try:
|
||||
events = poller.poll(0)
|
||||
except TimeoutError:
|
||||
events = []
|
||||
if events:
|
||||
while True:
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
elif timeout:
|
||||
return await asyncio.wait_for(self.recv(), timeout=timeout)
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
while True:
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
finally:
|
||||
self._read_in_progress.release()
|
||||
|
||||
def on_recv(self, callback):
|
||||
"""
|
||||
Register an on_recv callback
|
||||
Register a callback for received messages (that we didn't initiate)
|
||||
"""
|
||||
return self.message_client.on_recv(callback)
|
||||
|
||||
async def setup_callback():
|
||||
logit = True
|
||||
while not self._stream:
|
||||
# log.error("On recv wait stream %r", self._stream)
|
||||
await asyncio.sleep(0.003)
|
||||
while True:
|
||||
try:
|
||||
msg = await self.recv()
|
||||
logit = True
|
||||
except tornado.iostream.StreamClosedError:
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
if logit:
|
||||
log.error("Stream Closed", exc_info=True)
|
||||
logit = False
|
||||
await self._connect()
|
||||
await asyncio.sleep(0.03)
|
||||
# if self.disconnect_callback:
|
||||
# self.disconnect_callback()
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
continue
|
||||
except Exception:
|
||||
log.error("Other exception", exc_info=True)
|
||||
log.error("on recv got msg %r", msg)
|
||||
callback(msg)
|
||||
|
||||
self.io_loop.spawn_callback(setup_callback)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -389,6 +550,7 @@ class TCPReqServer(salt.transport.base.DaemonizedRequestServer):
|
|||
def handle_message(self, stream, payload, header=None):
|
||||
payload = self.decode_payload(payload)
|
||||
reply = yield self.message_handler(payload)
|
||||
# XXX Handle StreamClosedError
|
||||
stream.write(salt.transport.frame.frame_msg(reply, header=header))
|
||||
|
||||
def decode_payload(self, payload):
|
||||
|
@ -545,6 +707,7 @@ class MessageClient:
|
|||
opts,
|
||||
host,
|
||||
port,
|
||||
url=None,
|
||||
io_loop=None,
|
||||
resolver=None,
|
||||
connect_callback=None,
|
||||
|
@ -552,9 +715,13 @@ class MessageClient:
|
|||
source_ip=None,
|
||||
source_port=None,
|
||||
):
|
||||
# import traceback
|
||||
# stack = "".join(traceback.format_stack())
|
||||
# print(stack)
|
||||
self.opts = opts
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = url
|
||||
self.source_ip = source_ip
|
||||
self.source_port = source_port
|
||||
self.connect_callback = connect_callback
|
||||
|
@ -577,6 +744,14 @@ class MessageClient:
|
|||
self._stream = None
|
||||
|
||||
self.backoff = opts.get("tcp_reconnect_backoff", 1)
|
||||
self.callbacks = {}
|
||||
import traceback
|
||||
|
||||
self.stack = "".join(traceback.format_stack())
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
self._read_in_progress = Lock()
|
||||
self.task = None
|
||||
self.tcp_client = None
|
||||
|
||||
def _stop_io_loop(self):
|
||||
if self.io_loop is not None:
|
||||
|
@ -587,8 +762,21 @@ class MessageClient:
|
|||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self.stream.close()
|
||||
self.tcp_client.close()
|
||||
self.io_loop.add_timeout(1, self.check_close)
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
if self._stream is not None:
|
||||
self._stream.close()
|
||||
if self.tcp_client is not None:
|
||||
self.tcp_client.close()
|
||||
if self.task is not None:
|
||||
self.task.cancel()
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def check_close(self):
|
||||
if not self.send_future_map:
|
||||
|
@ -601,12 +789,11 @@ class MessageClient:
|
|||
|
||||
# pylint: disable=W1701
|
||||
def __del__(self):
|
||||
self.close()
|
||||
self.io_loop.spawn_callback(self.close)
|
||||
|
||||
# pylint: enable=W1701
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def getstream(self, **kwargs):
|
||||
async def getstream(self, **kwargs):
|
||||
if self.source_ip or self.source_port:
|
||||
kwargs = {
|
||||
"source_ip": self.source_ip,
|
||||
|
@ -615,12 +802,22 @@ class MessageClient:
|
|||
stream = None
|
||||
while stream is None and (not self._closed and not self._closing):
|
||||
try:
|
||||
stream = yield self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs
|
||||
)
|
||||
if self.host and self.port:
|
||||
# log.error("GET STREAM TCP %r %s %s %s", self.url, self.host, self.port, self.stack)
|
||||
stream = await self._tcp_client.connect(
|
||||
ip_bracket(self.host, strip=True),
|
||||
self.port,
|
||||
ssl_options=self.opts.get("ssl"),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
log.error("GET STREAM IPC")
|
||||
sock_type = socket.AF_UNIX
|
||||
path = self.url.replace("ipc://", "")
|
||||
stream = tornado.iostream.IOStream(
|
||||
socket.socket(sock_type, socket.SOCK_STREAM)
|
||||
)
|
||||
await stream.connect(path)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
log.warning(
|
||||
"TCP Message Client encountered an exception while connecting to"
|
||||
|
@ -630,26 +827,34 @@ class MessageClient:
|
|||
exc,
|
||||
self.backoff,
|
||||
)
|
||||
yield tornado.gen.sleep(self.backoff)
|
||||
raise tornado.gen.Return(stream)
|
||||
await tornado.gen.sleep(self.backoff)
|
||||
return stream
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def connect(self):
|
||||
async def connect(self):
|
||||
log.error("Connect %r %r", self, self._stream)
|
||||
import traceback
|
||||
|
||||
# stack = "".join(traceback.format_stack())
|
||||
# log.error(f"MessageClient Connect {name} {stack}")
|
||||
if self._stream is None:
|
||||
self._stream = yield self.getstream()
|
||||
log.error("Get stream")
|
||||
self._stream = await self.getstream()
|
||||
log.error("Got stream")
|
||||
if self._stream:
|
||||
if not self._stream_return_running:
|
||||
self.io_loop.spawn_callback(self._stream_return)
|
||||
self.task = asyncio.create_task(self._stream_return())
|
||||
# self.io_loop.spawn_callback(self._stream_return)
|
||||
if self.connect_callback:
|
||||
self.connect_callback(True)
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _stream_return(self):
|
||||
async def _stream_return(self):
|
||||
self._stream_return_running = True
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
while not self._closing:
|
||||
try:
|
||||
wire_bytes = yield self._stream.read_bytes(4096, partial=True)
|
||||
log.error("Stream read bytes %r", self._stream.socket)
|
||||
wire_bytes = await self._stream.read_bytes(4096, partial=True)
|
||||
log.error("Stream got bytes")
|
||||
unpacker.feed(wire_bytes)
|
||||
for framed_msg in unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
|
||||
|
@ -663,6 +868,7 @@ class MessageClient:
|
|||
else:
|
||||
if self._on_recv is not None:
|
||||
self.io_loop.spawn_callback(self._on_recv, header, body)
|
||||
# await self._on_recv(header, body)
|
||||
else:
|
||||
log.error(
|
||||
"Got response for message_id %s that we are not"
|
||||
|
@ -670,7 +876,7 @@ class MessageClient:
|
|||
message_id,
|
||||
)
|
||||
except tornado.iostream.StreamClosedError as e:
|
||||
log.debug(
|
||||
log.error(
|
||||
"tcp stream to %s:%s closed, unable to recv",
|
||||
self.host,
|
||||
self.port,
|
||||
|
@ -687,7 +893,7 @@ class MessageClient:
|
|||
if stream:
|
||||
stream.close()
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
yield self.connect()
|
||||
await self.connect()
|
||||
except TypeError:
|
||||
# This is an invalid transport
|
||||
if "detect_mode" in self.opts:
|
||||
|
@ -697,6 +903,9 @@ class MessageClient:
|
|||
)
|
||||
else:
|
||||
raise SaltClientError
|
||||
# except OSError:
|
||||
# log.error("OSERROR", exc_info=True)
|
||||
# raise
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
log.error("Exception parsing response", exc_info=True)
|
||||
for future in self.send_future_map.values():
|
||||
|
@ -711,7 +920,7 @@ class MessageClient:
|
|||
if stream:
|
||||
stream.close()
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
yield self.connect()
|
||||
await self.connect()
|
||||
self._stream_return_running = False
|
||||
|
||||
def _message_id(self):
|
||||
|
@ -733,14 +942,7 @@ class MessageClient:
|
|||
"""
|
||||
Register a callback for received messages (that we didn't initiate)
|
||||
"""
|
||||
if callback is None:
|
||||
self._on_recv = callback
|
||||
else:
|
||||
|
||||
def wrap_recv(header, body):
|
||||
callback(body)
|
||||
|
||||
self._on_recv = wrap_recv
|
||||
self._on_recv = callback
|
||||
|
||||
def remove_message_timeout(self, message_id):
|
||||
if message_id not in self.send_timeout_map:
|
||||
|
@ -755,13 +957,19 @@ class MessageClient:
|
|||
if future is not None:
|
||||
future.set_exception(SaltReqTimeoutError("Message timed out"))
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def send(self, msg, timeout=None, callback=None, raw=False):
|
||||
async def send(self, msg, timeout=None, callback=None, raw=False, reply=True):
|
||||
# log.error("stream send %r %r %r %r", self, self.url, self.port, reply)
|
||||
if self._closing:
|
||||
raise ClosingError()
|
||||
while not self._stream:
|
||||
# log.error("Wait stream %r %r", self, self._stream)
|
||||
await asyncio.sleep(0.03)
|
||||
message_id = self._message_id()
|
||||
header = {"mid": message_id}
|
||||
|
||||
# item = salt.transport.frame.frame_msg(msg, header=header)
|
||||
# await self._stream.write(item)
|
||||
# if reply:
|
||||
# return await self.recv(timeout=None)
|
||||
future = tornado.concurrent.Future()
|
||||
|
||||
if callback is not None:
|
||||
|
@ -782,18 +990,66 @@ class MessageClient:
|
|||
|
||||
item = salt.transport.frame.frame_msg(msg, header=header)
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _do_send():
|
||||
yield self.connect()
|
||||
async def _do_send():
|
||||
await self.connect()
|
||||
# If the _stream is None, we failed to connect.
|
||||
if self._stream:
|
||||
yield self._stream.write(item)
|
||||
await self._stream.write(item)
|
||||
|
||||
# Run send in a callback so we can wait on the future, in case we time
|
||||
# out before we are able to connect.
|
||||
self.io_loop.add_callback(_do_send)
|
||||
recv = yield future
|
||||
raise tornado.gen.Return(recv)
|
||||
recv = await future
|
||||
return recv
|
||||
|
||||
async def recv(self, timeout=None):
|
||||
try:
|
||||
await self._read_in_progress.acquire(timeout=0.00000001)
|
||||
except tornado.gen.TimeoutError:
|
||||
log.error("Timeout Error")
|
||||
return
|
||||
try:
|
||||
if timeout == 0:
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV a")
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
poller = select.poll()
|
||||
poller.register(self._stream.socket, select.POLLIN)
|
||||
try:
|
||||
events = poller.poll(0)
|
||||
except TimeoutError:
|
||||
events = []
|
||||
if events:
|
||||
while True:
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV b")
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
else:
|
||||
return
|
||||
elif timeout:
|
||||
log.error("RECV c")
|
||||
return await asyncio.wait_for(self.recv(), timeout=timeout)
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV d")
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
while True:
|
||||
log.error("RECV e")
|
||||
byts = await self._stream.read_bytes(4096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
log.error("RECV e %r", msg)
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
log.error("RECV e %r", framed_msg)
|
||||
return framed_msg["body"]
|
||||
finally:
|
||||
self._read_in_progress.release()
|
||||
# await asyncio.sleep(.003)
|
||||
|
||||
|
||||
class Subscriber:
|
||||
|
@ -898,9 +1154,10 @@ class PubServer(tornado.tcpserver.TCPServer):
|
|||
self.io_loop.spawn_callback(self._stream_read, client)
|
||||
|
||||
# TODO: ACK the publish through IPC
|
||||
@tornado.gen.coroutine
|
||||
def publish_payload(self, package, topic_list=None):
|
||||
log.trace("TCP PubServer sending payload: %s \n\n %r", package, topic_list)
|
||||
async def publish_payload(self, package, topic_list=None):
|
||||
log.trace(
|
||||
"TCP PubServer sending payload: topic_list=%r %r", topic_list, package
|
||||
)
|
||||
payload = salt.transport.frame.frame_msg(package)
|
||||
to_remove = []
|
||||
if topic_list:
|
||||
|
@ -910,7 +1167,7 @@ class PubServer(tornado.tcpserver.TCPServer):
|
|||
if topic == client.id_:
|
||||
try:
|
||||
# Write the packed str
|
||||
yield client.stream.write(payload)
|
||||
await client.stream.write(payload)
|
||||
sent = True
|
||||
# self.io_loop.add_future(f, lambda f: True)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
|
@ -921,7 +1178,7 @@ class PubServer(tornado.tcpserver.TCPServer):
|
|||
for client in self.clients:
|
||||
try:
|
||||
# Write the packed str
|
||||
yield client.stream.write(payload)
|
||||
await client.stream.write(payload)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
to_remove.append(client)
|
||||
for client in to_remove:
|
||||
|
@ -942,10 +1199,32 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
# TODO: opts!
|
||||
# Based on default used in tornado.netutil.bind_sockets()
|
||||
backlog = 128
|
||||
async_methods = [
|
||||
"publish",
|
||||
# "close",
|
||||
]
|
||||
close_methods = [
|
||||
"close",
|
||||
]
|
||||
|
||||
def __init__(self, opts):
|
||||
self.opts = opts
|
||||
self.pub_sock = None
|
||||
# Set up Salt IPC server
|
||||
if self.opts.get("ipc_mode", "") == "tcp":
|
||||
self.pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
|
||||
else:
|
||||
self.pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
interface = self.opts.get("interface", "127.0.0.1")
|
||||
self.publish_port = self.opts.get("publish_port", 4560)
|
||||
self.pub_uri = f"tcp://{interface}:{self.publish_port}"
|
||||
log.error(
|
||||
"TCPPubServer %r %s %s %s",
|
||||
self,
|
||||
self.pull_uri,
|
||||
self.publish_port,
|
||||
self.pub_uri,
|
||||
)
|
||||
|
||||
@property
|
||||
def topic_support(self):
|
||||
|
@ -967,6 +1246,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
Bind to the interface specified in the configuration file
|
||||
"""
|
||||
io_loop = tornado.ioloop.IOLoop()
|
||||
log.error(
|
||||
"TCPPubServer daemon %r %s %s %s",
|
||||
self,
|
||||
self.pull_uri,
|
||||
self.publish_port,
|
||||
self.pub_uri,
|
||||
)
|
||||
|
||||
# Spin up the publisher
|
||||
self.pub_server = pub_server = PubServer(
|
||||
|
@ -975,21 +1261,34 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
presence_callback=presence_callback,
|
||||
remove_presence_callback=remove_presence_callback,
|
||||
)
|
||||
sock = _get_socket(self.opts)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
_set_tcp_keepalive(sock, self.opts)
|
||||
sock.setblocking(0)
|
||||
sock.bind(_get_bind_addr(self.opts, "publish_port"))
|
||||
if self.pub_uri.startswith("ipc://"):
|
||||
pub_path = self.pub_uri.replace("ipc://", "")
|
||||
sock = tornado.netutil.bind_unix_socket(pub_path)
|
||||
else:
|
||||
sock = _get_socket(self.opts)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
_set_tcp_keepalive(sock, self.opts)
|
||||
sock.setblocking(0)
|
||||
sock.bind(_get_bind_addr(self.opts, "publish_port"))
|
||||
sock.listen(self.backlog)
|
||||
# pub_server will take ownership of the socket
|
||||
pub_server.add_socket(sock)
|
||||
|
||||
# Set up Salt IPC server
|
||||
if self.opts.get("ipc_mode", "") == "tcp":
|
||||
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
|
||||
else:
|
||||
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
# if self.opts.get("ipc_mode", "") == "tcp":
|
||||
# pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
|
||||
# else:
|
||||
# pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
self.pub_server = pub_server
|
||||
if "ipc://" in self.pull_uri:
|
||||
pull_uri = pull_uri = self.pull_uri.replace("ipc://", "")
|
||||
log.error("WTF PULL URI %r", pull_uri)
|
||||
elif "tcp://" in self.pull_uri:
|
||||
log.error("Fallback to publish port %r", self.pull_uri)
|
||||
pull_uri = self.publish_port
|
||||
else:
|
||||
pull_uri = self.pull_uri
|
||||
|
||||
pull_sock = salt.transport.ipc.IPCMessageServer(
|
||||
pull_uri,
|
||||
io_loop=io_loop,
|
||||
|
@ -997,7 +1296,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
)
|
||||
|
||||
# Securely create socket
|
||||
log.warning("Starting the Salt Puller on %s", pull_uri)
|
||||
log.warning("Starting the Salt Puller on %s", self.pull_uri)
|
||||
with salt.utils.files.set_umask(0o177):
|
||||
pull_sock.start()
|
||||
|
||||
|
@ -1019,24 +1318,44 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
|
||||
@tornado.gen.coroutine
|
||||
def publish_payload(self, payload, *args):
|
||||
ret = yield self.pub_server.publish_payload(payload, *args)
|
||||
# log.error("Publish paylaod %r %r", payload, args)
|
||||
ret = yield self.pub_server.publish_payload(payload) # , *args)
|
||||
raise tornado.gen.Return(ret)
|
||||
|
||||
def publish(self, payload, **kwargs):
|
||||
def connect(self):
|
||||
path = self.pull_uri.replace("ipc://", "")
|
||||
log.error("Connect pusher %s", path)
|
||||
# self.pub_sock = salt.utils.asynchronous.SyncWrapper(
|
||||
# salt.transport.ipc.IPCMessageClient,
|
||||
# (path,),
|
||||
# loop_kwarg="io_loop",
|
||||
# )
|
||||
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
|
||||
salt.transport.ipc.IPCMessageClient,
|
||||
(path,),
|
||||
loop_kwarg="io_loop",
|
||||
)
|
||||
# self.pub_sock = salt.transport.ipc.IPCMessageClient(path)
|
||||
self.pub_sock.connect()
|
||||
|
||||
async def publish(self, payload, **kwargs):
|
||||
"""
|
||||
Publish "load" to minions
|
||||
"""
|
||||
if self.opts.get("ipc_mode", "") == "tcp":
|
||||
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
|
||||
else:
|
||||
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
if not self.pub_sock:
|
||||
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
|
||||
salt.transport.ipc.IPCMessageClient,
|
||||
(pull_uri,),
|
||||
loop_kwarg="io_loop",
|
||||
)
|
||||
self.pub_sock.connect()
|
||||
self.connect()
|
||||
# if self.opts.get("ipc_mode", "") == "tcp":
|
||||
# pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
|
||||
# else:
|
||||
# pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
|
||||
# if not self.pub_sock:
|
||||
# self.pub_sock = salt.utils.asynchronous.SyncWrapper(
|
||||
# salt.transport.ipc.IPCMessageClient,
|
||||
# (pull_uri,),
|
||||
# loop_kwarg="io_loop",
|
||||
# )
|
||||
# self.pub_sock.connect()
|
||||
# await self.pub_sock.send(payload)
|
||||
self.pub_sock.send(payload)
|
||||
|
||||
def close(self):
|
||||
|
@ -1045,6 +1364,11 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
self.pub_sock = None
|
||||
|
||||
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
|
||||
class TCPReqClient(salt.transport.base.RequestClient):
|
||||
"""
|
||||
Tornado based TCP RequestClient
|
||||
|
@ -1070,14 +1394,17 @@ class TCPReqClient(salt.transport.base.RequestClient):
|
|||
source_port=opts.get("source_ret_port"),
|
||||
)
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def connect(self):
|
||||
yield self.message_client.connect()
|
||||
async def connect(self):
|
||||
log.error("TCPReqClient Connect")
|
||||
await self.message_client.connect()
|
||||
log.error("TCPReqClient Connected")
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def send(self, load, timeout=60):
|
||||
ret = yield self.message_client.send(load, timeout=timeout)
|
||||
raise tornado.gen.Return(ret)
|
||||
async def send(self, load, timeout=60):
|
||||
await self.connect()
|
||||
log.error("TCP Request %r", load)
|
||||
msg = await self.message_client.send(load, timeout=timeout)
|
||||
log.error("TCP Reply %r", msg)
|
||||
return msg
|
||||
|
||||
def close(self):
|
||||
self.message_client.close()
|
||||
|
|
|
@ -113,7 +113,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
"connect",
|
||||
"connect_uri",
|
||||
"recv",
|
||||
#"close",
|
||||
# "close",
|
||||
]
|
||||
close_methods = [
|
||||
"close",
|
||||
|
@ -139,7 +139,6 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
self.hexid = hashlib.sha1(salt.utils.stringutils.to_bytes(_id)).hexdigest()
|
||||
self._closing = False
|
||||
self.context = zmq.asyncio.Context()
|
||||
log.error("ZMQ Context creat %r", self)
|
||||
self._socket = self.context.socket(zmq.SUB)
|
||||
self._socket.setsockopt(zmq.LINGER, -1)
|
||||
if zmq_filtering:
|
||||
|
@ -211,6 +210,18 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
self.connect_called = False
|
||||
self.callbacks = {}
|
||||
|
||||
self.host = kwargs.get("host", None)
|
||||
self.port = kwargs.get("port", None)
|
||||
self.path = kwargs.get("path", None)
|
||||
self.source_ip = self.opts.get("source_ip")
|
||||
self.source_port = self.opts.get("source_publish_port")
|
||||
if self.host is None and self.port is None:
|
||||
if self.path is None:
|
||||
raise Exception("A host and port or a path must be provided")
|
||||
elif self.host and self.port:
|
||||
if self.path:
|
||||
raise Exception("A host and port or a path must be provided, not both")
|
||||
|
||||
async def close(self):
|
||||
if self._closing is True:
|
||||
return
|
||||
|
@ -223,9 +234,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
elif hasattr(self, "_socket"):
|
||||
self._socket.close(0)
|
||||
if hasattr(self, "context") and self.context.closed is False:
|
||||
log.error("ZMQ Context term %r", self)
|
||||
self.context.term()
|
||||
log.error("ZMQ Context after term %r", self)
|
||||
if self.callbacks:
|
||||
for cb in self.callbacks:
|
||||
running, task = self.callbacks[cb]
|
||||
|
@ -243,9 +252,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
elif hasattr(self, "_socket"):
|
||||
self._socket.close(0)
|
||||
if hasattr(self, "context") and self.context.closed is False:
|
||||
log.error("ZMQ Context term %r", self)
|
||||
self.context.term()
|
||||
log.error("ZMQ Context after term %r", self)
|
||||
|
||||
# pylint: enable=W1701
|
||||
def __enter__(self):
|
||||
|
@ -255,18 +262,26 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
self.close()
|
||||
|
||||
# TODO: this is the time to see if we are connected, maybe use the req channel to guess?
|
||||
async def connect(
|
||||
self, publish_port, connect_callback=None, disconnect_callback=None
|
||||
):
|
||||
async def connect(self, port=None, connect_callback=None, disconnect_callback=None):
|
||||
self.connect_called = True
|
||||
self.publish_port = publish_port
|
||||
self.uri = self.master_pub
|
||||
log.debug(
|
||||
"Connecting the Minion to the Master publish port, using the URI: %s",
|
||||
self.master_pub,
|
||||
)
|
||||
self._socket.connect(self.master_pub)
|
||||
# await connect_callback(True)
|
||||
if self.path:
|
||||
pub_uri = f"ipc://{self.path}"
|
||||
log.debug("Connecting the publisher client to: %s", pub_uri)
|
||||
self._socket.connect(pub_uri)
|
||||
else:
|
||||
# host = self.opts["master_ip"],
|
||||
if port is not None:
|
||||
self.port = port
|
||||
master_pub_uri = _get_master_uri(
|
||||
self.host, self.port, self.source_ip, self.source_port
|
||||
)
|
||||
log.debug(
|
||||
"Connecting the Minion to the Master publish port, using the URI: %s",
|
||||
master_pub_uri,
|
||||
)
|
||||
self._socket.connect(master_pub_uri)
|
||||
if connect_callback:
|
||||
await connect_callback(True)
|
||||
|
||||
async def connect_uri(self, uri, connect_callback=None, disconnect_callback=None):
|
||||
self.connect_called = True
|
||||
|
@ -275,19 +290,19 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
self.uri = uri
|
||||
self._socket.connect(uri)
|
||||
if connect_callback:
|
||||
connect_callback(True)
|
||||
await connect_callback(True)
|
||||
|
||||
@property
|
||||
def master_pub(self):
|
||||
"""
|
||||
Return the master publish port
|
||||
"""
|
||||
return _get_master_uri(
|
||||
self.opts["master_ip"],
|
||||
self.publish_port,
|
||||
source_ip=self.opts.get("source_ip"),
|
||||
source_port=self.opts.get("source_publish_port"),
|
||||
)
|
||||
# @property
|
||||
# def master_pub(self):
|
||||
# """
|
||||
# Return the master publish port
|
||||
# """
|
||||
# return _get_master_uri(
|
||||
# self.opts["master_ip"],
|
||||
# self.publish_port,
|
||||
# source_ip=self.opts.get("source_ip"),
|
||||
# source_port=self.opts.get("source_publish_port"),
|
||||
# )
|
||||
|
||||
def _decode_messages(self, messages):
|
||||
"""
|
||||
|
@ -359,27 +374,23 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
try:
|
||||
while running.is_set():
|
||||
try:
|
||||
log.error("Waiting for pyaload from %r", self.uri)
|
||||
msg = await self.recv(timeout=None)
|
||||
log.error("Got for pyaload from %r", self.uri)
|
||||
except zmq.error.ZMQError as exc:
|
||||
log.error("ZMQERROR, %s", exc)
|
||||
# We've disconnected just die
|
||||
break
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.error("WTF", exc_info=True)
|
||||
break
|
||||
# except Exception: # pylint: disable=broad-except
|
||||
# log.error("WTF", exc_info=True)
|
||||
# break
|
||||
if msg:
|
||||
try:
|
||||
log.error("Running callback for pyaload from %r", self.uri)
|
||||
await callback(msg)
|
||||
log.error("Finished callback for pyaload from %r", self.uri)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.error("Exception while running callback", exc_info=True)
|
||||
#log.debug("Callback done %r", callback)
|
||||
# log.debug("Callback done %r", callback)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
log.error("CONSUME Exception %s %s", self.uri, exc, exc_info=True)
|
||||
log.error("CONSUME ENDING %s", self.uri)
|
||||
log.error(
|
||||
"Exception while consuming%s %s", self.uri, exc, exc_info=True
|
||||
)
|
||||
|
||||
task = self.io_loop.spawn_callback(consume, running)
|
||||
self.callbacks[callback] = running, task
|
||||
|
@ -400,7 +411,6 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
|||
"""
|
||||
self.__setup_signals()
|
||||
context = zmq.Context(self.opts["worker_threads"])
|
||||
log.error("ZMQ Context create %r", self)
|
||||
# Prepare the zeromq sockets
|
||||
self.uri = "tcp://{interface}:{ret_port}".format(**self.opts)
|
||||
self.clients = context.socket(zmq.ROUTER)
|
||||
|
@ -448,9 +458,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
|||
raise
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
break
|
||||
log.error("ZMQ Context term %r", self)
|
||||
context.term()
|
||||
log.error("ZMQ Context after term %r", self)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
|
@ -476,9 +484,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
|||
if hasattr(self, "_socket") and self._socket.closed is False:
|
||||
self._socket.close()
|
||||
if hasattr(self, "context") and self.context.closed is False:
|
||||
log.error("ZMQ Context term %r", self)
|
||||
self.context.term()
|
||||
log.error("ZMQ Context after term %r", self)
|
||||
|
||||
def pre_fork(self, process_manager):
|
||||
"""
|
||||
|
@ -596,6 +602,7 @@ def _set_tcp_keepalive(zmq_socket, opts):
|
|||
if "tcp_keepalive_intvl" in opts:
|
||||
zmq_socket.setsockopt(zmq.TCP_KEEPALIVE_INTVL, opts["tcp_keepalive_intvl"])
|
||||
|
||||
|
||||
ctx = zmq.asyncio.Context()
|
||||
|
||||
# TODO: unit tests!
|
||||
|
@ -627,13 +634,13 @@ class AsyncReqMessageClient:
|
|||
self.io_loop = io_loop
|
||||
|
||||
self.context = zmq.asyncio.Context()
|
||||
log.error("ZMQ Context create %r", self)
|
||||
|
||||
self.send_queue = []
|
||||
# mapping of message -> future
|
||||
self.send_future_map = {}
|
||||
|
||||
self._closing = False
|
||||
self.socket = None
|
||||
|
||||
async def connect(self):
|
||||
# wire up sockets
|
||||
|
@ -650,15 +657,14 @@ class AsyncReqMessageClient:
|
|||
return
|
||||
else:
|
||||
self._closing = True
|
||||
self.socket.close()
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
if self.context.closed is False:
|
||||
# This hangs if closing the stream causes an import error
|
||||
log.error("ZMQ Context term %r", self)
|
||||
self.context.term()
|
||||
log.error("ZMQ Context after term %r", self)
|
||||
|
||||
def _init_socket(self):
|
||||
if hasattr(self, "socket"):
|
||||
if self.socket is not None:
|
||||
self.socket.close() # pylint: disable=E0203
|
||||
del self.socket
|
||||
|
||||
|
@ -693,6 +699,8 @@ class AsyncReqMessageClient:
|
|||
future.set_exception(SaltReqTimeoutError("Message timed out"))
|
||||
|
||||
async def _send_recv(self, message):
|
||||
if not self.socket:
|
||||
await self.connect()
|
||||
message = salt.payload.dumps(message)
|
||||
await self.socket.send(message)
|
||||
ret = await self.socket.recv()
|
||||
|
@ -793,7 +801,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
# _sock_data = threading.local()
|
||||
async_methods = [
|
||||
"publish",
|
||||
#"close",
|
||||
# "close",
|
||||
]
|
||||
close_methods = [
|
||||
"close",
|
||||
|
@ -814,10 +822,10 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
self.pub_uri = f"tcp://{interface}:{publish_port}"
|
||||
self.ctx = None
|
||||
self.sock = None
|
||||
self.deamon_context = None
|
||||
self.deamon_pub_sock = None
|
||||
self.deamon_pull_sock = None
|
||||
self.deamon_monitor = None
|
||||
self.daemon_context = None
|
||||
self.daemon_pub_sock = None
|
||||
self.daemon_pull_sock = None
|
||||
self.daemon_monitor = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PublishServer pub_uri={self.pub_uri} pull_uri={self.pull_uri} at {hex(id(self))}>"
|
||||
|
@ -837,10 +845,9 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
try:
|
||||
ioloop.start()
|
||||
finally:
|
||||
self.close()
|
||||
self.daemon_context.term()
|
||||
|
||||
def _get_sockets(self, context, ioloop):
|
||||
log.error("ZMQ Context create %r", self)
|
||||
pub_sock = context.socket(zmq.PUB)
|
||||
monitor = ZeroMQSocketMonitor(pub_sock)
|
||||
monitor.start_io_loop(ioloop)
|
||||
|
@ -893,12 +900,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
ioloop = tornado.ioloop.IOLoop.current()
|
||||
ioloop.asyncio_loop.set_debug(True)
|
||||
self.daemon_context = zmq.asyncio.Context()
|
||||
self.daemon_pull_sock, self.daemon_pub_sock, self.deamon_monitor = self._get_sockets(self.daemon_context, ioloop)
|
||||
(
|
||||
self.daemon_pull_sock,
|
||||
self.daemon_pub_sock,
|
||||
self.daemon_monitor,
|
||||
) = self._get_sockets(self.daemon_context, ioloop)
|
||||
while True:
|
||||
try:
|
||||
log.error("Publisher wait package %s", self.pull_uri)
|
||||
package = await self.daemon_pull_sock.recv()
|
||||
log.error("Publisher got package %s %r", self.pull_uri, package)
|
||||
# payload = salt.payload.loads(package)
|
||||
await publish_payload(package)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
|
@ -907,8 +916,8 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
)
|
||||
|
||||
async def publish_payload(self, payload, topic_list=None):
|
||||
log.error(f"Publish payload %s %r", self.pub_uri, payload)
|
||||
try:
|
||||
log.trace("Publish payload %r", payload)
|
||||
# payload = salt.payload.dumps(payload)
|
||||
if self.opts["zmq_filtering"]:
|
||||
if topic_list:
|
||||
|
@ -967,7 +976,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
already exists "pub_close" is called before creating and connecting a
|
||||
new socket.
|
||||
"""
|
||||
log.debug("Connecting to pub server: %s", self.pull_uri)
|
||||
log.error("Connecting to pub server: %s", self.pull_uri)
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.sock = self.ctx.socket(zmq.PUSH)
|
||||
self.sock.setsockopt(zmq.LINGER, 300)
|
||||
|
@ -982,23 +991,19 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
if self.sock is not None:
|
||||
sock = self.sock
|
||||
self.sock = None
|
||||
log.error("Socket close %r", self)
|
||||
sock.close()
|
||||
log.error("Socket closed %r", self)
|
||||
if self.ctx and self.ctx.closed is False:
|
||||
ctx = self.ctx
|
||||
self.ctx = None
|
||||
log.error("Context term %r", self)
|
||||
ctx.term()
|
||||
log.error("After context term %r", self)
|
||||
if self.deamon_pub_sock:
|
||||
self.deamon_pub_sock.close()
|
||||
if self.deamon_pull_sock:
|
||||
self.deamon_pull_sock.close()
|
||||
if self.daemon_pub_sock:
|
||||
self.daemon_pub_sock.close()
|
||||
if self.daemon_pull_sock:
|
||||
self.daemon_pull_sock.close()
|
||||
if self.daemon_monitor:
|
||||
self.daemon_monitor.close()
|
||||
if self.deamon_context:
|
||||
self.deamon_context.term()
|
||||
self.daemon_monitor.stop()
|
||||
if self.daemon_context:
|
||||
self.daemon_context.term()
|
||||
|
||||
async def publish(self, payload, **kwargs):
|
||||
"""
|
||||
|
@ -1009,7 +1014,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
"""
|
||||
if not self.sock:
|
||||
self.connect()
|
||||
log.error("%r send %r", self, payload)
|
||||
await self.sock.send(payload)
|
||||
|
||||
@property
|
||||
|
@ -1040,7 +1044,7 @@ class RequestClient(salt.transport.base.RequestClient):
|
|||
await self.message_client.connect()
|
||||
|
||||
async def send(self, load, timeout=60):
|
||||
self.connect()
|
||||
await self.connect()
|
||||
return await self.message_client.send(load, timeout=timeout)
|
||||
|
||||
def close(self):
|
||||
|
|
|
@ -59,7 +59,9 @@ class SyncWrapper:
|
|||
loop_kwarg=None,
|
||||
):
|
||||
self.asyncio_loop = asyncio.new_event_loop()
|
||||
self.io_loop = tornado.ioloop.IOLoop(asyncio_loop=self.asyncio_loop)
|
||||
self.io_loop = tornado.ioloop.IOLoop(
|
||||
asyncio_loop=self.asyncio_loop, make_current=False
|
||||
)
|
||||
if args is None:
|
||||
args = []
|
||||
if kwargs is None:
|
||||
|
@ -108,12 +110,16 @@ class SyncWrapper:
|
|||
log.error("No async method %s on object %r", method, self.obj)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.exception("Exception encountered while running stop method")
|
||||
# thread = threading.Thread(target=self._run_loop_final, args=(self.asyncio_loop,))
|
||||
# thread.start()
|
||||
# thread.join()
|
||||
io_loop = self.io_loop
|
||||
io_loop.stop()
|
||||
try:
|
||||
io_loop.close(all_fds=True)
|
||||
except KeyError:
|
||||
pass
|
||||
self.asyncio_loop.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self._async_methods:
|
||||
|
@ -137,6 +143,18 @@ class SyncWrapper:
|
|||
|
||||
return wrap
|
||||
|
||||
def _run_loop_final(self, asyncio_loop):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
try:
|
||||
|
||||
async def noop():
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
result = io_loop.run_sync(lambda: noop())
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.error("Error on last loop run")
|
||||
|
||||
def _target(self, key, args, kwargs, results, asyncio_loop):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
|
@ -149,7 +167,16 @@ class SyncWrapper:
|
|||
results.append(sys.exc_info())
|
||||
|
||||
def __enter__(self):
|
||||
if hasattr(self.obj, "__aenter__"):
|
||||
ret = self._wrap("__aenter__")()
|
||||
if ret == self.obj:
|
||||
return self
|
||||
else:
|
||||
return ret
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, tb):
|
||||
self.close()
|
||||
if hasattr(self.obj, "__aexit__"):
|
||||
return self._wrap("__aexit__")
|
||||
else:
|
||||
self.close()
|
||||
|
|
|
@ -230,7 +230,7 @@ class SaltEvent:
|
|||
self.io_loop = io_loop
|
||||
self._run_io_loop_sync = False
|
||||
else:
|
||||
self.io_loop = tornado.ioloop.IOLoop()
|
||||
# self.io_loop = tornado.ioloop.IOLoop()
|
||||
self._run_io_loop_sync = True
|
||||
self.cpub = False
|
||||
self.cpush = False
|
||||
|
@ -353,48 +353,43 @@ class SaltEvent:
|
|||
if self.cpub:
|
||||
return True
|
||||
|
||||
kwargs = {"io_loop": self.io_loop}
|
||||
if isinstance(self.puburi, int):
|
||||
kwargs.update(host="127.0.0.1", port=self.puburi)
|
||||
else:
|
||||
kwargs.update(path=self.puburi)
|
||||
if self._run_io_loop_sync:
|
||||
with salt.utils.asynchronous.current_ioloop(self.io_loop):
|
||||
if self.subscriber is None:
|
||||
# self.subscriber = salt.utils.asynchronous.SyncWrapper(
|
||||
# salt.transport.ipc.IPCMessageSubscriber,
|
||||
# args=(self.puburi,),
|
||||
# kwargs={"io_loop": self.io_loop},
|
||||
# loop_kwarg="io_loop",
|
||||
# )
|
||||
# self.subscriber = salt.transport.publish_client(self.opts)
|
||||
self.subscriber = salt.utils.asynchronous.SyncWrapper(
|
||||
salt.transport.publish_client,
|
||||
args=(self.opts,),
|
||||
kwargs={"io_loop": self.io_loop},
|
||||
loop_kwarg="io_loop",
|
||||
)
|
||||
try:
|
||||
# self.subscriber.connect(timeout=timeout)
|
||||
puburi = "ipc://{}".format(self.puburi)
|
||||
self.subscriber.connect_uri(puburi)
|
||||
self.cpub = True
|
||||
except tornado.iostream.StreamClosedError:
|
||||
log.error("Encountered StreamClosedException")
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.ENOENT:
|
||||
raise
|
||||
log.error("Error opening stream, file does not exist")
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
log.info(
|
||||
"An exception occurred connecting publisher: %s",
|
||||
exc,
|
||||
exc_info_on_loglevel=logging.DEBUG,
|
||||
)
|
||||
if self.subscriber is None:
|
||||
self.subscriber = salt.utils.asynchronous.SyncWrapper(
|
||||
salt.transport.publish_client,
|
||||
args=(self.opts,),
|
||||
kwargs=kwargs,
|
||||
loop_kwarg="io_loop",
|
||||
)
|
||||
try:
|
||||
# self.subscriber.connect(timeout=timeout)
|
||||
log.debug("Event connect subscriber %r", self.puburi)
|
||||
self.subscriber.connect()
|
||||
self.cpub = True
|
||||
except tornado.iostream.StreamClosedError:
|
||||
log.error("Encountered StreamClosedException")
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.ENOENT:
|
||||
raise
|
||||
log.error("Error opening stream, file does not exist")
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
log.info(
|
||||
"An exception occurred connecting publisher: %s",
|
||||
exc,
|
||||
exc_info_on_loglevel=logging.DEBUG,
|
||||
)
|
||||
else:
|
||||
if self.subscriber is None:
|
||||
if "master_ip" not in self.opts:
|
||||
self.opts["master_ip"] = ""
|
||||
self.subscriber = salt.transport.publish_client(self.opts, self.io_loop)
|
||||
puburi = "ipc://{}".format(self.puburi)
|
||||
# self.io_loop.run_sync(self.subscriber.connect_uri, puburi)
|
||||
self.io_loop.spawn_callback(self.subscriber.connect_uri, puburi)
|
||||
log.error("WTF")
|
||||
self.subscriber = salt.transport.publish_client(self.opts, **kwargs)
|
||||
log.debug("Event connect subscriber %r", self.puburi)
|
||||
self.io_loop.spawn_callback(self.subscriber.connect)
|
||||
# self.subscriber = salt.transport.ipc.IPCMessageSubscriber(
|
||||
# self.puburi, io_loop=self.io_loop
|
||||
# )
|
||||
|
@ -410,9 +405,9 @@ class SaltEvent:
|
|||
"""
|
||||
if not self.cpub:
|
||||
return
|
||||
#if isinstance(self.subscriber, salt.utils.asynchronous.SyncWrapper):
|
||||
# if isinstance(self.subscriber, salt.utils.asynchronous.SyncWrapper):
|
||||
# self.subscriber.close()
|
||||
#else:
|
||||
# else:
|
||||
# asyncio.create_task(self.subscriber.close())
|
||||
self.subscriber.close()
|
||||
self.subscriber = None
|
||||
|
@ -433,6 +428,7 @@ class SaltEvent:
|
|||
salt.transport.publish_server,
|
||||
args=(self.opts,),
|
||||
)
|
||||
log.error("PUSHER %r %r", self, self.pusher.io_loop.asyncio_loop)
|
||||
self.pusher.obj.pub_uri = "ipc://{}".format(self.puburi)
|
||||
self.pusher.obj.pull_uri = "ipc://{}".format(self.pulluri)
|
||||
# self.pusher = salt.utils.asynchronous.SyncWrapper(
|
||||
|
@ -716,7 +712,6 @@ class SaltEvent:
|
|||
if not self.cpub:
|
||||
if not self.connect_pub():
|
||||
return None
|
||||
log.error("GET EVENT NOBLOCK %r", self.subscriber)
|
||||
raw = self.subscriber.recv(timeout=0)
|
||||
if raw is None:
|
||||
return None
|
||||
|
@ -733,7 +728,6 @@ class SaltEvent:
|
|||
if not self.cpub:
|
||||
if not self.connect_pub():
|
||||
return None
|
||||
log.error("GET EVENT BLOCK %r", self.subscriber)
|
||||
raw = self.subscriber.recv(timeout=None)
|
||||
if raw is None:
|
||||
return None
|
||||
|
@ -792,7 +786,7 @@ class SaltEvent:
|
|||
is_msgpacked=True,
|
||||
use_bin_type=True,
|
||||
)
|
||||
log.error(
|
||||
log.debug(
|
||||
"Sending event(fire_event_async): tag = %s; data = %s %r",
|
||||
tag,
|
||||
data,
|
||||
|
@ -850,7 +844,7 @@ class SaltEvent:
|
|||
is_msgpacked=True,
|
||||
use_bin_type=True,
|
||||
)
|
||||
log.error(
|
||||
log.debug(
|
||||
"Sending event(fire_event): tag = %s; data = %s %s",
|
||||
tag,
|
||||
data,
|
||||
|
@ -865,7 +859,6 @@ class SaltEvent:
|
|||
)
|
||||
msg = salt.utils.stringutils.to_bytes(event, "utf-8")
|
||||
if self._run_io_loop_sync:
|
||||
log.error("FIRE EVENT A %r %r", msg, self.pusher.obj)
|
||||
try:
|
||||
# self.pusher.send(msg)
|
||||
self.pusher.publish(msg)
|
||||
|
@ -877,7 +870,6 @@ class SaltEvent:
|
|||
)
|
||||
raise
|
||||
else:
|
||||
log.error("FIRE EVENT B %r %r", msg, self.pusher)
|
||||
asyncio.create_task(self.pusher.publish(msg))
|
||||
# self.io_loop.spawn_callback(self.pusher.send, msg)
|
||||
return True
|
||||
|
@ -897,8 +889,8 @@ class SaltEvent:
|
|||
self.close_pub()
|
||||
if self.pusher is not None:
|
||||
self.close_pull()
|
||||
if self._run_io_loop_sync and not self.keep_loop:
|
||||
self.io_loop.close()
|
||||
# if self._run_io_loop_sync and not self.keep_loop:
|
||||
# self.io_loop.close()
|
||||
|
||||
def _fire_ret_load_specific_fun(self, load, fun_index=0):
|
||||
"""
|
||||
|
@ -989,12 +981,11 @@ class SaltEvent:
|
|||
Invoke the event_handler callback each time an event arrives.
|
||||
"""
|
||||
assert not self._run_io_loop_sync
|
||||
|
||||
if not self.cpub:
|
||||
self.connect_pub()
|
||||
# This will handle reconnects
|
||||
# return self.subscriber.read_async(event_handler)
|
||||
self.subscriber.on_recv(event_handler)
|
||||
self.io_loop.spawn_callback(self.subscriber.on_recv, event_handler)
|
||||
|
||||
# pylint: disable=W1701
|
||||
def __del__(self):
|
||||
|
|
|
@ -30,6 +30,10 @@ from tests.support.runtests import RUNTIME_VARS
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def salt_auth_account_1_factory():
|
||||
|
@ -607,8 +611,14 @@ def pytest_pyfunc_call(pyfuncitem):
|
|||
|
||||
__tracebackhide__ = True
|
||||
|
||||
loop.run_sync(
|
||||
CoroTestFunction(pyfuncitem.obj, testargs), timeout=get_test_timeout(pyfuncitem)
|
||||
# loop.run_sync(
|
||||
# CoroTestFunction(pyfuncitem.obj, testargs), timeout=get_test_timeout(pyfuncitem)
|
||||
# )
|
||||
loop.asyncio_loop.run_until_complete(
|
||||
asyncio.wait_for(
|
||||
CoroTestFunction(pyfuncitem.obj, testargs)(),
|
||||
timeout=get_test_timeout(pyfuncitem),
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import time
|
||||
import tracemalloc
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
|
@ -9,6 +10,8 @@ import salt.channel.server
|
|||
import salt.master
|
||||
from tests.support.pytest.transport import PubServerChannelProcess
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -61,6 +64,8 @@ def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
|
|||
|
||||
ZMQ's ipc transport not supported on Windows
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
opts = dict(
|
||||
salt_master.config.copy(),
|
||||
ipc_mode="ipc",
|
||||
|
@ -91,28 +96,51 @@ def test_issue_36469_tcp(salt_master, salt_minion, transport):
|
|||
"""
|
||||
if transport == "tcp":
|
||||
pytest.skip("Test not applicable to the ZeroMQ transport.")
|
||||
import asyncio
|
||||
|
||||
def _send_small(opts, sid, num=10):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_channel = salt.channel.server.PubServerChannel.factory(opts)
|
||||
for idx in range(num):
|
||||
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-s{}".format(sid, idx)}
|
||||
server_channel.publish(load)
|
||||
|
||||
async def send():
|
||||
await asyncio.sleep(0.3)
|
||||
for idx in range(num):
|
||||
load = {
|
||||
"tgt_type": "glob",
|
||||
"tgt": "*",
|
||||
"jid": "{}-s{}".format(sid, idx),
|
||||
}
|
||||
await server_channel.publish(load)
|
||||
|
||||
asyncio.run(send())
|
||||
time.sleep(0.3)
|
||||
time.sleep(3)
|
||||
server_channel.close_pub()
|
||||
server_channel.close()
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def _send_large(opts, sid, num=10, size=250000 * 3):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_channel = salt.channel.server.PubServerChannel.factory(opts)
|
||||
for idx in range(num):
|
||||
load = {
|
||||
"tgt_type": "glob",
|
||||
"tgt": "*",
|
||||
"jid": "{}-l{}".format(sid, idx),
|
||||
"xdata": "0" * size,
|
||||
}
|
||||
server_channel.publish(load)
|
||||
|
||||
async def send():
|
||||
await asyncio.sleep(0.3)
|
||||
for idx in range(num):
|
||||
load = {
|
||||
"tgt_type": "glob",
|
||||
"tgt": "*",
|
||||
"jid": "{}-l{}".format(sid, idx),
|
||||
"xdata": "0" * size,
|
||||
}
|
||||
await server_channel.publish(load)
|
||||
|
||||
asyncio.run(send())
|
||||
time.sleep(0.3)
|
||||
server_channel.close_pub()
|
||||
server_channel.close()
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
opts = dict(salt_master.config.copy(), ipc_mode="tcp", pub_hwm=0)
|
||||
send_num = 10 * 4
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
@ -26,16 +27,26 @@ def server(config):
|
|||
disconnect = False
|
||||
|
||||
async def handle_stream(self, stream, address):
|
||||
while self.disconnect is False:
|
||||
for msg in self.send[:]:
|
||||
msg = self.send.pop(0)
|
||||
try:
|
||||
await stream.write(msg)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
break
|
||||
else:
|
||||
await tornado.gen.sleep(1)
|
||||
stream.close()
|
||||
try:
|
||||
log.error("Got stream")
|
||||
while self.disconnect is False:
|
||||
for msg in self.send[:]:
|
||||
msg = self.send.pop(0)
|
||||
try:
|
||||
log.error("Write %r", msg)
|
||||
await stream.write(msg)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
log.error("Stream Closed Error From Test Server")
|
||||
break
|
||||
else:
|
||||
log.error("SLEEP")
|
||||
await asyncio.sleep(1)
|
||||
log.error("Close stream")
|
||||
log.error("After close stream")
|
||||
except:
|
||||
log.error("WTFSON", exc_info=True)
|
||||
finally:
|
||||
stream.close()
|
||||
|
||||
server = TestServer()
|
||||
try:
|
||||
|
@ -47,14 +58,16 @@ def server(config):
|
|||
|
||||
@pytest.fixture
|
||||
def client(io_loop, config):
|
||||
client = salt.transport.tcp.TCPPubClient(config.copy(), io_loop)
|
||||
client = salt.transport.tcp.TCPPubClient(
|
||||
config.copy(), io_loop, host=config["master_ip"], port=config["publish_port"]
|
||||
)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
async def test_message_client_reconnect(io_loop, config, client, server):
|
||||
async def test_message_client_reconnect(config, client, server):
|
||||
"""
|
||||
Verify that the tcp MessageClient class re-sets it's unpacker after a
|
||||
stream disconnect.
|
||||
|
@ -69,7 +82,7 @@ async def test_message_client_reconnect(io_loop, config, client, server):
|
|||
received.append(msg)
|
||||
|
||||
client.on_recv(handler)
|
||||
|
||||
await asyncio.sleep(0.03)
|
||||
# Prepare two packed messages
|
||||
msg = salt.utils.msgpack.dumps({"test": "test1"})
|
||||
pmsg = salt.utils.msgpack.dumps({"head": {}, "body": msg})
|
||||
|
@ -78,24 +91,40 @@ async def test_message_client_reconnect(io_loop, config, client, server):
|
|||
|
||||
# Send one full and one partial msg to the client.
|
||||
partial = pmsg[:40]
|
||||
log.error("Send partial %r", partial)
|
||||
server.send.append(partial)
|
||||
|
||||
while not received:
|
||||
await tornado.gen.sleep(1)
|
||||
log.error("wait received")
|
||||
await asyncio.sleep(1)
|
||||
log.error("assert received")
|
||||
assert received == [msg]
|
||||
# log.error("sleep")
|
||||
# await asyncio.sleep(1)
|
||||
|
||||
# The message client has unpacked one msg and there is a partial msg left in
|
||||
# the unpacker. Closing the stream now leaves the unpacker in a bad state
|
||||
# since the rest of the partil message will never be received.
|
||||
log.error("disconnect")
|
||||
server.disconnect = True
|
||||
await tornado.gen.sleep(1)
|
||||
log.error("sleep")
|
||||
await asyncio.sleep(1)
|
||||
log.error("after sleep")
|
||||
log.error("disconnect false")
|
||||
server.disconnect = False
|
||||
log.error("sleep")
|
||||
await asyncio.sleep(1)
|
||||
log.error("after sleep")
|
||||
log.error("Disconnect False")
|
||||
received = []
|
||||
|
||||
# Prior to the fix for #60831, the unpacker would be left in a broken state
|
||||
# resulting in either a TypeError or BufferFull error from msgpack. The
|
||||
# rest of this test would fail.
|
||||
log.error("Send pmsg %r", pmsg)
|
||||
server.send.append(pmsg)
|
||||
while not received:
|
||||
await tornado.gen.sleep(1)
|
||||
assert received == [msg, msg]
|
||||
server.disconnect = True
|
||||
await tornado.gen.sleep(1)
|
||||
|
|
|
@ -22,6 +22,7 @@ def test_zeromq_filtering(salt_master, salt_minion):
|
|||
"""
|
||||
Test sending messages to publisher using UDP with zeromq_filtering enabled
|
||||
"""
|
||||
log.error("TEST START")
|
||||
opts = dict(
|
||||
salt_master.config.copy(),
|
||||
ipc_mode="ipc",
|
||||
|
@ -31,6 +32,7 @@ def test_zeromq_filtering(salt_master, salt_minion):
|
|||
)
|
||||
send_num = 1
|
||||
expect = []
|
||||
log.error("TEST START 2 ")
|
||||
with patch(
|
||||
"salt.utils.minions.CkMinions.check_minions",
|
||||
MagicMock(
|
||||
|
@ -41,9 +43,12 @@ def test_zeromq_filtering(salt_master, salt_minion):
|
|||
}
|
||||
),
|
||||
):
|
||||
# log.error("Get Server Channel")
|
||||
log.error("TEST START 3")
|
||||
with PubServerChannelProcess(
|
||||
opts, salt_minion.config.copy(), zmq_filtering=True
|
||||
) as server_channel:
|
||||
log.error("pub chan started")
|
||||
expect.append(send_num)
|
||||
load = {"tgt_type": "glob", "tgt": "*", "jid": send_num}
|
||||
server_channel.publish(load)
|
||||
|
|
1282
tests/pytests/unit/channel/test_request_channel.py
Normal file
1282
tests/pytests/unit/channel/test_request_channel.py
Normal file
File diff suppressed because it is too large
Load diff
249
tests/pytests/unit/transport/test_publish_client.py
Normal file
249
tests/pytests/unit/transport/test_publish_client.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
"""
|
||||
:codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import tornado.ioloop
|
||||
|
||||
import salt.crypt
|
||||
import salt.transport.tcp
|
||||
import salt.transport.zeromq
|
||||
import salt.utils.stringutils
|
||||
from tests.support.mock import MagicMock, patch
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.core_test,
|
||||
]
|
||||
|
||||
|
||||
def transport_ids(value):
|
||||
return "Transport({})".format(value)
|
||||
|
||||
|
||||
@pytest.fixture(params=("zeromq", "tcp"), ids=transport_ids)
|
||||
def transport(request):
|
||||
return request.param
|
||||
|
||||
|
||||
async def test_zeromq_async_pub_channel_publish_port(temp_salt_master):
|
||||
"""
|
||||
test when connecting that we use the publish_port set in opts when its not 4506
|
||||
"""
|
||||
opts = dict(
|
||||
temp_salt_master.config.copy(),
|
||||
ipc_mode="ipc",
|
||||
pub_hwm=0,
|
||||
recon_randomize=False,
|
||||
publish_port=455505,
|
||||
recon_default=1,
|
||||
recon_max=2,
|
||||
master_ip="127.0.0.1",
|
||||
acceptance_wait_time=5,
|
||||
acceptance_wait_time_max=5,
|
||||
sign_pub_messages=False,
|
||||
)
|
||||
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
|
||||
ioloop = tornado.ioloop.IOLoop()
|
||||
# Transport will connect to port given to connect method.
|
||||
transport = salt.transport.zeromq.PublishClient(
|
||||
opts, ioloop, host=opts["master_ip"], port=121212
|
||||
)
|
||||
with transport:
|
||||
patch_socket = MagicMock(return_value=True)
|
||||
patch_auth = MagicMock(return_value=True)
|
||||
with patch.object(transport, "_socket", patch_socket):
|
||||
await transport.connect(opts["publish_port"])
|
||||
assert str(opts["publish_port"]) in patch_socket.mock_calls[0][1][0]
|
||||
|
||||
|
||||
def test_zeromq_async_pub_channel_filtering_decode_message_no_match(
|
||||
temp_salt_master,
|
||||
):
|
||||
"""
|
||||
test zeromq PublishClient _decode_messages when
|
||||
zmq_filtering enabled and minion does not match
|
||||
"""
|
||||
message = [
|
||||
b"4f26aeafdb2367620a393c973eddbe8f8b846eb",
|
||||
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
|
||||
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
|
||||
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
|
||||
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
|
||||
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
|
||||
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
|
||||
]
|
||||
|
||||
opts = dict(
|
||||
temp_salt_master.config.copy(),
|
||||
ipc_mode="ipc",
|
||||
pub_hwm=0,
|
||||
zmq_filtering=True,
|
||||
recon_randomize=False,
|
||||
recon_default=1,
|
||||
recon_max=2,
|
||||
master_ip="127.0.0.1",
|
||||
acceptance_wait_time=5,
|
||||
acceptance_wait_time_max=5,
|
||||
sign_pub_messages=False,
|
||||
)
|
||||
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
|
||||
|
||||
ioloop = tornado.ioloop.IOLoop()
|
||||
transport = salt.transport.zeromq.PublishClient(
|
||||
opts, ioloop, host=opts["master_ip"], port=121212
|
||||
)
|
||||
with transport:
|
||||
with patch(
|
||||
"salt.crypt.AsyncAuth.crypticle",
|
||||
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
|
||||
):
|
||||
res = transport._decode_messages(message)
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_zeromq_async_pub_channel_filtering_decode_message(
|
||||
temp_salt_master, temp_salt_minion
|
||||
):
|
||||
"""
|
||||
test AsyncZeroMQPublishClient _decode_messages when zmq_filtered enabled
|
||||
"""
|
||||
minion_hexid = salt.utils.stringutils.to_bytes(
|
||||
hashlib.sha1(salt.utils.stringutils.to_bytes(temp_salt_minion.id)).hexdigest()
|
||||
)
|
||||
|
||||
message = [
|
||||
minion_hexid,
|
||||
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
|
||||
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
|
||||
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
|
||||
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
|
||||
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
|
||||
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
|
||||
]
|
||||
|
||||
opts = dict(
|
||||
temp_salt_master.config.copy(),
|
||||
id=temp_salt_minion.id,
|
||||
ipc_mode="ipc",
|
||||
pub_hwm=0,
|
||||
zmq_filtering=True,
|
||||
recon_randomize=False,
|
||||
recon_default=1,
|
||||
recon_max=2,
|
||||
master_ip="127.0.0.1",
|
||||
acceptance_wait_time=5,
|
||||
acceptance_wait_time_max=5,
|
||||
sign_pub_messages=False,
|
||||
)
|
||||
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
|
||||
|
||||
ioloop = tornado.ioloop.IOLoop()
|
||||
transport = salt.transport.zeromq.PublishClient(
|
||||
opts, ioloop, host=opts["master_ip"], port=121212
|
||||
)
|
||||
with transport:
|
||||
with patch(
|
||||
"salt.crypt.AsyncAuth.crypticle",
|
||||
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
|
||||
) as mock_test:
|
||||
res = transport._decode_messages(message)
|
||||
|
||||
assert res["enc"] == "aes"
|
||||
|
||||
|
||||
async def test_publish_client_connect_server_down(transport, io_loop):
|
||||
opts = {"master_ip": "127.0.0.1"}
|
||||
host = "127.0.0.1"
|
||||
port = 111222
|
||||
print(transport)
|
||||
if transport == "zeromq":
|
||||
client = salt.transport.zeromq.PublishClient(
|
||||
opts, io_loop, host=host, port=port
|
||||
)
|
||||
await client.connect()
|
||||
assert client._socket
|
||||
elif transport == "tcp":
|
||||
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
|
||||
try:
|
||||
await client.connect(background=True)
|
||||
except TimeoutError:
|
||||
pass
|
||||
assert client._stream is None
|
||||
client.close()
|
||||
|
||||
|
||||
async def test_publish_client_connect_server_comes_up(transport, io_loop):
|
||||
print(io_loop)
|
||||
opts = {"master_ip": "127.0.0.1"}
|
||||
host = "127.0.0.1"
|
||||
port = 11122
|
||||
if transport == "zeromq":
|
||||
import asyncio
|
||||
|
||||
import zmq
|
||||
|
||||
return
|
||||
ctx = zmq.asyncio.Context()
|
||||
uri = f"tcp://{opts['master_ip']}:{port}"
|
||||
msg = salt.payload.dumps({"meh": 123})
|
||||
print(f"send {msg}")
|
||||
client = salt.transport.zeromq.PublishClient(
|
||||
opts, io_loop, host=host, port=port
|
||||
)
|
||||
await client.connect(background=True)
|
||||
assert client._socket
|
||||
|
||||
socket = ctx.socket(zmq.PUB)
|
||||
socket.setsockopt(zmq.BACKLOG, 1000)
|
||||
socket.setsockopt(zmq.LINGER, -1)
|
||||
socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
print(f"bind {uri}")
|
||||
socket.bind(uri)
|
||||
await asyncio.sleep(10)
|
||||
|
||||
async def recv():
|
||||
return await client.recv(timeout=1)
|
||||
|
||||
task = asyncio.create_task(recv())
|
||||
# Sleep to allow zmq to do it's thing.
|
||||
await asyncio.sleep(0.03)
|
||||
await socket.send(msg)
|
||||
await task
|
||||
response = task.result()
|
||||
assert response
|
||||
client.close()
|
||||
socket.close()
|
||||
await asyncio.sleep(0.03)
|
||||
ctx.term()
|
||||
elif transport == "tcp":
|
||||
import asyncio
|
||||
import socket
|
||||
|
||||
import tornado
|
||||
|
||||
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
|
||||
await client.connect(port, background=True)
|
||||
assert client._stream is None
|
||||
await asyncio.sleep(2)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.setblocking(0)
|
||||
sock.bind((opts["master_ip"], port))
|
||||
sock.listen(128)
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
msg = salt.payload.dumps({"meh": 123})
|
||||
msg = salt.transport.frame.frame_msg(msg, header=None)
|
||||
conn, addr = sock.accept()
|
||||
conn.send(msg)
|
||||
response = await client.recv()
|
||||
assert response
|
||||
else:
|
||||
raise Exception(f"Unknown transport {transport}")
|
|
@ -80,12 +80,14 @@ def client_socket():
|
|||
yield _client_socket
|
||||
|
||||
|
||||
def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
|
||||
async def test_message_client_cleanup_on_close(
|
||||
client_socket, temp_salt_master, io_loop
|
||||
):
|
||||
"""
|
||||
test message client cleanup on close
|
||||
"""
|
||||
orig_loop = tornado.ioloop.IOLoop()
|
||||
orig_loop.make_current()
|
||||
|
||||
orig_loop = io_loop
|
||||
|
||||
opts = dict(temp_salt_master.config.copy(), transport="tcp")
|
||||
client = salt.transport.tcp.MessageClient(
|
||||
|
@ -103,15 +105,14 @@ def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
|
|||
orig_loop.stop = stop
|
||||
try:
|
||||
assert client.io_loop == orig_loop
|
||||
client.io_loop.run_sync(client.connect)
|
||||
await client.connect()
|
||||
|
||||
# Ensure we are testing the _read_until_future and io_loop teardown
|
||||
assert client._stream is not None
|
||||
assert orig_loop.stop_called is True
|
||||
|
||||
# The run_sync call will set stop_called, reset it
|
||||
# orig_loop.stop_called = False
|
||||
client.close()
|
||||
await client.close()
|
||||
|
||||
# Stop should be called again, client's io_loop should be None
|
||||
# assert orig_loop.stop_called is True
|
||||
|
@ -120,8 +121,10 @@ def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
|
|||
orig_loop.stop = orig_loop.real_stop
|
||||
del orig_loop.real_stop
|
||||
del orig_loop.stop_called
|
||||
orig_loop.clear_current()
|
||||
orig_loop.close(all_fds=True)
|
||||
|
||||
|
||||
# orig_loop.clear_current()
|
||||
# orig_loop.close(all_fds=True)
|
||||
|
||||
|
||||
async def test_async_tcp_pub_channel_connect_publish_port(
|
||||
|
@ -130,6 +133,8 @@ async def test_async_tcp_pub_channel_connect_publish_port(
|
|||
"""
|
||||
test when publish_port is not 4506
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
opts = dict(
|
||||
temp_salt_master.config.copy(),
|
||||
master_uri="tcp://127.0.0.1:1234",
|
||||
|
@ -141,6 +146,10 @@ async def test_async_tcp_pub_channel_connect_publish_port(
|
|||
)
|
||||
patch_auth = MagicMock(return_value=True)
|
||||
transport = MagicMock(spec=salt.transport.tcp.TCPPubClient)
|
||||
transport.connect = MagicMock()
|
||||
future = asyncio.Future()
|
||||
transport.connect.return_value = future
|
||||
future.set_result(True)
|
||||
with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
|
||||
"salt.crypt.AsyncAuth.authenticated", patch_auth
|
||||
), patch("salt.transport.tcp.TCPPubClient", transport):
|
||||
|
@ -150,6 +159,7 @@ async def test_async_tcp_pub_channel_connect_publish_port(
|
|||
with pytest.raises(salt.exceptions.SaltClientError):
|
||||
await channel.connect()
|
||||
# The first call to the mock is the instance's __init__, and the first argument to those calls is the opts dict
|
||||
await asyncio.sleep(0.3)
|
||||
assert channel.transport.connect.call_args[0][0] == opts["publish_port"]
|
||||
|
||||
|
||||
|
@ -248,10 +258,7 @@ def salt_message_client():
|
|||
{}, "127.0.0.1", ports.get_unused_localhost_port(), io_loop=io_loop_mock
|
||||
)
|
||||
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
yield client
|
||||
|
||||
|
||||
# XXX we don't reutnr a future anymore, this needs a different way of testing.
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -73,13 +73,21 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
"rotate_master_key": self._rotate_secrets,
|
||||
}
|
||||
|
||||
def _teardown_listener(self):
|
||||
if self.transport == "zeromq":
|
||||
self.sock.close()
|
||||
self.ctx.term()
|
||||
else:
|
||||
self.sock.close()
|
||||
|
||||
def _setup_listener(self):
|
||||
if self.transport == "zeromq":
|
||||
ctx = zmq.Context()
|
||||
self.sock = ctx.socket(zmq.SUB)
|
||||
self.ctx = zmq.Context()
|
||||
self.sock = self.ctx.socket(zmq.SUB)
|
||||
self.sock.setsockopt(zmq.LINGER, -1)
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
|
||||
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
|
||||
log.error("Collector listen %s", pub_uri)
|
||||
self.sock.connect(pub_uri)
|
||||
else:
|
||||
end = time.time() + 120
|
||||
|
@ -97,22 +105,28 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
|
||||
@tornado.gen.coroutine
|
||||
def _recv(self):
|
||||
# log.error("RECV %s", self.transport)
|
||||
if self.transport == "zeromq":
|
||||
# test_zeromq_filtering requires catching the
|
||||
# SaltDeserializationError in order to pass.
|
||||
try:
|
||||
payload = self.sock.recv(zmq.NOBLOCK)
|
||||
# log.error("ZMQ Payload is %r", payload)
|
||||
serial_payload = salt.payload.loads(payload)
|
||||
raise tornado.gen.Return(serial_payload)
|
||||
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
|
||||
raise RecvError("ZMQ Error")
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
raise tornado.gen.Return(msg["body"])
|
||||
# log.error("TCP Payload is %r", msg)
|
||||
serial_payload = salt.payload.loads(msg["body"])
|
||||
# raise tornado.gen.Return(msg["body"])
|
||||
raise tornado.gen.Return(serial_payload)
|
||||
byts = yield self.sock.read_bytes(8096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
raise tornado.gen.Return(msg["body"])
|
||||
serial_payload = salt.payload.loads(msg["body"])
|
||||
raise tornado.gen.Return(serial_payload)
|
||||
raise RecvError("TCP Error")
|
||||
|
||||
@tornado.gen.coroutine
|
||||
|
@ -128,40 +142,44 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
self.start = last_msg
|
||||
serial = salt.payload.Serial(self.minion_config)
|
||||
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
|
||||
while True:
|
||||
curr_time = time.time()
|
||||
if time.time() > self.hard_timeout:
|
||||
log.error("Hard timeout reaced in test collector!")
|
||||
break
|
||||
if curr_time - last_msg >= self.timeout:
|
||||
log.error("Receive timeout reaced in test collector!")
|
||||
break
|
||||
try:
|
||||
payload = yield self._recv()
|
||||
except RecvError:
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
try:
|
||||
while True:
|
||||
curr_time = time.time()
|
||||
if time.time() > self.hard_timeout:
|
||||
log.error("Hard timeout reaced in test collector!")
|
||||
break
|
||||
if curr_time - last_msg >= self.timeout:
|
||||
log.error("Receive timeout reaced in test collector!")
|
||||
break
|
||||
try:
|
||||
payload = crypticle.loads(payload["load"])
|
||||
if not payload:
|
||||
continue
|
||||
if "start" in payload:
|
||||
log.info("Collector started")
|
||||
self.running.set()
|
||||
continue
|
||||
if "stop" in payload:
|
||||
log.info("Collector stopped")
|
||||
break
|
||||
last_msg = time.time()
|
||||
self.results.append(payload["jid"])
|
||||
except salt.exceptions.SaltDeserializationError:
|
||||
log.error("Deserializer Error")
|
||||
if not self.zmq_filtering:
|
||||
log.exception("Failed to deserialize...")
|
||||
break
|
||||
self.end = time.time()
|
||||
print(f"Total time {self.end - self.start}")
|
||||
loop.stop()
|
||||
payload = yield self._recv()
|
||||
except RecvError:
|
||||
time.sleep(0.03)
|
||||
else:
|
||||
try:
|
||||
log.trace("Colleted payload %r", payload)
|
||||
payload = crypticle.loads(payload["load"])
|
||||
if not payload:
|
||||
continue
|
||||
if "start" in payload:
|
||||
log.info("Collector started")
|
||||
self.running.set()
|
||||
continue
|
||||
if "stop" in payload:
|
||||
log.info("Collector stopped")
|
||||
break
|
||||
last_msg = time.time()
|
||||
self.results.append(payload["jid"])
|
||||
except salt.exceptions.SaltDeserializationError:
|
||||
log.error("Deserializer Error")
|
||||
if not self.zmq_filtering:
|
||||
log.exception("Failed to deserialize...")
|
||||
break
|
||||
self.end = time.time()
|
||||
print(f"Total time {self.end - self.start}")
|
||||
finally:
|
||||
self._teardown_listener()
|
||||
loop.stop()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
|
@ -173,11 +191,17 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
loop.start()
|
||||
|
||||
def __enter__(self):
|
||||
import sys
|
||||
|
||||
print("COL ENTER")
|
||||
sys.stdout.flush()
|
||||
self.manager.__enter__()
|
||||
self.start()
|
||||
# Wait until we can start receiving events
|
||||
self.started.wait()
|
||||
self.started.clear()
|
||||
print("COL ENTER - Done")
|
||||
sys.stdout.flush()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
|
@ -197,6 +221,7 @@ class Collector(salt.utils.process.SignalHandlingProcess):
|
|||
class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
||||
def __init__(self, master_config, minion_config, **collector_kwargs):
|
||||
super().__init__()
|
||||
self.name = "PubServerChannelProcess"
|
||||
self._closing = False
|
||||
self.master_config = master_config
|
||||
self.minion_config = minion_config
|
||||
|
@ -226,17 +251,24 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
|||
self.master_config["interface"],
|
||||
self.master_config["publish_port"],
|
||||
self.aes_key,
|
||||
**self.collector_kwargs
|
||||
**self.collector_kwargs,
|
||||
)
|
||||
|
||||
def run(self):
|
||||
import queue
|
||||
|
||||
ioloop = tornado.ioloop.IOLoop()
|
||||
try:
|
||||
while True:
|
||||
payload = self.queue.get()
|
||||
try:
|
||||
payload = self.queue.get(False)
|
||||
except queue.Empty:
|
||||
time.sleep(0.03)
|
||||
continue
|
||||
if payload is None:
|
||||
log.debug("We received the stop sentinel")
|
||||
break
|
||||
self.pub_server_channel.publish(payload)
|
||||
ioloop.run_sync(lambda: self.pub_server_channel.publish(payload))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
|
@ -253,8 +285,8 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
|||
if self.process_manager is None:
|
||||
return
|
||||
self.process_manager.terminate()
|
||||
if hasattr(self.pub_server_channel, "pub_close"):
|
||||
self.pub_server_channel.pub_close()
|
||||
if hasattr(self.pub_server_channel, "close"):
|
||||
self.pub_server_channel.close()
|
||||
# Really terminate any process still left behind
|
||||
for pid in self.process_manager._process_map:
|
||||
terminate_process(pid=pid, kill_children=True, slow_stop=False)
|
||||
|
@ -264,9 +296,12 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
|||
self.queue.put(payload)
|
||||
|
||||
def __enter__(self):
|
||||
log.error("Proc start")
|
||||
self.start()
|
||||
log.error("Col enter")
|
||||
self.collector.__enter__()
|
||||
attempts = 300
|
||||
log.error("Wait collector")
|
||||
while attempts > 0:
|
||||
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "start": True})
|
||||
if self.collector.running.wait(1) is True:
|
||||
|
@ -274,6 +309,7 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
|||
attempts -= 1
|
||||
else:
|
||||
pytest.fail("Failed to confirm the collector has started")
|
||||
log.error("Collector started")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
|
@ -290,4 +326,4 @@ class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
|||
self.stopped.wait(10)
|
||||
self.close()
|
||||
self.terminate()
|
||||
log.info("The PubServerChannelProcess has terminated")
|
||||
log.error("The PubServerChannelProcess has terminated")
|
||||
|
|
Loading…
Add table
Reference in a new issue