Clean up transports

This commit is contained in:
Daniel A. Wozniak 2021-09-23 20:28:39 -07:00 committed by Gareth J. Greenaway
parent d4e6111086
commit 9db1af7147
15 changed files with 623 additions and 1071 deletions

View file

@ -4,6 +4,7 @@ Encapsulate the different transports available to Salt.
This includes client side transport, for the ReqServer and the Publisher
"""
import logging
import os
import time
@ -499,7 +500,6 @@ class AsyncPubChannel:
def _do_transfer():
msg = self._package_load(self.auth.crypticle.dumps(load))
package = salt.transport.frame.frame_msg(msg, header=None)
# yield self.message_client.write_to_stream(package)
yield self.transport.send(package)
raise salt.ext.tornado.gen.Return(True)
@ -555,9 +555,9 @@ class AsyncPubChannel:
"data": data,
"tag": tag,
}
req_channel = ReqChannel(self.opts)
req_channel = AsyncReqChannel.factory(self.opts)
try:
req_channel.send(load, timeout=60)
yield req_channel.send(load, timeout=60)
except salt.exceptions.SaltReqTimeoutError:
log.info(
"fire_master failed: master could not be contacted. Request timed"

View file

@ -4,6 +4,7 @@ Encapsulate the different transports available to Salt.
This includes server side transport, for the ReqServer and the Publisher
"""
import binascii
import ctypes
import hashlib
@ -11,13 +12,13 @@ import logging
import multiprocessing
import os
import shutil
import threading
import salt.crypt
import salt.ext.tornado.gen
import salt.master
import salt.payload
import salt.transport.frame
import salt.utils.channel
import salt.utils.event
import salt.utils.files
import salt.utils.minions
@ -115,6 +116,7 @@ class ReqServerChannel:
self.io_loop = io_loop
self.transport.post_fork(self.handle_message, io_loop)
import salt.master
self.serial = salt.payload.Serial(self.opts)
self.crypticle = salt.crypt.Crypticle(
self.opts, salt.master.SMaster.secrets["aes"]["secret"].value
@ -655,8 +657,6 @@ class PubServerChannel:
Factory class to create subscription channels to the master's Publisher
"""
_sock_data = threading.local()
@classmethod
def factory(cls, opts, **kwargs):
# Default to ZeroMQ for now
@ -668,6 +668,18 @@ class PubServerChannel:
elif "transport" in opts.get("pillar", {}).get("master", {}):
ttype = opts["pillar"]["master"]["transport"]
presence_events = False
if opts.get("presence_events", False):
tcp_only = True
for transport, _ in salt.utils.channel.iter_transport_opts(opts):
if transport != "tcp":
tcp_only = False
if tcp_only:
# Only when the transport is TCP only, the presence events will
# be handled here. Otherwise, it will be handled in the
# 'Maintenance' process.
presence_events = True
# switch on available ttypes
if ttype == "zeromq":
import salt.transport.zeromq
@ -683,16 +695,16 @@ class PubServerChannel:
transport = salt.transport.local.LocalPubServerChannel(opts, **kwargs)
else:
raise Exception("Channels are only defined for ZeroMQ and TCP")
# return NewKindOfChannel(opts, **kwargs)
return cls(opts, transport)
return cls(opts, transport, presence_events=presence_events)
def __init__(self, opts, transport):
def __init__(self, opts, transport, presence_events=False):
self.opts = opts
self.serial = salt.payload.Serial(self.opts) # TODO: in init?
self.ckminions = salt.utils.minions.CkMinions(self.opts)
self.transport = transport
self.aes_funcs = salt.master.AESFuncs(self.opts)
self.present = {}
self.presence_events = presence_events
self.event = salt.utils.event.get_event("master", opts=self.opts, listen=False)
def close(self):
@ -796,18 +808,20 @@ class PubServerChannel:
)
@salt.ext.tornado.gen.coroutine
def publish_payload(self, package, *args):
# unpacked_package = salt.payload.unpackage(package)
# unpacked_package = salt.transport.frame.decode_embedded_strs(
# unpacked_package
# )
ret = yield self.transport.publish_payload(package)
def publish_payload(self, unpacked_package, *args):
try:
payload = self.serial.loads(unpacked_package["payload"])
except KeyError:
log.error("Invalid package %r", unpacked_package)
raise
if "topic_lst" in unpacked_package:
topic_list = unpacked_package["topic_lst"]
ret = yield self.transport.publish_payload(payload, topic_list)
else:
ret = yield self.transport.publish_payload(payload)
raise salt.ext.tornado.gen.Return(ret)
def publish(self, load):
"""
Publish "load" to minions
"""
def wrap_payload(self, load):
payload = {"enc": "aes"}
crypticle = salt.crypt.Crypticle(
self.opts, salt.master.SMaster.secrets["aes"]["secret"].value
@ -817,14 +831,33 @@ class PubServerChannel:
master_pem_path = os.path.join(self.opts["pki_dir"], "master.pem")
log.debug("Signing data packet")
payload["sig"] = salt.crypt.sign_message(master_pem_path, payload["load"])
int_payload = {"payload": self.serial.dumps(payload)}
# add some targeting stuff for lists only (for now)
if load["tgt_type"] == "list":
int_payload["topic_lst"] = load["tgt"]
# If topics are upported, target matching has to happen master side
match_targets = ["pcre", "glob", "list"]
if self.transport.topic_support and load["tgt_type"] in match_targets:
if isinstance(load["tgt"], str):
# Fetch a list of minions that match
_res = self.ckminions.check_minions(
load["tgt"], tgt_type=load["tgt_type"]
)
match_ids = _res["minions"]
log.debug("Publish Side Match: %s", match_ids)
# Send list of miions thru so zmq can target them
int_payload["topic_lst"] = match_ids
else:
int_payload["topic_lst"] = load["tgt"]
return int_payload
# XXX: These implimentations vary slightly, condense them and add tests
int_payload = self.transport.publish_filters(
payload, load["tgt_type"], load["tgt"], self.ckminions
)
def publish(self, load):
"""
Publish "load" to minions
"""
payload = self.wrap_payload(load)
log.debug(
"Sending payload to publish daemon. jid=%s size=%d",
"Sending payload to publish daemon. jid=%s",
load.get("jid", None),
len(payload),
)
self.transport.publish(int_payload)
self.transport.publish(payload)

View file

@ -58,7 +58,7 @@ import salt.wheel
from salt.config import DEFAULT_INTERVAL
from salt.defaults import DEFAULT_TARGET_DELIM
from salt.ext.tornado.stack_context import StackContext
from salt.transport import iter_transport_opts
from salt.utils.channel import iter_transport_opts
from salt.utils.ctx import RequestContext
from salt.utils.debug import (
enable_sigusr1_handler,

View file

@ -10,11 +10,11 @@ import os
import sys
import traceback
import salt.channel.client
import salt.ext.tornado.gen
import salt.fileclient
import salt.loader
import salt.minion
import salt.transport.client
import salt.utils.args
import salt.utils.cache
import salt.utils.crypt
@ -218,7 +218,7 @@ class AsyncRemotePillar(RemotePillarMixin):
self.ext = ext
self.grains = grains
self.minion_id = minion_id
self.channel = salt.transport.client.AsyncReqChannel.factory(opts)
self.channel = salt.channel.client.AsyncReqChannel.factory(opts)
if pillarenv is not None:
self.opts["pillarenv"] = pillarenv
self.pillar_override = pillar_override or {}
@ -311,7 +311,7 @@ class RemotePillar(RemotePillarMixin):
self.ext = ext
self.grains = grains
self.minion_id = minion_id
self.channel = salt.transport.client.ReqChannel.factory(opts)
self.channel = salt.channel.client.ReqChannel.factory(opts)
if pillarenv is not None:
self.opts["pillarenv"] = pillarenv
self.pillar_override = pillar_override or {}

View file

@ -13,39 +13,3 @@ log = logging.getLogger(__name__)
warnings.filterwarnings(
"ignore", message="IOLoop.current expected instance.*", category=RuntimeWarning
)
def iter_transport_opts(opts):
"""
Yield transport, opts for all master configured transports
"""
transports = set()
for transport, opts_overrides in opts.get("transport_opts", {}).items():
t_opts = dict(opts)
t_opts.update(opts_overrides)
t_opts["transport"] = transport
transports.add(transport)
yield transport, t_opts
if opts["transport"] not in transports:
yield opts["transport"], opts
class MessageClientPool:
def __init__(self, tgt, opts, args=None, kwargs=None):
sock_pool_size = opts["sock_pool_size"] if "sock_pool_size" in opts else 1
if sock_pool_size < 1:
log.warning(
"sock_pool_size is not correctly set, the option should be "
"greater than 0 but is instead %s",
sock_pool_size,
)
sock_pool_size = 1
if args is None:
args = ()
if kwargs is None:
kwargs = {}
self.message_clients = [tgt(*args, **kwargs) for _ in range(sock_pool_size)]

View file

@ -14,8 +14,6 @@ import socket
import threading
import urllib
import salt.crypt
import salt.exceptions
import salt.ext.tornado
import salt.ext.tornado.concurrent
import salt.ext.tornado.gen
@ -30,15 +28,11 @@ import salt.transport.frame
import salt.transport.ipc
import salt.transport.server
import salt.utils.asynchronous
import salt.utils.event
import salt.utils.files
import salt.utils.msgpack
import salt.utils.platform
import salt.utils.process
import salt.utils.verify
import salt.utils.versions
from salt.exceptions import SaltClientError, SaltReqTimeoutError
from salt.transport import iter_transport_opts
if salt.utils.platform.is_windows():
USE_LOAD_BALANCER = True
@ -54,6 +48,10 @@ if USE_LOAD_BALANCER:
log = logging.getLogger(__name__)
class ClosingError(Exception):
""" """
def _set_tcp_keepalive(sock, opts):
"""
Ensure that TCP keepalives are set for the socket.
@ -215,15 +213,10 @@ class AsyncTCPPubChannel(ResolverMixin):
def __init__(self, opts, io_loop, **kwargs):
super().__init__()
self.opts = opts
self.crypt = kwargs.get("crypt", "aes")
self.io_loop = io_loop
self.message_client = None
self.connected = False
self._closing = False
self._reconnected = False
self.message_client = None
self._closing = False
# self.event = salt.utils.event.get_event("minion", opts=self.opts, listen=False)
# self.tok = self.auth.gen_token(b"salt")
def close(self):
if self._closing:
@ -242,16 +235,15 @@ class AsyncTCPPubChannel(ResolverMixin):
@salt.ext.tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
self.publish_port = publish_port
self.message_client = SaltMessageClientPool(
self.message_client = MessageClient(
self.opts,
args=(self.opts, self.opts["master_ip"], int(self.publish_port)),
kwargs={
"io_loop": self.io_loop,
"connect_callback": connect_callback,
"disconnect_callback": disconnect_callback,
"source_ip": self.opts.get("source_ip"),
"source_port": self.opts.get("source_publish_port"),
},
self.opts["master_ip"],
int(self.publish_port),
io_loop=self.io_loop,
connect_callback=connect_callback,
disconnect_callback=disconnect_callback,
source_ip=self.opts.get("source_ip"),
source_port=self.opts.get("source_publish_port"),
)
yield self.message_client.connect() # wait for the client to be connected
self.connected = True
@ -269,7 +261,7 @@ class AsyncTCPPubChannel(ResolverMixin):
@salt.ext.tornado.gen.coroutine
def send(self, msg):
yield self.message_client.write_to_stream(msg)
yield self.message_client._stream.write(msg)
def on_recv(self, callback):
"""
@ -566,58 +558,10 @@ class TCPClientKeepAlive(salt.ext.tornado.tcpclient.TCPClient):
return stream, stream.connect(addr)
class SaltMessageClientPool(salt.transport.MessageClientPool):
"""
Wrapper class of SaltMessageClient to avoid blocking waiting while writing data to socket.
"""
ttype = "tcp"
def __init__(self, opts, args=None, kwargs=None):
super().__init__(SaltMessageClient, opts, args=args, kwargs=kwargs)
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
# pylint: disable=W1701
def __del__(self):
self.close()
# pylint: enable=W1701
def close(self):
for message_client in self.message_clients:
message_client.close()
self.message_clients = []
@salt.ext.tornado.gen.coroutine
def connect(self):
futures = []
for message_client in self.message_clients:
futures.append(message_client.connect())
yield futures
raise salt.ext.tornado.gen.Return(None)
def on_recv(self, *args, **kwargs):
for message_client in self.message_clients:
message_client.on_recv(*args, **kwargs)
def send(self, *args, **kwargs):
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
return message_clients[0].send(*args, **kwargs)
def write_to_stream(self, *args, **kwargs):
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
return message_clients[0]._stream.write(*args, **kwargs)
# TODO consolidate with IPCClient
# TODO: limit in-flight messages.
# TODO: singleton? Something to not re-create the tcp connection so much
class SaltMessageClient:
class MessageClient:
"""
Low-level message sending client
"""
@ -641,26 +585,21 @@ class SaltMessageClient:
self.source_port = source_port
self.connect_callback = connect_callback
self.disconnect_callback = disconnect_callback
self.io_loop = io_loop or salt.ext.tornado.ioloop.IOLoop.current()
with salt.utils.asynchronous.current_ioloop(self.io_loop):
self._tcp_client = TCPClientKeepAlive(opts, resolver=resolver)
self._mid = 1
self._max_messages = int((1 << 31) - 2) # number of IDs before we wrap
# TODO: max queue size
self.send_queue = [] # queue of messages to be sent
self.send_future_map = {} # mapping of request_id -> Future
self.send_timeout_map = {} # request_id -> timeout_callback
self._read_until_future = None
self._on_recv = None
self._closing = False
self._connecting_future = self.connect()
self._stream_return_future = salt.ext.tornado.concurrent.Future()
self.io_loop.spawn_callback(self._stream_return)
self._closed = False
self._connecting_future = salt.ext.tornado.concurrent.Future()
self._stream = None
self.backoff = opts.get("tcp_reconnect_backoff", 1)
@ -673,46 +612,16 @@ class SaltMessageClient:
if self._closing:
return
self._closing = True
if hasattr(self, "_stream") and not self._stream.closed():
# If _stream_return() hasn't completed, it means the IO
# Loop is stopped (such as when using
# 'salt.utils.asynchronous.SyncWrapper'). Ensure that
# _stream_return() completes by restarting the IO Loop.
# This will prevent potential errors on shutdown.
try:
orig_loop = salt.ext.tornado.ioloop.IOLoop.current()
self.io_loop.make_current()
self._stream.close()
if self._read_until_future is not None:
# This will prevent this message from showing up:
# '[ERROR ] Future exception was never retrieved:
# StreamClosedError'
# This happens because the logic is always waiting to read
# the next message and the associated read future is marked
# 'StreamClosedError' when the stream is closed.
if self._read_until_future.done():
self._read_until_future.exception()
if (
self.io_loop
!= salt.ext.tornado.ioloop.IOLoop.current(instance=False)
or not self._stream_return_future.done()
):
self.io_loop.add_future(
self._stream_return_future,
lambda future: self._stop_io_loop(),
)
self.io_loop.start()
except Exception as e: # pylint: disable=broad-except
log.info("Exception caught in SaltMessageClient.close: %s", str(e))
finally:
orig_loop.make_current()
self._tcp_client.close()
self.io_loop = None
self._read_until_future = None
# Clear callback references to allow the object that they belong to
# to be deleted.
self.connect_callback = None
self.disconnect_callback = None
try:
for msg_id in list(self.send_future_map):
future = self.send_future_map.pop(msg_id)
future.set_exception(ClosingError())
self._tcp_client.close()
# self._stream.close()
finally:
self._stream = None
self._closing = False
self._closed = True
# pylint: disable=W1701
def __del__(self):
@ -720,58 +629,28 @@ class SaltMessageClient:
# pylint: enable=W1701
def connect(self):
"""
Ask for this client to reconnect to the origin
"""
if hasattr(self, "_connecting_future") and not self._connecting_future.done():
future = self._connecting_future
else:
future = salt.ext.tornado.concurrent.Future()
self._connecting_future = future
self.io_loop.add_callback(self._connect)
# Add the callback only when a new future is created
if self.connect_callback is not None:
def handle_future(future):
response = future.result()
self.io_loop.add_callback(self.connect_callback, response)
future.add_done_callback(handle_future)
return future
@salt.ext.tornado.gen.coroutine
def _connect(self):
"""
Try to connect for the rest of time!
"""
while True:
if self._closing:
break
def getstream(self, **kwargs):
if self.source_ip or self.source_port:
if salt.ext.tornado.version_info >= (4, 5):
### source_ip and source_port are supported only in Tornado >= 4.5
# See http://www.tornadoweb.org/en/stable/releases/v4.5.0.html
# Otherwise will just ignore these args
kwargs = {
"source_ip": self.source_ip,
"source_port": self.source_port,
}
else:
log.warning(
"If you need a certain source IP/port, consider upgrading"
" Tornado >= 4.5"
)
stream = None
while stream is None and not self._closed:
try:
kwargs = {}
if self.source_ip or self.source_port:
if salt.ext.tornado.version_info >= (4, 5):
### source_ip and source_port are supported only in Tornado >= 4.5
# See http://www.tornadoweb.org/en/stable/releases/v4.5.0.html
# Otherwise will just ignore these args
kwargs = {
"source_ip": self.source_ip,
"source_port": self.source_port,
}
else:
log.warning(
"If you need a certain source IP/port, consider upgrading"
" Tornado >= 4.5"
)
with salt.utils.asynchronous.current_ioloop(self.io_loop):
self._stream = yield self._tcp_client.connect(
self.host, self.port, ssl_options=self.opts.get("ssl"), **kwargs
)
self._connecting_future.set_result(True)
break
stream = yield self._tcp_client.connect(
self.host, self.port, ssl_options=self.opts.get("ssl"), **kwargs
)
except Exception as exc: # pylint: disable=broad-except
log.warning(
"TCP Message Client encountered an exception while connecting to"
@ -783,104 +662,73 @@ class SaltMessageClient:
)
yield salt.ext.tornado.gen.sleep(self.backoff)
# self._connecting_future.set_exception(exc)
raise salt.ext.tornado.gen.Return(stream)
@salt.ext.tornado.gen.coroutine
def connect(self):
if not self._stream:
self._stream = yield self.getstream()
if self.connect_callback:
self.connect_callback(True)
self.io_loop.spawn_callback(self._stream_return)
@salt.ext.tornado.gen.coroutine
def _stream_return(self):
try:
while not self._closing and (
not self._connecting_future.done()
or self._connecting_future.result() is not True
):
yield self._connecting_future
unpacker = salt.utils.msgpack.Unpacker()
while not self._closing:
try:
self._read_until_future = self._stream.read_bytes(
4096, partial=True
)
wire_bytes = yield self._read_until_future
unpacker.feed(wire_bytes)
for framed_msg in unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(
framed_msg
)
header = framed_msg["head"]
body = framed_msg["body"]
message_id = header.get("mid")
if message_id in self.send_future_map:
self.send_future_map.pop(message_id).set_result(body)
self.remove_message_timeout(message_id)
else:
if self._on_recv is not None:
self.io_loop.spawn_callback(self._on_recv, header, body)
else:
log.error(
"Got response for message_id %s that we are not"
" tracking",
message_id,
)
except salt.ext.tornado.iostream.StreamClosedError as e:
log.debug(
"tcp stream to %s:%s closed, unable to recv",
self.host,
self.port,
)
for future in self.send_future_map.values():
future.set_exception(e)
self.send_future_map = {}
if self._closing:
return
if self.disconnect_callback:
self.disconnect_callback()
# if the last connect finished, then we need to make a new one
if self._connecting_future.done():
self._connecting_future = self.connect()
yield self._connecting_future
except TypeError:
# This is an invalid transport
if "detect_mode" in self.opts:
log.info(
"There was an error trying to use TCP transport; "
"attempting to fallback to another transport"
)
else:
raise SaltClientError
except Exception as e: # pylint: disable=broad-except
log.error("Exception parsing response", exc_info=True)
for future in self.send_future_map.values():
future.set_exception(e)
self.send_future_map = {}
if self._closing:
return
if self.disconnect_callback:
self.disconnect_callback()
# if the last connect finished, then we need to make a new one
if self._connecting_future.done():
self._connecting_future = self.connect()
yield self._connecting_future
finally:
self._stream_return_future.set_result(True)
@salt.ext.tornado.gen.coroutine
def _stream_send(self):
while (
not self._connecting_future.done()
or self._connecting_future.result() is not True
):
yield self._connecting_future
while len(self.send_queue) > 0:
message_id, item = self.send_queue[0]
unpacker = salt.utils.msgpack.Unpacker()
while not self._closing:
try:
yield self._stream.write(item)
del self.send_queue[0]
# if the connection is dead, lets fail this send, and make sure we
# attempt to reconnect
self._read_until_future = self._stream.read_bytes(4096, partial=True)
wire_bytes = yield self._read_until_future
unpacker.feed(wire_bytes)
for framed_msg in unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
header = framed_msg["head"]
body = framed_msg["body"]
message_id = header.get("mid")
if message_id in self.send_future_map:
self.send_future_map.pop(message_id).set_result(body)
# self.remove_message_timeout(message_id)
else:
if self._on_recv is not None:
self.io_loop.spawn_callback(self._on_recv, header, body)
else:
log.error(
"Got response for message_id %s that we are not"
" tracking",
message_id,
)
except salt.ext.tornado.iostream.StreamClosedError as e:
if message_id in self.send_future_map:
self.send_future_map.pop(message_id).set_exception(e)
self.remove_message_timeout(message_id)
del self.send_queue[0]
log.debug(
"tcp stream to %s:%s closed, unable to recv",
self.host,
self.port,
)
for future in self.send_future_map.values():
future.set_exception(e)
self.send_future_map = {}
if self._closing:
return
if self.disconnect_callback:
self.disconnect_callback()
# if the last connect finished, then we need to make a new one
if self._connecting_future.done():
self._connecting_future = self.connect()
yield self._connecting_future
except TypeError:
# This is an invalid transport
if "detect_mode" in self.opts:
log.info(
"There was an error trying to use TCP transport; "
"attempting to fallback to another transport"
)
else:
raise SaltClientError
except Exception as e: # pylint: disable=broad-except
log.error("Exception parsing response", exc_info=True)
for future in self.send_future_map.values():
future.set_exception(e)
self.send_future_map = {}
if self._closing:
return
if self.disconnect_callback:
@ -925,35 +773,15 @@ class SaltMessageClient:
self.io_loop.remove_timeout(timeout)
def timeout_message(self, message_id, msg):
if message_id in self.send_timeout_map:
del self.send_timeout_map[message_id]
if message_id in self.send_future_map:
future = self.send_future_map.pop(message_id)
# In a race condition the message might have been sent by the time
# we're timing it out. Make sure the future is not None
if future is not None:
if future.attempts < future.tries:
future.attempts += 1
log.debug(
"SaltReqTimeoutError, retrying. (%s/%s)",
future.attempts,
future.tries,
)
self.send(
msg,
timeout=future.timeout,
tries=future.tries,
future=future,
)
else:
future.set_exception(SaltReqTimeoutError("Message timed out"))
if message_id not in self.send_future_map:
return
future = self.send_future_map.pop(message_id)
future.set_exception(SaltReqTimeoutError("Message timed out"))
@salt.ext.tornado.gen.coroutine
def send(self, msg, timeout=None, callback=None, raw=False, future=None, tries=3):
"""
Send given message, and return a future
"""
if self._closing:
raise ClosingError()
message_id = self._message_id()
header = {"mid": message_id}
@ -977,18 +805,12 @@ class SaltMessageClient:
timeout = 1
if timeout is not None:
send_timeout = self.io_loop.call_later(
timeout, self.timeout_message, message_id, msg
)
self.send_timeout_map[message_id] = send_timeout
# if we don't have a send queue, we need to spawn the callback to do the sending
if len(self.send_queue) == 0:
self.io_loop.spawn_callback(self._stream_send)
self.send_queue.append(
(message_id, salt.transport.frame.frame_msg(msg, header=header))
)
return future
self.io_loop.call_later(timeout, self.timeout_message, message_id, msg)
item = salt.transport.frame.frame_msg(msg, header=header)
yield self.connect()
yield self._stream.write(item)
recv = yield future
raise salt.ext.tornado.gen.Return(recv)
class Subscriber:
@ -1039,16 +861,7 @@ class PubServer(salt.ext.tornado.tcpserver.TCPServer):
self._closing = False
self.clients = set()
self.presence_events = False
if self.opts.get("presence_events", False):
tcp_only = True
for transport, _ in iter_transport_opts(self.opts):
if transport != "tcp":
tcp_only = False
if tcp_only:
# Only when the transport is TCP only, the presence events will
# be handled here. Otherwise, it will be handled in the
# 'Maintenance' process.
self.presence_events = True
self.serial = salt.payload.Serial({})
if presence_callback:
self.presence_callback = presence_callback
else:
@ -1076,14 +889,15 @@ class PubServer(salt.ext.tornado.tcpserver.TCPServer):
unpacker = salt.utils.msgpack.Unpacker()
while not self._closing:
try:
# client._read_until_future = client.stream.read_bytes(4096, partial=True)
# wire_bytes = yield client._read_until_future
wire_bytes = yield client.stream.read_bytes(4096, partial=True)
client._read_until_future = client.stream.read_bytes(4096, partial=True)
wire_bytes = yield client._read_until_future
# wire_bytes = yield client.stream.read_bytes(4096, partial=True)
unpacker.feed(wire_bytes)
for framed_msg in unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg)
body = framed_msg["body"]
self.presence_callback(client, body)
if self.presence_callback:
self.presence_callback(client, body)
except salt.ext.tornado.iostream.StreamClosedError as e:
log.debug("tcp stream to %s closed, unable to recv", client.address)
client.close()
@ -1104,29 +918,24 @@ class PubServer(salt.ext.tornado.tcpserver.TCPServer):
# TODO: ACK the publish through IPC
@salt.ext.tornado.gen.coroutine
def publish_payload(self, package, *args):
def publish_payload(self, package, topic_list=None):
log.debug("TCP PubServer sending payload: %s", package)
payload = salt.transport.frame.frame_msg(package["payload"])
payload = salt.transport.frame.frame_msg(package)
to_remove = []
if "topic_lst" in package:
topic_lst = package["topic_lst"]
for topic in topic_lst:
if topic in self.present:
# This will rarely be a list of more than 1 item. It will
# be more than 1 item if the minion disconnects from the
# master in an unclean manner (eg cable yank), then
# restarts and the master is yet to detect the disconnect
# via TCP keep-alive.
for client in self.present[topic]:
if topic_list and False:
for topic in topic_list:
sent = False
for client in list(self.clients):
if topic == client.id_:
try:
# Write the packed str
yield client.stream.write(payload)
sent = True
# self.io_loop.add_future(f, lambda f: True)
except salt.ext.tornado.iostream.StreamClosedError:
to_remove.append(client)
else:
log.debug("Publish target %s not connected", topic)
self.clients.remove(client)
if not sent:
log.debug("Publish target %s not connected %r", topic, self.clients)
else:
for client in self.clients:
try:
@ -1152,15 +961,16 @@ class TCPPubServerChannel:
def __init__(self, opts):
self.opts = opts
self.ckminions = salt.utils.minions.CkMinions(opts)
self.io_loop = None
@property
def topic_support(self):
return not self.opts.get("order_masters", False)
def __setstate__(self, state):
salt.master.SMaster.secrets = state["secrets"]
self.__init__(state["opts"])
def __getstate__(self):
return {"opts": self.opts, "secrets": salt.master.SMaster.secrets}
return {"opts": self.opts}
def publish_daemon(
self,
@ -1172,60 +982,55 @@ class TCPPubServerChannel:
"""
Bind to the interface specified in the configuration file
"""
log_queue = kwargs.get("log_queue")
if log_queue is not None:
salt.log.setup.set_multiprocessing_logging_queue(log_queue)
log_queue_level = kwargs.get("log_queue_level")
if log_queue_level is not None:
salt.log.setup.set_multiprocessing_logging_level(log_queue_level)
io_loop = salt.ext.tornado.ioloop.IOLoop()
io_loop.make_current()
# Spin up the publisher
self.pub_server = pub_server = PubServer(
self.opts,
io_loop=io_loop,
presence_callback=presence_callback,
remove_presence_callback=remove_presence_callback,
)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(sock, self.opts)
sock.setblocking(0)
sock.bind((self.opts["interface"], int(self.opts["publish_port"])))
sock.listen(self.backlog)
# pub_server will take ownership of the socket
pub_server.add_socket(sock)
# Set up Salt IPC server
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
else:
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
self.pub_server = pub_server
pull_sock = salt.transport.ipc.IPCMessageServer(
pull_uri,
io_loop=io_loop,
payload_handler=publish_payload,
)
# Securely create socket
log.info("Starting the Salt Puller on %s", pull_uri)
with salt.utils.files.set_umask(0o177):
pull_sock.start()
# run forever
try:
log_queue = kwargs.get("log_queue")
if log_queue is not None:
salt.log.setup.set_multiprocessing_logging_queue(log_queue)
log_queue_level = kwargs.get("log_queue_level")
if log_queue_level is not None:
salt.log.setup.set_multiprocessing_logging_level(log_queue_level)
salt.log.setup.setup_multiprocessing_logging(log_queue)
# Check if io_loop was set outside
if self.io_loop is None:
self.io_loop = salt.ext.tornado.ioloop.IOLoop.current()
# Spin up the publisher
self.pub_server = pub_server = PubServer(
self.opts,
io_loop=self.io_loop,
presence_callback=presence_callback,
remove_presence_callback=remove_presence_callback,
)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(sock, self.opts)
sock.setblocking(0)
sock.bind((self.opts["interface"], int(self.opts["publish_port"])))
sock.listen(self.backlog)
# pub_server will take ownership of the socket
pub_server.add_socket(sock)
# Set up Salt IPC server
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
else:
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
self.pub_server = pub_server
pull_sock = salt.transport.ipc.IPCMessageServer(
pull_uri,
io_loop=self.io_loop,
payload_handler=publish_payload,
)
# Securely create socket
log.info("Starting the Salt Puller on %s", pull_uri)
with salt.utils.files.set_umask(0o177):
pull_sock.start()
# run forever
try:
self.io_loop.start()
except (KeyboardInterrupt, SystemExit):
pass
finally:
pull_sock.close()
except Exception as exc: # pylint: disable=broad-except
log.error("Caught exception in publish daemon", exc_info=True)
io_loop.start()
except (KeyboardInterrupt, SystemExit):
pass
finally:
pull_sock.close()
def pre_fork(self, process_manager, kwargs=None):
"""
@ -1242,7 +1047,7 @@ class TCPPubServerChannel:
ret = yield self.pub_server.publish_payload(payload, *args)
raise salt.ext.tornado.gen.Return(ret)
def publish(self, int_payload):
def publish(self, payload):
"""
Publish "load" to minions
"""
@ -1256,22 +1061,7 @@ class TCPPubServerChannel:
loop_kwarg="io_loop",
)
pub_sock.connect()
pub_sock.send(int_payload)
def publish_filters(self, payload, tgt_type, tgt, ckminions):
payload = {"payload": self.serial.dumps(payload)}
# add some targeting stuff for lists only (for now)
if tgt_type == "list" and not self.opts.get("order_masters", False):
if isinstance(tgt, str):
# Fetch a list of minions that match
_res = self.ckminions.check_minions(tgt, tgt_type=tgt_type)
match_ids = _res["minions"]
log.debug("Publish Side Match: %s", match_ids)
# Send list of miions thru so zmq can target them
payload["topic_lst"] = match_ids
else:
payload["topic_lst"] = tgt
return payload
pub_sock.send(payload)
class TCPReqChannel(ResolverMixin):
@ -1287,19 +1077,14 @@ class TCPReqChannel(ResolverMixin):
master_host, master_port = parse.netloc.rsplit(":", 1)
master_addr = (master_host, int(master_port))
resolver = kwargs.get("resolver")
self.message_client = salt.transport.tcp.SaltMessageClientPool(
self.message_client = salt.transport.tcp.MessageClient(
opts,
args=(
opts,
master_host,
int(master_port),
),
kwargs={
"io_loop": io_loop,
"resolver": resolver,
"source_ip": opts.get("source_ip"),
"source_port": opts.get("source_ret_port"),
},
master_host,
int(master_port),
io_loop=io_loop,
resolver=resolver,
source_ip=opts.get("source_ip"),
source_port=opts.get("source_ret_port"),
)
@salt.ext.tornado.gen.coroutine

View file

@ -16,16 +16,11 @@ import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.log.setup
import salt.payload
import salt.transport.client
import salt.transport.server
import salt.utils.files
import salt.utils.minions
import salt.utils.process
import salt.utils.stringutils
import salt.utils.zeromq
import zmq.asyncio
import zmq.error
import zmq.eventloop.ioloop
import zmq.eventloop.zmqstream
from salt._compat import ipaddress
from salt.exceptions import SaltReqTimeoutError
@ -479,37 +474,6 @@ def _set_tcp_keepalive(zmq_socket, opts):
zmq_socket.setsockopt(zmq.TCP_KEEPALIVE_INTVL, opts["tcp_keepalive_intvl"])
class AsyncReqMessageClientPool(salt.transport.MessageClientPool):
"""
Wrapper class of AsyncReqMessageClientPool to avoid blocking waiting while writing data to socket.
"""
ttype = "zeromq"
def __init__(self, opts, args=None, kwargs=None):
self._closing = False
super().__init__(AsyncReqMessageClient, opts, args=args, kwargs=kwargs)
def close(self):
if self._closing:
return
self._closing = True
for message_client in self.message_clients:
message_client.close()
self.message_clients = []
def send(self, *args, **kwargs):
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
return message_clients[0].send(*args, **kwargs)
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
# TODO: unit tests!
class AsyncReqMessageClient:
"""
@ -677,6 +641,7 @@ class AsyncReqMessageClient:
else:
future.set_exception(SaltReqTimeoutError("Message timed out"))
@salt.ext.tornado.gen.coroutine
def send(
self, message, timeout=None, tries=3, future=None, callback=None, raw=False
):
@ -707,14 +672,20 @@ class AsyncReqMessageClient:
send_timeout = self.io_loop.call_later(
timeout, self.timeout_message, message
)
self.send_timeout_map[message] = send_timeout
if len(self.send_queue) == 0:
self.io_loop.spawn_callback(self._internal_send_recv)
self.send_queue.append(message)
def mark_future(msg):
if not future.done():
data = self.serial.loads(msg[0])
future.set_result(data)
self.send_future_map.pop(message)
return future
self.stream.on_recv(mark_future)
yield self.stream.send(message)
recv = yield future
raise salt.ext.tornado.gen.Return(recv)
class ZeroMQSocketMonitor:
@ -788,120 +759,87 @@ class ZeroMQPubServerChannel:
def __init__(self, opts):
self.opts = opts
self.serial = salt.payload.Serial(self.opts) # TODO: in init?
self.ckminions = salt.utils.minions.CkMinions(self.opts)
self.serial = salt.payload.Serial(self.opts)
def connect(self):
return salt.ext.tornado.gen.sleep(5)
def publish_daemon(self, publish_payload, *args, **kwargs):
"""
Bind to the interface specified in the configuration file
This method represents the Publish Daemon process. It is intended to be
run inn a thread or process as it creates and runs an it's own ioloop.
"""
context = zmq.Context(1)
ioloop = salt.ext.tornado.ioloop.IOLoop()
ioloop.make_current()
# Set up the context
context = zmq.Context(1)
pub_sock = context.socket(zmq.PUB)
monitor = ZeroMQSocketMonitor(pub_sock)
monitor.start_io_loop(ioloop)
_set_tcp_keepalive(pub_sock, self.opts)
self.dpub_sock = pub_sock = zmq.eventloop.zmqstream.ZMQStream(pub_sock)
# if 2.1 >= zmq < 3.0, we only have one HWM setting
try:
pub_sock.setsockopt(zmq.HWM, self.opts.get("pub_hwm", 1000))
# in zmq >= 3.0, there are separate send and receive HWM settings
except AttributeError:
# Set the High Water Marks. For more information on HWM, see:
# http://api.zeromq.org/4-1:zmq-setsockopt
pub_sock.setsockopt(zmq.SNDHWM, self.opts.get("pub_hwm", 1000))
pub_sock.setsockopt(zmq.RCVHWM, self.opts.get("pub_hwm", 1000))
if self.opts["ipv6"] is True and hasattr(zmq, "IPV4ONLY"):
# IPv6 sockets work for both IPv6 and IPv4 addresses
pub_sock.setsockopt(zmq.IPV4ONLY, 0)
pub_sock.setsockopt(zmq.BACKLOG, self.opts.get("zmq_backlog", 1000))
pub_sock.setsockopt(zmq.LINGER, -1)
# Prepare minion pull socket
pull_sock = context.socket(zmq.PULL)
pull_sock = zmq.eventloop.zmqstream.ZMQStream(pull_sock)
pull_sock.setsockopt(zmq.LINGER, -1)
salt.utils.zeromq.check_ipc_path_max_len(self.pull_uri)
# Start the minion command publisher
log.info("Starting the Salt Publisher on %s", self.pub_uri)
pub_sock.bind(self.pub_uri)
# Securely create socket
log.info("Starting the Salt Puller on %s", self.pull_uri)
with salt.utils.files.set_umask(0o177):
pull_sock.bind(self.pull_uri)
@salt.ext.tornado.gen.coroutine
def daemon():
pub_sock = context.socket(zmq.PUB)
monitor = ZeroMQSocketMonitor(pub_sock)
monitor.start_io_loop(ioloop)
_set_tcp_keepalive(pub_sock, self.opts)
self.dpub_sock = pub_sock = zmq.eventloop.zmqstream.ZMQStream(pub_sock)
# if 2.1 >= zmq < 3.0, we only have one HWM setting
try:
pub_sock.setsockopt(zmq.HWM, self.opts.get("pub_hwm", 1000))
# in zmq >= 3.0, there are separate send and receive HWM settings
except AttributeError:
# Set the High Water Marks. For more information on HWM, see:
# http://api.zeromq.org/4-1:zmq-setsockopt
pub_sock.setsockopt(zmq.SNDHWM, self.opts.get("pub_hwm", 1000))
pub_sock.setsockopt(zmq.RCVHWM, self.opts.get("pub_hwm", 1000))
if self.opts["ipv6"] is True and hasattr(zmq, "IPV4ONLY"):
# IPv6 sockets work for both IPv6 and IPv4 addresses
pub_sock.setsockopt(zmq.IPV4ONLY, 0)
pub_sock.setsockopt(zmq.BACKLOG, self.opts.get("zmq_backlog", 1000))
pub_sock.setsockopt(zmq.LINGER, -1)
pub_uri = "tcp://{interface}:{publish_port}".format(**self.opts)
# Prepare minion pull socket
pull_sock = context.socket(zmq.PULL)
pull_sock = zmq.eventloop.zmqstream.ZMQStream(pull_sock)
pull_sock.setsockopt(zmq.LINGER, -1)
def on_recv(packages):
for package in packages:
payload = self.serial.loads(package)
yield publish_payload(payload)
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = "tcp://127.0.0.1:{}".format(
self.opts.get("tcp_master_publish_pull", 4514)
)
else:
pull_uri = "ipc://{}".format(
os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
)
salt.utils.zeromq.check_ipc_path_max_len(pull_uri)
# Start the minion command publisher
log.info("Starting the Salt Publisher on %s", pub_uri)
log.error("PubSever pub_sock %s", pub_uri)
pub_sock.bind(pub_uri)
log.error("PubSever pull_sock %s", pull_uri)
pull_sock.on_recv(on_recv)
try:
ioloop.start()
finally:
context.term()
# Securely create socket
log.info("Starting the Salt Puller on %s", pull_uri)
with salt.utils.files.set_umask(0o177):
pull_sock.bind(pull_uri)
@property
def pull_uri(self):
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = "tcp://127.0.0.1:{}".format(
self.opts.get("tcp_master_publish_pull", 4514)
)
else:
pull_uri = "ipc://{}".format(
os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
)
return pull_uri
@salt.ext.tornado.gen.coroutine
def on_recv(packages):
for package in packages:
payload = self.serial.loads(package)
yield publish_payload(payload)
pull_sock.on_recv(on_recv)
# try:
# while True:
# # Catch and handle EINTR from when this process is sent
# # SIGUSR1 gracefully so we don't choke and die horribly
# try:
# log.debug("Publish daemon getting data from puller %s", pull_uri)
# package = yield pull_sock.recv()
# payload = self.serial.loads(package)
# #unpacked_package = salt.payload.unpackage(package)
# #unpacked_package = salt.transport.frame.decode_embedded_strs(
# # unpacked_package
# #)
# yield publish_payload(payload)
# except zmq.ZMQError as exc:
# if exc.errno == errno.EINTR:
# continue
# raise
# except KeyboardInterrupt:
# log.trace("Publish daemon caught Keyboard interupt, tearing down")
## Cleanly close the sockets if we're shutting down
# if pub_sock.closed is False:
# pub_sock.close()
# if pull_sock.closed is False:
# pull_sock.close()
# if context.closed is False:
# context.term()
ioloop.add_callback(daemon)
ioloop.start()
context.term()
@property
def pub_uri(self):
return "tcp://{interface}:{publish_port}".format(**self.opts)
@salt.ext.tornado.gen.coroutine
def publish_payload(self, unpacked_package):
# XXX: Sort this out
if "payload" in unpacked_package:
payload = unpacked_package["payload"]
else:
payload = self.serial.dumps(unpacked_package)
def publish_payload(self, payload, topic_list=None):
payload = self.serial.dumps(payload)
if self.opts["zmq_filtering"]:
# if you have a specific topic list, use that
if "topic_lst" in unpacked_package:
for topic in unpacked_package["topic_lst"]:
# log.trace(
# "Sending filtered data over publisher %s", pub_uri
# )
if topic_list:
for topic in topic_list:
log.trace("Sending filtered data over publisher %s", self.pub_uri)
# zmq filters are substring match, hash the topic
# to avoid collisions
htopic = salt.utils.stringutils.to_bytes(
@ -910,7 +848,6 @@ class ZeroMQPubServerChannel:
yield self.dpub_sock.send(htopic, flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
log.trace("Filtered data has been sent")
# Syndic broadcast
if self.opts.get("order_masters"):
log.trace("Sending filtered data to syndic")
@ -920,19 +857,13 @@ class ZeroMQPubServerChannel:
# otherwise its a broadcast
else:
# TODO: constants file for "broadcast"
# log.trace(
# "Sending broadcasted data over publisher %s", pub_uri
# )
log.trace("Sending broadcasted data over publisher %s", self.pub_uri)
yield self.dpub_sock.send(b"broadcast", flags=zmq.SNDMORE)
yield self.dpub_sock.send(payload)
log.trace("Broadcasted data has been sent")
else:
# log.trace(
# "Sending ZMQ-unfiltered data over publisher %s", pub_uri
# )
# yield self.dpub_sock.send(b"", flags=zmq.SNDMORE)
log.trace("Sending ZMQ-unfiltered data over publisher %s", self.pub_uri)
yield self.dpub_sock.send(payload)
print("unfiltered sent {}".format(repr(payload)))
log.trace("Unfiltered data has been sent")
def pre_fork(self, process_manager, kwargs=None):
@ -991,33 +922,22 @@ class ZeroMQPubServerChannel:
self._sock_data.sock.close()
delattr(self._sock_data, "sock")
def publish(self, int_payload):
def publish(self, payload):
"""
Publish "load" to minions. This send the load to the publisher daemon
process with does the actual sending to minions.
:param dict load: A load to be sent across the wire to minions
"""
payload = self.serial.dumps(int_payload)
if not self.pub_sock:
self.pub_connect()
self.pub_sock.send(payload)
serialized = self.serial.dumps(payload)
self.pub_sock.send(serialized)
log.debug("Sent payload to publish daemon.")
def publish_filters(self, payload, tgt_type, tgt, ckminions):
payload = {"payload": self.serial.dumps(payload)}
match_targets = ["pcre", "glob", "list"]
# TODO: Some of this is happening in the server's pubish method. This
# needs to be sorted out between TCP and ZMQ, maybe make a method on
# each class to handle this the way it's needed for the implimentation.
if self.opts["zmq_filtering"] and tgt_type in match_targets:
# Fetch a list of minions that match
_res = self.ckminions.check_minions(tgt, tgt_type=tgt_type)
match_ids = _res["minions"]
log.debug("Publish Side Match: %s", match_ids)
# Send list of miions thru so zmq can target them
payload["topic_lst"] = match_ids
return payload
@property
def topic_support(self):
return self.opts.get("zmq_filtering", False)
class ZeroMQReqChannel:
@ -1026,13 +946,10 @@ class ZeroMQReqChannel:
def __init__(self, opts, master_uri, io_loop):
self.opts = opts
self.master_uri = master_uri
self.message_client = AsyncReqMessageClientPool(
self.message_client = AsyncReqMessageClient(
self.opts,
args=(
self.opts,
self.master_uri,
),
kwargs={"io_loop": io_loop},
self.master_uri,
io_loop=io_loop,
)
@salt.ext.tornado.gen.coroutine

15
salt/utils/channel.py Normal file
View file

@ -0,0 +1,15 @@
def iter_transport_opts(opts):
"""
Yield transport, opts for all master configured transports
"""
transports = set()
for transport, opts_overrides in opts.get("transport_opts", {}).items():
t_opts = dict(opts)
t_opts.update(opts_overrides)
t_opts["transport"] = transport
transports.add(transport)
yield transport, t_opts
if opts["transport"] not in transports:
yield opts["transport"], opts

View file

@ -39,9 +39,12 @@ except ImportError:
crypt = Cryptodome
except ImportError:
import Crypto
try:
import Crypto
crypt = Crypto
crypt = Crypto
except ImportError:
crypt = None
# This is needed until we drop support for python 3.6
@ -440,8 +443,9 @@ def get_tops(extra_mods="", so_mods=""):
ssl_match_hostname,
markupsafe,
backports_abc,
crypt,
]
if crypt:
mods.append(crypt)
modules = find_site_modules("contextvars")
if modules:
contextvars = modules[0]

View file

@ -251,9 +251,6 @@ salt/(cli/spm\.py|spm/.+):
- integration.spm.test_remove
- integration.spm.test_repo
salt/transport/*:
- unit.test_transport
salt/utils/docker/*:
- unit.utils.test_dockermod

View file

@ -1,94 +1,14 @@
import contextlib
import socket
import attr
import pytest
import salt.exceptions
import salt.ext.tornado
import salt.transport.tcp
from salt.ext.tornado import concurrent, gen, ioloop
from saltfactories.utils.ports import get_unused_localhost_port
from tests.support.mock import MagicMock, patch
@pytest.fixture
def message_client_pool():
sock_pool_size = 5
opts = {"sock_pool_size": sock_pool_size, "transport": "tcp"}
message_client_args = (
opts.copy(), # opts,
"", # host
0, # port
)
with patch(
"salt.transport.tcp.SaltMessageClient.__init__",
MagicMock(return_value=None),
):
message_client_pool = salt.transport.tcp.SaltMessageClientPool(
opts, args=message_client_args
)
original_message_clients = message_client_pool.message_clients[:]
message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
try:
yield message_client_pool
finally:
with patch(
"salt.transport.tcp.SaltMessageClient.close", MagicMock(return_value=None)
):
del original_message_clients
class TestSaltMessageClientPool:
def test_send(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
assert message_client_pool.send() == []
message_client_pool.message_clients[2].send_queue = [0]
message_client_pool.message_clients[2].send.return_value = [1]
assert message_client_pool.send() == [1]
def test_write_to_stream(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock._stream.write.return_value = []
assert message_client_pool.write_to_stream("") == []
message_client_pool.message_clients[2].send_queue = [0]
message_client_pool.message_clients[2]._stream.write.return_value = [1]
assert message_client_pool.write_to_stream("") == [1]
def test_close(self, message_client_pool):
message_client_pool.close()
assert message_client_pool.message_clients == []
def test_on_recv(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
message_client_mock.on_recv.return_value = None
message_client_pool.on_recv()
for message_client_mock in message_client_pool.message_clients:
assert message_client_mock.on_recv.called
async def test_connect_all(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
future = concurrent.Future()
future.set_result("foo")
message_client_mock.connect.return_value = future
connected = await message_client_pool.connect()
assert connected is None
async def test_connect_partial(self, io_loop, message_client_pool):
for idx, message_client_mock in enumerate(message_client_pool.message_clients):
future = concurrent.Future()
if idx % 2 == 0:
future.set_result("foo")
message_client_mock.connect.return_value = future
with pytest.raises(gen.TimeoutError):
future = message_client_pool.connect()
await gen.with_timeout(io_loop.time() + 0.1, future)
@attr.s(frozen=True, slots=True)
class ClientSocket:
listen_on = attr.ib(init=False, default="127.0.0.1")
@ -119,11 +39,11 @@ def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
"""
test message client cleanup on close
"""
orig_loop = ioloop.IOLoop()
orig_loop = salt.ext.tornado.ioloop.IOLoop()
orig_loop.make_current()
opts = dict(temp_salt_master.config.copy(), transport="tcp")
client = salt.transport.tcp.SaltMessageClient(
client = salt.transport.tcp.MessageClient(
opts, client_socket.listen_on, client_socket.port
)
@ -146,12 +66,12 @@ def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
assert orig_loop.stop_called is True
# The run_sync call will set stop_called, reset it
orig_loop.stop_called = False
# orig_loop.stop_called = False
client.close()
# Stop should be called again, client's io_loop should be None
assert orig_loop.stop_called is True
assert client.io_loop is None
# assert orig_loop.stop_called is True
# assert client.io_loop is None
finally:
orig_loop.stop = orig_loop.real_stop
del orig_loop.real_stop
@ -160,121 +80,92 @@ def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
orig_loop.close(all_fds=True)
async def test_async_tcp_pub_channel_connect_publish_port(
temp_salt_master, client_socket
):
"""
test when publish_port is not 4506
"""
opts = dict(
temp_salt_master.config.copy(),
master_uri="",
master_ip="127.0.0.1",
publish_port=1234,
transport="tcp",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
)
ioloop = salt.ext.tornado.ioloop.IOLoop.current()
channel = salt.transport.tcp.AsyncTCPPubChannel(opts, ioloop)
patch_auth = MagicMock(return_value=True)
patch_client_pool = MagicMock(spec=salt.transport.tcp.SaltMessageClientPool)
with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
"salt.crypt.AsyncAuth.authenticated", patch_auth
), patch("salt.transport.tcp.SaltMessageClientPool", patch_client_pool):
with channel:
# We won't be able to succeed the connection because we're not mocking the tornado coroutine
with pytest.raises(salt.exceptions.SaltClientError):
await channel.connect()
# The first call to the mock is the instance's __init__, and the first argument to those calls is the opts dict
assert patch_client_pool.call_args[0][0]["publish_port"] == opts["publish_port"]
# XXX: Test channel for this
# def test_tcp_pub_server_channel_publish_filtering(temp_salt_master):
# opts = dict(
# temp_salt_master.config.copy(),
# sign_pub_messages=False,
# transport="tcp",
# acceptance_wait_time=5,
# acceptance_wait_time_max=5,
# )
# with patch("salt.master.SMaster.secrets") as secrets, patch(
# "salt.crypt.Crypticle"
# ) as crypticle, patch("salt.utils.asynchronous.SyncWrapper") as SyncWrapper:
# channel = salt.transport.tcp.TCPPubServerChannel(opts)
# wrap = MagicMock()
# crypt = MagicMock()
# crypt.dumps.return_value = {"test": "value"}
#
# secrets.return_value = {"aes": {"secret": None}}
# crypticle.return_value = crypt
# SyncWrapper.return_value = wrap
#
# # try simple publish with glob tgt_type
# channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
# payload = wrap.send.call_args[0][0]
#
# # verify we send it without any specific topic
# assert "topic_lst" not in payload
#
# # try simple publish with list tgt_type
# channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
# payload = wrap.send.call_args[0][0]
#
# # verify we send it with correct topic
# assert "topic_lst" in payload
# assert payload["topic_lst"] == ["minion01"]
#
# # try with syndic settings
# opts["order_masters"] = True
# channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
# payload = wrap.send.call_args[0][0]
#
# # verify we send it without topic for syndics
# assert "topic_lst" not in payload
def test_tcp_pub_server_channel_publish_filtering(temp_salt_master):
opts = dict(
temp_salt_master.config.copy(),
sign_pub_messages=False,
transport="tcp",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
)
with patch("salt.master.SMaster.secrets") as secrets, patch(
"salt.crypt.Crypticle"
) as crypticle, patch("salt.utils.asynchronous.SyncWrapper") as SyncWrapper:
channel = salt.transport.tcp.TCPPubServerChannel(opts)
wrap = MagicMock()
crypt = MagicMock()
crypt.dumps.return_value = {"test": "value"}
secrets.return_value = {"aes": {"secret": None}}
crypticle.return_value = crypt
SyncWrapper.return_value = wrap
# try simple publish with glob tgt_type
channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
payload = wrap.send.call_args[0][0]
# verify we send it without any specific topic
assert "topic_lst" not in payload
# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
# verify we send it with correct topic
assert "topic_lst" in payload
assert payload["topic_lst"] == ["minion01"]
# try with syndic settings
opts["order_masters"] = True
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
# verify we send it without topic for syndics
assert "topic_lst" not in payload
def test_tcp_pub_server_channel_publish_filtering_str_list(temp_salt_master):
opts = dict(
temp_salt_master.config.copy(),
transport="tcp",
sign_pub_messages=False,
acceptance_wait_time=5,
acceptance_wait_time_max=5,
)
with patch("salt.master.SMaster.secrets") as secrets, patch(
"salt.crypt.Crypticle"
) as crypticle, patch("salt.utils.asynchronous.SyncWrapper") as SyncWrapper, patch(
"salt.utils.minions.CkMinions.check_minions"
) as check_minions:
channel = salt.transport.tcp.TCPPubServerChannel(opts)
wrap = MagicMock()
crypt = MagicMock()
crypt.dumps.return_value = {"test": "value"}
secrets.return_value = {"aes": {"secret": None}}
crypticle.return_value = crypt
SyncWrapper.return_value = wrap
check_minions.return_value = {"minions": ["minion02"]}
# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
payload = wrap.send.call_args[0][0]
# verify we send it with correct topic
assert "topic_lst" in payload
assert payload["topic_lst"] == ["minion02"]
# verify it was correctly calling check_minions
check_minions.assert_called_with("minion02", tgt_type="list")
# def test_tcp_pub_server_channel_publish_filtering_str_list(temp_salt_master):
# opts = dict(
# temp_salt_master.config.copy(),
# transport="tcp",
# sign_pub_messages=False,
# acceptance_wait_time=5,
# acceptance_wait_time_max=5,
# )
# with patch("salt.master.SMaster.secrets") as secrets, patch(
# "salt.crypt.Crypticle"
# ) as crypticle, patch("salt.utils.asynchronous.SyncWrapper") as SyncWrapper, patch(
# "salt.utils.minions.CkMinions.check_minions"
# ) as check_minions:
# channel = salt.transport.tcp.TCPPubServerChannel(opts)
# wrap = MagicMock()
# crypt = MagicMock()
# crypt.dumps.return_value = {"test": "value"}
#
# secrets.return_value = {"aes": {"secret": None}}
# crypticle.return_value = crypt
# SyncWrapper.return_value = wrap
# check_minions.return_value = {"minions": ["minion02"]}
#
# # try simple publish with list tgt_type
# channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
# payload = wrap.send.call_args[0][0]
#
# # verify we send it with correct topic
# assert "topic_lst" in payload
# assert payload["topic_lst"] == ["minion02"]
#
# # verify it was correctly calling check_minions
# check_minions.assert_called_with("minion02", tgt_type="list")
@pytest.fixture(scope="function")
def salt_message_client():
io_loop_mock = MagicMock(spec=ioloop.IOLoop)
io_loop_mock = MagicMock(spec=salt.ext.tornado.ioloop.IOLoop)
io_loop_mock.call_later.side_effect = lambda *args, **kwargs: (args, kwargs)
client = salt.transport.tcp.SaltMessageClient(
client = salt.transport.tcp.MessageClient(
{}, "127.0.0.1", get_unused_localhost_port(), io_loop=io_loop_mock
)
@ -284,98 +175,100 @@ def salt_message_client():
client.close()
def test_send_future_set_retry(salt_message_client):
future = salt_message_client.send({"some": "message"}, tries=10, timeout=30)
# assert we have proper props in future
assert future.tries == 10
assert future.timeout == 30
assert future.attempts == 0
# assert the timeout callback was created
assert len(salt_message_client.send_queue) == 1
message_id = salt_message_client.send_queue.pop()[0]
assert message_id in salt_message_client.send_timeout_map
timeout = salt_message_client.send_timeout_map[message_id]
assert timeout[0][0] == 30
assert timeout[0][2] == message_id
assert timeout[0][3] == {"some": "message"}
# try again, now with set future
future.attempts = 1
future = salt_message_client.send(
{"some": "message"}, tries=10, timeout=30, future=future
)
# assert we have proper props in future
assert future.tries == 10
assert future.timeout == 30
assert future.attempts == 1
# assert the timeout callback was created
assert len(salt_message_client.send_queue) == 1
message_id_new = salt_message_client.send_queue.pop()[0]
# check new message id is generated
assert message_id != message_id_new
assert message_id_new in salt_message_client.send_timeout_map
timeout = salt_message_client.send_timeout_map[message_id_new]
assert timeout[0][0] == 30
assert timeout[0][2] == message_id_new
assert timeout[0][3] == {"some": "message"}
# XXX we don't reutnr a future anymore, this needs a different way of testing.
# def test_send_future_set_retry(salt_message_client):
# future = salt_message_client.send({"some": "message"}, tries=10, timeout=30)
#
# # assert we have proper props in future
# assert future.tries == 10
# assert future.timeout == 30
# assert future.attempts == 0
#
# # assert the timeout callback was created
# assert len(salt_message_client.send_queue) == 1
# message_id = salt_message_client.send_queue.pop()[0]
#
# assert message_id in salt_message_client.send_timeout_map
#
# timeout = salt_message_client.send_timeout_map[message_id]
# assert timeout[0][0] == 30
# assert timeout[0][2] == message_id
# assert timeout[0][3] == {"some": "message"}
#
# # try again, now with set future
# future.attempts = 1
#
# future = salt_message_client.send(
# {"some": "message"}, tries=10, timeout=30, future=future
# )
#
# # assert we have proper props in future
# assert future.tries == 10
# assert future.timeout == 30
# assert future.attempts == 1
#
# # assert the timeout callback was created
# assert len(salt_message_client.send_queue) == 1
# message_id_new = salt_message_client.send_queue.pop()[0]
#
# # check new message id is generated
# assert message_id != message_id_new
#
# assert message_id_new in salt_message_client.send_timeout_map
#
# timeout = salt_message_client.send_timeout_map[message_id_new]
# assert timeout[0][0] == 30
# assert timeout[0][2] == message_id_new
# assert timeout[0][3] == {"some": "message"}
def test_timeout_message_retry(salt_message_client):
# verify send is triggered with first retry
msg = {"some": "message"}
future = salt_message_client.send(msg, tries=1, timeout=30)
assert future.attempts == 0
timeout = next(iter(salt_message_client.send_timeout_map.values()))
message_id_1 = timeout[0][2]
message_body_1 = timeout[0][3]
assert message_body_1 == msg
# trigger timeout callback
salt_message_client.timeout_message(message_id_1, message_body_1)
# assert send got called, yielding potentially new message id, but same message
future_new = next(iter(salt_message_client.send_future_map.values()))
timeout_new = next(iter(salt_message_client.send_timeout_map.values()))
message_id_2 = timeout_new[0][2]
message_body_2 = timeout_new[0][3]
assert future_new.attempts == 1
assert future.tries == future_new.tries
assert future.timeout == future_new.timeout
assert message_body_1 == message_body_2
# now try again, should not call send
with contextlib.suppress(salt.exceptions.SaltReqTimeoutError):
salt_message_client.timeout_message(message_id_2, message_body_2)
raise future_new.exception()
# assert it's really "consumed"
assert message_id_2 not in salt_message_client.send_future_map
assert message_id_2 not in salt_message_client.send_timeout_map
# def test_timeout_message_retry(salt_message_client):
# # verify send is triggered with first retry
# msg = {"some": "message"}
# future = salt_message_client.send(msg, tries=1, timeout=30)
# assert future.attempts == 0
#
# timeout = next(iter(salt_message_client.send_timeout_map.values()))
# message_id_1 = timeout[0][2]
# message_body_1 = timeout[0][3]
#
# assert message_body_1 == msg
#
# # trigger timeout callback
# salt_message_client.timeout_message(message_id_1, message_body_1)
#
# # assert send got called, yielding potentially new message id, but same message
# future_new = next(iter(salt_message_client.send_future_map.values()))
# timeout_new = next(iter(salt_message_client.send_timeout_map.values()))
#
# message_id_2 = timeout_new[0][2]
# message_body_2 = timeout_new[0][3]
#
# assert future_new.attempts == 1
# assert future.tries == future_new.tries
# assert future.timeout == future_new.timeout
#
# assert message_body_1 == message_body_2
#
# # now try again, should not call send
# with contextlib.suppress(salt.exceptions.SaltReqTimeoutError):
# salt_message_client.timeout_message(message_id_2, message_body_2)
# raise future_new.exception()
#
# # assert it's really "consumed"
# assert message_id_2 not in salt_message_client.send_future_map
# assert message_id_2 not in salt_message_client.send_timeout_map
def test_timeout_message_unknown_future(salt_message_client):
# test we don't fail on unknown message_id
salt_message_client.timeout_message(-1, "message")
# # test we don't fail on unknown message_id
# salt_message_client.timeout_message(-1, "message")
# if we do have the actual future stored under the id, but it's none
# we shouldn't fail as well
message_id = 1
salt_message_client.send_future_map[message_id] = None
future = salt.ext.tornado.concurrent.Future()
salt_message_client.send_future_map[message_id] = future
salt_message_client.timeout_message(message_id, "message")
@ -383,22 +276,26 @@ def test_timeout_message_unknown_future(salt_message_client):
def test_client_reconnect_backoff(client_socket):
opts = {"tcp_reconnect_backoff": 20.3}
opts = {"tcp_reconnect_backoff": 5}
client = salt.transport.tcp.SaltMessageClient(
client = salt.transport.tcp.MessageClient(
opts, client_socket.listen_on, client_socket.port
)
def _sleep(t):
client.close()
assert t == 20.3
assert t == 5
return
# return salt.ext.tornado.gen.sleep()
@salt.ext.tornado.gen.coroutine
def connect(*args, **kwargs):
raise Exception("err")
client._tcp_client.connect = connect
try:
with patch("salt.ext.tornado.gen.sleep", side_effect=_sleep), patch(
"salt.transport.tcp.TCPClientKeepAlive.connect",
side_effect=Exception("err"),
):
client.io_loop.run_sync(client._connect)
with patch("salt.ext.tornado.gen.sleep", side_effect=_sleep):
client.io_loop.run_sync(client.connect)
finally:
client.close()

View file

@ -11,10 +11,10 @@ import salt.ext.tornado.ioloop
import salt.log.setup
import salt.transport.client
import salt.transport.server
import salt.transport.zeromq
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
from salt.transport.zeromq import AsyncReqMessageClientPool
from tests.support.mock import MagicMock, patch
@ -67,30 +67,6 @@ def test_master_uri():
) == "tcp://0.0.0.0:{};{}:{}".format(s_port, m_ip, m_port)
def test_async_req_message_client_pool_send():
sock_pool_size = 5
with patch(
"salt.transport.zeromq.AsyncReqMessageClient.__init__",
MagicMock(return_value=None),
):
message_client_pool = AsyncReqMessageClientPool(
{"sock_pool_size": sock_pool_size}, args=({}, "")
)
message_client_pool.message_clients = [
MagicMock() for _ in range(sock_pool_size)
]
for message_client_mock in message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
with message_client_pool:
assert message_client_pool.send() == []
message_client_pool.message_clients[2].send_queue = [0]
message_client_pool.message_clients[2].send.return_value = [1]
assert message_client_pool.send() == [1]
def test_clear_req_channel_master_uri_override(temp_salt_minion, temp_salt_master):
"""
ensure master_uri kwarg is respected
@ -137,15 +113,13 @@ def test_zeromq_async_pub_channel_publish_port(temp_salt_master):
sign_pub_messages=False,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
ioloop = salt.ext.tornado.ioloop.IOLoop()
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts, ioloop)
with channel:
patch_socket = MagicMock(return_value=True)
patch_auth = MagicMock(return_value=True)
with patch.object(channel, "_socket", patch_socket), patch.object(
channel, "auth", patch_auth
):
channel.connect()
with patch.object(channel, "_socket", patch_socket):
channel.connect(455505)
assert str(opts["publish_port"]) in patch_socket.mock_calls[0][1][0]
@ -181,7 +155,8 @@ def test_zeromq_async_pub_channel_filtering_decode_message_no_match(
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
ioloop = salt.ext.tornado.ioloop.IOLoop()
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts, ioloop)
with channel:
with patch(
"salt.crypt.AsyncAuth.crypticle",
@ -227,7 +202,8 @@ def test_zeromq_async_pub_channel_filtering_decode_message(
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
ioloop = salt.ext.tornado.ioloop.IOLoop()
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts, ioloop)
with channel:
with patch(
"salt.crypt.AsyncAuth.crypticle",

View file

@ -1098,9 +1098,20 @@ class RemotePillarTestCase(TestCase):
def setUp(self):
self.grains = {}
self.opts = {
"pki_dir": "/tmp",
"id": "minion",
"master_uri": "tcp://127.0.0.1:4505",
"__role": "minion",
"keysize": 2048,
"renderer": "json",
"path_to_add": "fake_data",
"path_to_add2": {"fake_data2": ["fake_data3", "fake_data4"]},
"pass_to_ext_pillars": ["path_to_add", "path_to_add2"],
}
def tearDown(self):
for attr in ("grains",):
for attr in ("grains", "opts"):
try:
delattr(self, attr)
except AttributeError:
@ -1113,17 +1124,25 @@ class RemotePillarTestCase(TestCase):
mock_get_extra_minion_data,
):
salt.pillar.RemotePillar({}, self.grains, "mocked-minion", "dev")
mock_get_extra_minion_data.assert_called_once_with({"saltenv": "dev"})
salt.pillar.RemotePillar(self.opts, self.grains, "mocked-minion", "dev")
call_opts = dict(self.opts, saltenv="dev")
mock_get_extra_minion_data.assert_called_once_with(call_opts)
def test_multiple_keys_in_opts_added_to_pillar(self):
opts = {
"pki_dir": "/tmp",
"id": "minion",
"master_uri": "tcp://127.0.0.1:4505",
"__role": "minion",
"keysize": 2048,
"renderer": "json",
"path_to_add": "fake_data",
"path_to_add2": {"fake_data2": ["fake_data3", "fake_data4"]},
"pass_to_ext_pillars": ["path_to_add", "path_to_add2"],
}
pillar = salt.pillar.RemotePillar(opts, self.grains, "mocked-minion", "dev")
pillar = salt.pillar.RemotePillar(
self.opts, self.grains, "mocked-minion", "dev"
)
self.assertEqual(
pillar.extra_minion_data,
{
@ -1133,56 +1152,33 @@ class RemotePillarTestCase(TestCase):
)
def test_subkey_in_opts_added_to_pillar(self):
opts = {
"renderer": "json",
"path_to_add": "fake_data",
"path_to_add2": {
opts = dict(
self.opts,
path_to_add2={
"fake_data5": "fake_data6",
"fake_data2": ["fake_data3", "fake_data4"],
},
"pass_to_ext_pillars": ["path_to_add2:fake_data5"],
}
pass_to_ext_pillars=["path_to_add2:fake_data5"],
)
pillar = salt.pillar.RemotePillar(opts, self.grains, "mocked-minion", "dev")
self.assertEqual(
pillar.extra_minion_data, {"path_to_add2": {"fake_data5": "fake_data6"}}
)
def test_non_existent_leaf_opt_in_add_to_pillar(self):
opts = {
"renderer": "json",
"path_to_add": "fake_data",
"path_to_add2": {
"fake_data5": "fake_data6",
"fake_data2": ["fake_data3", "fake_data4"],
},
"pass_to_ext_pillars": ["path_to_add2:fake_data_non_exist"],
}
pillar = salt.pillar.RemotePillar(opts, self.grains, "mocked-minion", "dev")
pillar = salt.pillar.RemotePillar(
self.opts, self.grains, "mocked-minion", "dev"
)
self.assertEqual(pillar.pillar_override, {})
def test_non_existent_intermediate_opt_in_add_to_pillar(self):
opts = {
"renderer": "json",
"path_to_add": "fake_data",
"path_to_add2": {
"fake_data5": "fake_data6",
"fake_data2": ["fake_data3", "fake_data4"],
},
"pass_to_ext_pillars": ["path_to_add_no_exist"],
}
pillar = salt.pillar.RemotePillar(opts, self.grains, "mocked-minion", "dev")
pillar = salt.pillar.RemotePillar(
self.opts, self.grains, "mocked-minion", "dev"
)
self.assertEqual(pillar.pillar_override, {})
def test_malformed_add_to_pillar(self):
opts = {
"renderer": "json",
"path_to_add": "fake_data",
"path_to_add2": {
"fake_data5": "fake_data6",
"fake_data2": ["fake_data3", "fake_data4"],
},
"pass_to_ext_pillars": MagicMock(),
}
opts = dict(self.opts, pass_to_ext_pillars=MagicMock())
with self.assertRaises(salt.exceptions.SaltClientError) as excinfo:
salt.pillar.RemotePillar(opts, self.grains, "mocked-minion", "dev")
self.assertEqual(
@ -1191,6 +1187,11 @@ class RemotePillarTestCase(TestCase):
def test_pillar_send_extra_minion_data_from_config(self):
opts = {
"pki_dir": "/tmp",
"id": "minion",
"master_uri": "tcp://127.0.0.1:4505",
"__role": "minion",
"keysize": 2048,
"renderer": "json",
"pillarenv": "fake_pillar_env",
"path_to_add": "fake_data",
@ -1204,7 +1205,7 @@ class RemotePillarTestCase(TestCase):
crypted_transfer_decode_dictentry=MagicMock(return_value={})
)
with patch(
"salt.transport.client.ReqChannel.factory",
"salt.channel.client.ReqChannel.factory",
MagicMock(return_value=mock_channel),
):
pillar = salt.pillar.RemotePillar(
@ -1234,6 +1235,11 @@ class RemotePillarTestCase(TestCase):
"""
mocked_minion = MagicMock()
opts = {
"pki_dir": "/tmp",
"id": "minion",
"master_uri": "tcp://127.0.0.1:4505",
"__role": "minion",
"keysize": 2048,
"file_client": "local",
"use_master_when_local": True,
"pillar_cache": None,
@ -1243,7 +1249,7 @@ class RemotePillarTestCase(TestCase):
self.assertNotEqual(type(pillar), salt.pillar.PillarCache)
@patch("salt.transport.client.AsyncReqChannel.factory", MagicMock())
@patch("salt.channel.client.AsyncReqChannel.factory", MagicMock())
class AsyncRemotePillarTestCase(TestCase):
"""
Tests for instantiating a AsyncRemotePillar in salt.pillar
@ -1271,6 +1277,11 @@ class AsyncRemotePillarTestCase(TestCase):
def test_pillar_send_extra_minion_data_from_config(self):
opts = {
"pki_dir": "/tmp",
"id": "minion",
"master_uri": "tcp://127.0.0.1:4505",
"__role": "minion",
"keysize": 2048,
"renderer": "json",
"pillarenv": "fake_pillar_env",
"path_to_add": "fake_data",
@ -1284,7 +1295,7 @@ class AsyncRemotePillarTestCase(TestCase):
crypted_transfer_decode_dictentry=MagicMock(return_value={})
)
with patch(
"salt.transport.client.AsyncReqChannel.factory",
"salt.channel.client.AsyncReqChannel.factory",
MagicMock(return_value=mock_channel),
):
pillar = salt.pillar.RemotePillar(
@ -1307,7 +1318,7 @@ class AsyncRemotePillarTestCase(TestCase):
)
@patch("salt.transport.client.ReqChannel.factory", MagicMock())
@patch("salt.channel.client.ReqChannel.factory", MagicMock())
class PillarCacheTestCase(TestCase):
"""
Tests for instantiating a PillarCache in salt.pillar

View file

@ -1,53 +0,0 @@
import logging
from salt.transport import MessageClientPool
from tests.support.unit import TestCase
log = logging.getLogger(__name__)
class MessageClientPoolTest(TestCase):
class MockClass:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def test_init(self):
opts = {"sock_pool_size": 10}
args = (0,)
kwargs = {"kwarg": 1}
message_client_pool = MessageClientPool(
self.MockClass, opts, args=args, kwargs=kwargs
)
self.assertEqual(
opts["sock_pool_size"], len(message_client_pool.message_clients)
)
for message_client in message_client_pool.message_clients:
self.assertEqual(message_client.args, args)
self.assertEqual(message_client.kwargs, kwargs)
def test_init_without_config(self):
opts = {}
args = (0,)
kwargs = {"kwarg": 1}
message_client_pool = MessageClientPool(
self.MockClass, opts, args=args, kwargs=kwargs
)
# The size of pool is set as 1 by the MessageClientPool init method.
self.assertEqual(1, len(message_client_pool.message_clients))
for message_client in message_client_pool.message_clients:
self.assertEqual(message_client.args, args)
self.assertEqual(message_client.kwargs, kwargs)
def test_init_less_than_one(self):
opts = {"sock_pool_size": -1}
args = (0,)
kwargs = {"kwarg": 1}
message_client_pool = MessageClientPool(
self.MockClass, opts, args=args, kwargs=kwargs
)
# The size of pool is set as 1 by the MessageClientPool init method.
self.assertEqual(1, len(message_client_pool.message_clients))
for message_client in message_client_pool.message_clients:
self.assertEqual(message_client.args, args)
self.assertEqual(message_client.kwargs, kwargs)

View file

@ -470,6 +470,8 @@ class SSHThinTestCase(TestCase):
]
if salt.utils.thin.has_immutables:
base_tops.extend(["immutables"])
if thin.crypt:
base_tops.append(thin.crypt.__name__)
tops = []
for top in thin.get_tops(extra_mods="foo,bar"):
if top.find("/") != -1:
@ -567,6 +569,8 @@ class SSHThinTestCase(TestCase):
]
if salt.utils.thin.has_immutables:
base_tops.extend(["immutables"])
if thin.crypt:
base_tops.append(thin.crypt.__name__)
libs = salt.utils.thin.find_site_modules("contextvars")
foo = {"__file__": os.sep + os.path.join("custom", "foo", "__init__.py")}
bar = {"__file__": os.sep + os.path.join("custom", "bar")}
@ -672,6 +676,8 @@ class SSHThinTestCase(TestCase):
]
if salt.utils.thin.has_immutables:
base_tops.extend(["immutables"])
if thin.crypt:
base_tops.append(thin.crypt.__name__)
libs = salt.utils.thin.find_site_modules("contextvars")
with patch("salt.utils.thin.find_site_modules", MagicMock(side_effect=[libs])):
with patch(