Migrate the ZeroMQ and TCP transport unit tests to PyTest

This commit is contained in:
Pedro Algarvio 2020-12-18 17:04:40 +00:00
parent 719e010f1b
commit 34dfd90f74
14 changed files with 1168 additions and 1275 deletions

View file

@ -2,6 +2,8 @@
tests.pytests.conftest
~~~~~~~~~~~~~~~~~~~~~~
"""
import functools
import inspect
import logging
import os
import shutil
@ -9,6 +11,7 @@ import stat
import attr
import pytest
import salt.ext.tornado.ioloop
import salt.utils.files
import salt.utils.platform
from salt.serializers import yaml
@ -316,8 +319,116 @@ def salt_proxy_factory(salt_factories, salt_master_factory):
return factory
@pytest.fixture
def temp_salt_master(
request, salt_factories,
):
config_defaults = {
"open_mode": True,
"transport": request.config.getoption("--transport"),
}
factory = salt_factories.get_salt_master_daemon(
random_string("temp-master-"),
config_defaults=config_defaults,
extra_cli_arguments_after_first_start_failure=["--log-level=debug"],
)
return factory
@pytest.fixture
def temp_salt_minion(temp_salt_master):
config_defaults = {
"open_mode": True,
"transport": temp_salt_master.config["transport"],
}
factory = temp_salt_master.get_salt_minion_daemon(
random_string("temp-minion-"),
config_defaults=config_defaults,
extra_cli_arguments_after_first_start_failure=["--log-level=debug"],
)
factory.register_after_terminate_callback(
pytest.helpers.remove_stale_minion_key, temp_salt_master, factory.id
)
return factory
@pytest.fixture(scope="session")
def bridge_pytest_and_runtests():
"""
We're basically overriding the same fixture defined in tests/conftest.py
"""
# ----- Async Test Fixtures ----------------------------------------------------------------------------------------->
# This is based on https://github.com/eukaryote/pytest-tornasync
# The reason why we don't use that pytest plugin instead is because it has
# tornado as a dependency, and we need to use the tornado we ship with salt
def get_test_timeout(pyfuncitem):
default_timeout = 30
marker = pyfuncitem.get_closest_marker("timeout")
if marker:
return marker.kwargs.get("seconds") or default_timeout
return default_timeout
@pytest.mark.tryfirst
def pytest_pycollect_makeitem(collector, name, obj):
if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj):
return list(collector._genfunctions(name, obj))
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
if inspect.iscoroutinefunction(item.obj):
if "io_loop" not in item.fixturenames:
# Append the io_loop fixture for the async functions
item.fixturenames.append("io_loop")
class CoroTestFunction:
def __init__(self, func, kwargs):
self.func = func
self.kwargs = kwargs
functools.update_wrapper(self, func)
async def __call__(self):
ret = await self.func(**self.kwargs)
return ret
@pytest.mark.tryfirst
def pytest_pyfunc_call(pyfuncitem):
if not inspect.iscoroutinefunction(pyfuncitem.obj):
return
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
try:
loop = funcargs["io_loop"]
except KeyError:
loop = salt.ext.tornado.ioloop.IOLoop.current()
loop.run_sync(
CoroTestFunction(pyfuncitem.obj, testargs), timeout=get_test_timeout(pyfuncitem)
)
return True
@pytest.fixture
def io_loop():
"""
Create new io loop for each test, and tear it down after.
"""
loop = salt.ext.tornado.ioloop.IOLoop()
loop.make_current()
try:
yield loop
finally:
loop.clear_current()
loop.close(all_fds=True)
# <---- Async Test Fixtures ------------------------------------------------------------------------------------------

View file

@ -0,0 +1,42 @@
import pytest
from saltfactories.utils import random_string
def transport_ids(value):
return "Transport({})".format(value)
@pytest.fixture(params=("zeromq", "tcp"), ids=transport_ids)
def transport(request):
return request.param
@pytest.fixture
def salt_master(salt_factories, transport):
config_defaults = {
"transport": transport,
"auto_accept": True,
"sign_pub_messages": False,
}
factory = salt_factories.get_salt_master_daemon(
random_string("server-{}-master-".format(transport)),
config_defaults=config_defaults,
)
return factory
@pytest.fixture
def salt_minion(salt_master, transport):
config_defaults = {
"transport": transport,
"master_ip": "127.0.0.1",
"master_port": salt_master.config["ret_port"],
"auth_timeout": 5,
"auth_tries": 1,
"master_uri": "tcp://127.0.0.1:{}".format(salt_master.config["ret_port"]),
}
factory = salt_master.get_salt_minion_daemon(
random_string("server-{}-minion-".format(transport)),
config_defaults=config_defaults,
)
return factory

View file

@ -0,0 +1,145 @@
import logging
import signal
import pytest
import salt.config
import salt.exceptions
import salt.ext.tornado.gen
import salt.log.setup
import salt.transport.client
import salt.transport.server
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
log = logging.getLogger(__name__)
class ReqServerChannelProcess(salt.utils.process.SignalHandlingProcess):
def __init__(self, config, req_channel_crypt):
super().__init__()
self._closing = False
self.config = config
self.req_channel_crypt = req_channel_crypt
self.process_manager = salt.utils.process.ProcessManager(
name="ReqServer-ProcessManager"
)
self.req_server_channel = salt.transport.server.ReqServerChannel.factory(
self.config
)
self.req_server_channel.pre_fork(self.process_manager)
self.io_loop = None
def run(self):
self.io_loop = salt.ext.tornado.ioloop.IOLoop()
self.io_loop.make_current()
self.req_server_channel.post_fork(self._handle_payload, io_loop=self.io_loop)
try:
self.io_loop.start()
except KeyboardInterrupt:
pass
finally:
self.req_server_channel.close()
self.io_loop.clear_current()
self.io_loop.close(all_fds=True)
def _handle_signals(self, signum, sigframe):
self.close()
super()._handle_signals(signum, sigframe)
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.terminate()
def close(self):
if self._closing:
return
self._closing = True
if self.process_manager is None:
return
self.process_manager.stop_restarting()
self.process_manager.send_signal_to_processes(signal.SIGTERM)
self.process_manager.kill_children()
@salt.ext.tornado.gen.coroutine
def _handle_payload(self, payload):
if self.req_channel_crypt == "clear":
raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
raise salt.ext.tornado.gen.Return((payload, {"fun": "send"}))
@pytest.fixture
def req_server_channel(salt_master, req_channel_crypt):
req_server_channel_process = ReqServerChannelProcess(
salt_master.config.copy(), req_channel_crypt
)
with req_server_channel_process:
yield
def req_channel_crypt_ids(value):
return "ReqChannel(crypt='{}')".format(value)
@pytest.fixture(params=["clear", "aes"], ids=req_channel_crypt_ids)
def req_channel_crypt(request):
return request.param
@pytest.fixture
def req_channel(req_server_channel, salt_minion, req_channel_crypt):
with salt.transport.client.ReqChannel.factory(
salt_minion.config, crypt=req_channel_crypt
) as _req_channel:
try:
yield _req_channel
finally:
_req_channel.force_close_all_instances()
def test_basic(req_channel):
"""
Test a variety of messages, make sure we get the expected responses
"""
msgs = [
{"foo": "bar"},
{"bar": "baz"},
{"baz": "qux", "list": [1, 2, 3]},
]
for msg in msgs:
ret = req_channel.send(msg, timeout=5, tries=1)
assert ret["load"] == msg
def test_normalization(req_channel):
"""
Since we use msgpack, we need to test that list types are converted to lists
"""
types = {
"list": list,
}
msgs = [
{"list": tuple([1, 2, 3])},
]
for msg in msgs:
ret = req_channel.send(msg, timeout=5, tries=1)
for key, value in ret["load"].items():
assert types[key] == type(value)
def test_badload(req_channel, req_channel_crypt):
"""
Test a variety of bad requests, make sure that we get some sort of error
"""
msgs = ["", [], tuple()]
if req_channel_crypt == "clear":
for msg in msgs:
ret = req_channel.send(msg, timeout=5, tries=1)
assert ret == "payload and load must be a dict"
else:
for msg in msgs:
with pytest.raises(salt.exceptions.AuthenticationError):
req_channel.send(msg, timeout=5, tries=1)

View file

@ -0,0 +1,31 @@
import pytest
from saltfactories.utils import random_string
@pytest.fixture
def salt_master(salt_factories):
config_defaults = {
"transport": "zeromq",
"auto_accept": True,
"sign_pub_messages": False,
}
factory = salt_factories.get_salt_master_daemon(
random_string("zeromq-master-"), config_defaults=config_defaults
)
return factory
@pytest.fixture
def salt_minion(salt_master):
config_defaults = {
"transport": "zeromq",
"master_ip": "127.0.0.1",
"master_port": salt_master.config["ret_port"],
"auth_timeout": 5,
"auth_tries": 1,
"master_uri": "tcp://127.0.0.1:{}".format(salt_master.config["ret_port"]),
}
factory = salt_master.get_salt_minion_daemon(
random_string("zeromq-minion-"), config_defaults=config_defaults
)
return factory

View file

@ -0,0 +1,296 @@
import ctypes
import logging
import multiprocessing
import signal
import time
from concurrent.futures.thread import ThreadPoolExecutor
import pytest
import salt.config
import salt.exceptions
import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.log.setup
import salt.transport.client
import salt.transport.server
import salt.transport.zeromq
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
import zmq.eventloop.ioloop
from tests.support.helpers import slowTest
from tests.support.mock import MagicMock, patch
log = logging.getLogger(__name__)
class Collector(salt.utils.process.SignalHandlingProcess):
def __init__(
self, minion_config, pub_uri, aes_key, timeout=30, zmq_filtering=False
):
super().__init__()
self.minion_config = minion_config
self.pub_uri = pub_uri
self.aes_key = aes_key
self.timeout = timeout
self.hard_timeout = time.time() + timeout + 30
self.manager = multiprocessing.Manager()
self.results = self.manager.list()
self.zmq_filtering = zmq_filtering
self.stopped = multiprocessing.Event()
self.running = multiprocessing.Event()
def run(self):
"""
Gather results until then number of seconds specified by timeout passes
without receiving a message
"""
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
sock.setsockopt(zmq.LINGER, -1)
sock.setsockopt(zmq.SUBSCRIBE, b"")
sock.connect(self.pub_uri)
last_msg = time.time()
serial = salt.payload.Serial(self.minion_config)
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
self.running.set()
while True:
curr_time = time.time()
if time.time() > self.hard_timeout:
break
if curr_time - last_msg >= self.timeout:
break
try:
payload = sock.recv(zmq.NOBLOCK)
except zmq.ZMQError:
time.sleep(0.01)
else:
try:
serial_payload = serial.loads(payload)
payload = crypticle.loads(serial_payload["load"])
if "start" in payload:
self.running.set()
continue
if "stop" in payload:
break
last_msg = time.time()
self.results.append(payload["jid"])
except salt.exceptions.SaltDeserializationError:
if not self.zmq_filtering:
log.exception("Failed to deserialize...")
break
def __enter__(self):
self.manager.__enter__()
self.start()
# Wait until we can start receiving events
self.running.wait()
self.running.clear()
return self
def __exit__(self, *args):
# Wait until we either processed all expected messages or we reach the hard timeout
join_secs = self.hard_timeout - time.time()
log.info("Waiting at most %s seconds before exiting the collector", join_secs)
self.join(join_secs)
self.terminate()
# Cast our manager.list into a plain list
self.results = list(self.results)
# Terminate our multiprocessing manager
self.manager.__exit__(*args)
log.debug("The collector has exited")
self.stopped.set()
class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
def __init__(self, master_config, minion_config, **collector_kwargs):
super().__init__()
self._closing = False
self.master_config = master_config
self.minion_config = minion_config
self.collector_kwargs = collector_kwargs
self.aes_key = multiprocessing.Array(
ctypes.c_char,
salt.utils.stringutils.to_bytes(salt.crypt.Crypticle.generate_key_string()),
)
self.process_manager = salt.utils.process.ProcessManager(
name="ZMQ-PubServer-ProcessManager"
)
self.pub_server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(
self.master_config
)
self.pub_server_channel.pre_fork(
self.process_manager,
kwargs={"log_queue": salt.log.setup.get_multiprocessing_logging_queue()},
)
self.pub_uri = "tcp://{interface}:{publish_port}".format(**self.master_config)
self.queue = multiprocessing.Queue()
self.stopped = multiprocessing.Event()
self.collector = Collector(
self.minion_config,
self.pub_uri,
self.aes_key.value,
**self.collector_kwargs
)
def run(self):
salt.master.SMaster.secrets["aes"] = {"secret": self.aes_key}
try:
while True:
payload = self.queue.get()
if payload is None:
log.debug("We received the stop sentinal")
break
self.pub_server_channel.publish(payload)
except KeyboardInterrupt:
pass
finally:
self.stopped.set()
def _handle_signals(self, signum, sigframe):
self.close()
super()._handle_signals(signum, sigframe)
def close(self):
if self._closing:
return
self._closing = True
if self.process_manager is None:
return
self.process_manager.stop_restarting()
self.process_manager.send_signal_to_processes(signal.SIGTERM)
self.pub_server_channel.pub_close()
self.process_manager.kill_children()
self.process_manager = None
def publish(self, payload):
self.queue.put(payload)
def __enter__(self):
self.start()
self.collector.__enter__()
attempts = 10
while attempts > 0:
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "start": True})
if self.collector.running.wait(1) is True:
break
attempts -= 1
else:
pytest.fail("Failed to confirm the collector has started")
return self
def __exit__(self, *args):
# Publish a payload to tell the collection it's done processing
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
# Now trigger the collector to also exit
self.collector.__exit__(*args)
# We can safely wait here without a timeout because the Collector instance has a
# hard timeout set, so eventually Collector.stopped will be set
self.collector.stopped.wait()
# Stop our own processing
self.queue.put(None)
# Wait at most 10 secs for the above `None` in the queue to be processed
self.stopped.wait(10)
self.close()
self.terminate()
log.info("The PubServerChannelProcess has terminated")
@slowTest
@pytest.mark.skip_on_windows
def test_publish_to_pubserv_ipc(salt_master, salt_minion):
"""
Test sending 10K messags to ZeroMQPubServerChannel using IPC transport
ZMQ's ipc transport not supported on Windows
"""
opts = dict(salt_master.config.copy(), ipc_mode="ipc", pub_hwm=0)
with PubServerChannelProcess(opts, salt_minion.config.copy()) as server_channel:
send_num = 10000
expect = []
for idx in range(send_num):
expect.append(idx)
load = {"tgt_type": "glob", "tgt": "*", "jid": idx}
server_channel.publish(load)
results = server_channel.collector.results
assert len(results) == send_num, "{} != {}, difference: {}".format(
len(results), send_num, set(expect).difference(results)
)
@slowTest
@pytest.mark.skip_on_freebsd
def test_issue_36469_tcp(salt_master, salt_minion):
"""
Test sending both large and small messags to publisher using TCP
https://github.com/saltstack/salt/issues/36469
"""
def _send_small(server_channel, sid, num=10):
for idx in range(num):
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-s{}".format(sid, idx)}
server_channel.publish(load)
def _send_large(server_channel, sid, num=10, size=250000 * 3):
for idx in range(num):
load = {
"tgt_type": "glob",
"tgt": "*",
"jid": "{}-l{}".format(sid, idx),
"xdata": "0" * size,
}
server_channel.publish(load)
opts = dict(salt_master.config.copy(), ipc_mode="tcp", pub_hwm=0)
send_num = 10 * 4
expect = []
with PubServerChannelProcess(opts, salt_minion.config.copy()) as server_channel:
with ThreadPoolExecutor(max_workers=4) as executor:
executor.submit(_send_small, server_channel, 1)
executor.submit(_send_large, server_channel, 2)
executor.submit(_send_small, server_channel, 3)
executor.submit(_send_large, server_channel, 4)
expect.extend(["{}-s{}".format(a, b) for a in range(10) for b in (1, 3)])
expect.extend(["{}-l{}".format(a, b) for a in range(10) for b in (2, 4)])
results = server_channel.collector.results
assert len(results) == send_num, "{} != {}, difference: {}".format(
len(results), send_num, set(expect).difference(results)
)
@slowTest
@pytest.mark.skip_on_windows
def test_zeromq_filtering(salt_master, salt_minion):
"""
Test sending messages to publisher using UDP with zeromq_filtering enabled
"""
opts = dict(
salt_master.config.copy(),
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
acceptance_wait_time=5,
)
send_num = 1
expect = []
with patch(
"salt.utils.minions.CkMinions.check_minions",
MagicMock(
return_value={
"minions": [salt_minion.id],
"missing": [],
"ssh_minions": False,
}
),
):
with PubServerChannelProcess(
opts, salt_minion.config.copy(), zmq_filtering=True
) as server_channel:
expect.append(send_num)
load = {"tgt_type": "glob", "tgt": "*", "jid": send_num}
server_channel.publish(load)
results = server_channel.collector.results
assert len(results) == send_num, "{} != {}, difference: {}".format(
len(results), send_num, set(expect).difference(results)
)

View file

View file

@ -0,0 +1,251 @@
import socket
import attr
import pytest
import salt.exceptions
import salt.transport.tcp
from salt.ext.tornado import concurrent, gen, ioloop
from saltfactories.utils.ports import get_unused_localhost_port
from tests.support.mock import MagicMock, patch
@pytest.fixture
def message_client_pool():
sock_pool_size = 5
opts = {"sock_pool_size": sock_pool_size}
message_client_args = (
{}, # opts,
"", # host
0, # port
)
with patch(
"salt.transport.tcp.SaltMessageClient.__init__", MagicMock(return_value=None),
):
message_client_pool = salt.transport.tcp.SaltMessageClientPool(
opts, args=message_client_args
)
original_message_clients = message_client_pool.message_clients[:]
message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
try:
yield message_client_pool
finally:
with patch(
"salt.transport.tcp.SaltMessageClient.close", MagicMock(return_value=None)
):
del original_message_clients
class TestSaltMessageClientPool:
def test_send(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
assert message_client_pool.send() == []
message_client_pool.message_clients[2].send_queue = [0]
message_client_pool.message_clients[2].send.return_value = [1]
assert message_client_pool.send() == [1]
def test_write_to_stream(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock._stream.write.return_value = []
assert message_client_pool.write_to_stream("") == []
message_client_pool.message_clients[2].send_queue = [0]
message_client_pool.message_clients[2]._stream.write.return_value = [1]
assert message_client_pool.write_to_stream("") == [1]
def test_close(self, message_client_pool):
message_client_pool.close()
assert message_client_pool.message_clients == []
def test_on_recv(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
message_client_mock.on_recv.return_value = None
message_client_pool.on_recv()
for message_client_mock in message_client_pool.message_clients:
assert message_client_mock.on_recv.called
async def test_connect_all(self, message_client_pool):
for message_client_mock in message_client_pool.message_clients:
future = concurrent.Future()
future.set_result("foo")
message_client_mock.connect.return_value = future
connected = await message_client_pool.connect()
assert connected is None
async def test_connect_partial(self, io_loop, message_client_pool):
for idx, message_client_mock in enumerate(message_client_pool.message_clients):
future = concurrent.Future()
if idx % 2 == 0:
future.set_result("foo")
message_client_mock.connect.return_value = future
with pytest.raises(gen.TimeoutError):
future = message_client_pool.connect()
await gen.with_timeout(io_loop.time() + 0.1, future)
@attr.s(frozen=True, slots=True)
class ClientSocket:
listen_on = attr.ib(init=False, default="127.0.0.1")
port = attr.ib(init=False, default=attr.Factory(get_unused_localhost_port))
sock = attr.ib(init=False, repr=False)
@sock.default
def _sock_default(self):
return socket.socket(socket.AF_INET, socket.SOCK_STREAM)
def __enter__(self):
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((self.listen_on, self.port))
self.sock.listen(1)
return self
def __exit__(self, *args):
self.sock.close()
@pytest.fixture
def client_socket():
with ClientSocket() as _client_socket:
yield _client_socket
def test_message_client_cleanup_on_close(client_socket, temp_salt_master):
"""
test message client cleanup on close
"""
orig_loop = ioloop.IOLoop()
orig_loop.make_current()
opts = temp_salt_master.config.copy()
client = salt.transport.tcp.SaltMessageClient(
opts, client_socket.listen_on, client_socket.port
)
# Mock the io_loop's stop method so we know when it has been called.
orig_loop.real_stop = orig_loop.stop
orig_loop.stop_called = False
def stop(*args, **kwargs):
orig_loop.stop_called = True
orig_loop.real_stop()
orig_loop.stop = stop
try:
assert client.io_loop == orig_loop
client.io_loop.run_sync(client.connect)
# Ensure we are testing the _read_until_future and io_loop teardown
assert client._stream is not None
assert client._read_until_future is not None
assert orig_loop.stop_called is True
# The run_sync call will set stop_called, reset it
orig_loop.stop_called = False
client.close()
# Stop should be called again, client's io_loop should be None
assert orig_loop.stop_called is True
assert client.io_loop is None
finally:
orig_loop.stop = orig_loop.real_stop
del orig_loop.real_stop
del orig_loop.stop_called
orig_loop.clear_current()
orig_loop.close(all_fds=True)
async def test_async_tcp_pub_channel_connect_publish_port(
temp_salt_master, client_socket
):
"""
test when publish_port is not 4506
"""
opts = dict(
temp_salt_master.config.copy(),
master_uri="",
master_ip="127.0.0.1",
publish_port=1234,
)
channel = salt.transport.tcp.AsyncTCPPubChannel(opts)
patch_auth = MagicMock(return_value=True)
patch_client_pool = MagicMock(spec=salt.transport.tcp.SaltMessageClientPool)
with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
"salt.crypt.AsyncAuth.authenticated", patch_auth
), patch("salt.transport.tcp.SaltMessageClientPool", patch_client_pool):
with channel:
# We won't be able to succeed the connection because we're not mocking the tornado coroutine
with pytest.raises(salt.exceptions.SaltClientError):
await channel.connect()
# The first call to the mock is the instance's __init__, and the first argument to those calls is the opts dict
assert patch_client_pool.call_args[0][0]["publish_port"] == opts["publish_port"]
def test_tcp_pub_server_channel_publish_filtering(temp_salt_master):
opts = dict(temp_salt_master.config.copy(), sign_pub_messages=False)
with patch("salt.master.SMaster.secrets") as secrets, patch(
"salt.crypt.Crypticle"
) as crypticle, patch("salt.utils.asynchronous.SyncWrapper") as SyncWrapper:
channel = salt.transport.tcp.TCPPubServerChannel(opts)
wrap = MagicMock()
crypt = MagicMock()
crypt.dumps.return_value = {"test": "value"}
secrets.return_value = {"aes": {"secret": None}}
crypticle.return_value = crypt
SyncWrapper.return_value = wrap
# try simple publish with glob tgt_type
channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
payload = wrap.send.call_args[0][0]
# verify we send it without any specific topic
assert "topic_lst" not in payload
# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
# verify we send it with correct topic
assert "topic_lst" in payload
assert payload["topic_lst"] == ["minion01"]
# try with syndic settings
opts["order_masters"] = True
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
# verify we send it without topic for syndics
assert "topic_lst" not in payload
def test_tcp_pub_server_channel_publish_filtering_str_list(temp_salt_master):
opts = dict(temp_salt_master.config.copy(), sign_pub_messages=False)
with patch("salt.master.SMaster.secrets") as secrets, patch(
"salt.crypt.Crypticle"
) as crypticle, patch("salt.utils.asynchronous.SyncWrapper") as SyncWrapper, patch(
"salt.utils.minions.CkMinions.check_minions"
) as check_minions:
channel = salt.transport.tcp.TCPPubServerChannel(opts)
wrap = MagicMock()
crypt = MagicMock()
crypt.dumps.return_value = {"test": "value"}
secrets.return_value = {"aes": {"secret": None}}
crypticle.return_value = crypt
SyncWrapper.return_value = wrap
check_minions.return_value = {"minions": ["minion02"]}
# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
payload = wrap.send.call_args[0][0]
# verify we send it with correct topic
assert "topic_lst" in payload
assert payload["topic_lst"] == ["minion02"]
# verify it was correctly calling check_minions
check_minions.assert_called_with("minion02", tgt_type="list")

View file

@ -0,0 +1,253 @@
"""
:codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
"""
import hashlib
import salt.config
import salt.exceptions
import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.log.setup
import salt.transport.client
import salt.transport.server
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
from salt.transport.zeromq import AsyncReqMessageClientPool
from tests.support.mock import MagicMock, call, patch
def test_master_uri():
"""
test _get_master_uri method
"""
m_ip = "127.0.0.1"
m_port = 4505
s_ip = "111.1.0.1"
s_port = 4058
m_ip6 = "1234:5678::9abc"
s_ip6 = "1234:5678::1:9abc"
with patch("salt.transport.zeromq.LIBZMQ_VERSION_INFO", (4, 1, 6)), patch(
"salt.transport.zeromq.ZMQ_VERSION_INFO", (16, 0, 1)
):
# pass in both source_ip and source_port
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port, source_ip=s_ip, source_port=s_port
) == "tcp://{}:{};{}:{}".format(s_ip, s_port, m_ip, m_port)
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip6, master_port=m_port, source_ip=s_ip6, source_port=s_port
) == "tcp://[{}]:{};[{}]:{}".format(s_ip6, s_port, m_ip6, m_port)
# source ip and source_port empty
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port
) == "tcp://{}:{}".format(m_ip, m_port)
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip6, master_port=m_port
) == "tcp://[{}]:{}".format(m_ip6, m_port)
# pass in only source_ip
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port, source_ip=s_ip
) == "tcp://{}:0;{}:{}".format(s_ip, m_ip, m_port)
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip6, master_port=m_port, source_ip=s_ip6
) == "tcp://[{}]:0;[{}]:{}".format(s_ip6, m_ip6, m_port)
# pass in only source_port
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port, source_port=s_port
) == "tcp://0.0.0.0:{};{}:{}".format(s_port, m_ip, m_port)
def test_force_close_all_instances():
zmq1 = MagicMock()
zmq2 = MagicMock()
zmq3 = MagicMock()
zmq_objects = {"zmq": {"1": zmq1, "2": zmq2}, "other_zmq": {"3": zmq3}}
with patch("salt.transport.zeromq.AsyncZeroMQReqChannel.instance_map", zmq_objects):
salt.transport.zeromq.AsyncZeroMQReqChannel.force_close_all_instances()
assert zmq1.mock_calls == [call.close()]
assert zmq2.mock_calls == [call.close()]
assert zmq3.mock_calls == [call.close()]
# check if instance map changed
assert zmq_objects is salt.transport.zeromq.AsyncZeroMQReqChannel.instance_map
def test_async_req_message_client_pool_send():
sock_pool_size = 5
with patch(
"salt.transport.zeromq.AsyncReqMessageClient.__init__",
MagicMock(return_value=None),
):
message_client_pool = AsyncReqMessageClientPool(
{"sock_pool_size": sock_pool_size}, args=({}, "")
)
message_client_pool.message_clients = [
MagicMock() for _ in range(sock_pool_size)
]
for message_client_mock in message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
with message_client_pool:
assert message_client_pool.send() == []
message_client_pool.message_clients[2].send_queue = [0]
message_client_pool.message_clients[2].send.return_value = [1]
assert message_client_pool.send() == [1]
def test_clear_req_channel_master_uri_override(temp_salt_minion, temp_salt_master):
"""
ensure master_uri kwarg is respected
"""
opts = temp_salt_minion.config.copy()
# minion_config should be 127.0.0.1, we want a different uri that still connects
opts.update(
{
"id": "root",
"transport": "zeromq",
"auth_tries": 1,
"auth_timeout": 5,
"master_ip": "127.0.0.1",
"master_port": temp_salt_master.config["ret_port"],
"master_uri": "tcp://127.0.0.1:{}".format(
temp_salt_master.config["ret_port"]
),
}
)
master_uri = "tcp://{master_ip}:{master_port}".format(
master_ip="localhost", master_port=opts["master_port"]
)
with salt.transport.client.ReqChannel.factory(
opts, master_uri=master_uri
) as channel:
assert "localhost" in channel.master_uri
def test_zeromq_async_pub_channel_publish_port(temp_salt_master):
"""
test when connecting that we use the publish_port set in opts when its not 4506
"""
opts = dict(
temp_salt_master.config.copy(),
ipc_mode="ipc",
pub_hwm=0,
recon_randomize=False,
publish_port=455505,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
sign_pub_messages=False,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
with channel:
patch_socket = MagicMock(return_value=True)
patch_auth = MagicMock(return_value=True)
with patch.object(channel, "_socket", patch_socket), patch.object(
channel, "auth", patch_auth
):
channel.connect()
assert str(opts["publish_port"]) in patch_socket.mock_calls[0][1][0]
def test_zeromq_async_pub_channel_filtering_decode_message_no_match(temp_salt_master,):
"""
test AsyncZeroMQPubChannel _decode_messages when
zmq_filtering enabled and minion does not match
"""
message = [
b"4f26aeafdb2367620a393c973eddbe8f8b846eb",
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
]
opts = dict(
temp_salt_master.config.copy(),
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
recon_randomize=False,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
sign_pub_messages=False,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
with channel:
with patch(
"salt.crypt.AsyncAuth.crypticle",
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
):
res = channel._decode_messages(message)
assert res.result() is None
def test_zeromq_async_pub_channel_filtering_decode_message(
temp_salt_master, temp_salt_minion
):
"""
test AsyncZeroMQPubChannel _decode_messages when zmq_filtered enabled
"""
minion_hexid = salt.utils.stringutils.to_bytes(
hashlib.sha1(salt.utils.stringutils.to_bytes(temp_salt_minion.id)).hexdigest()
)
message = [
minion_hexid,
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
]
opts = dict(
temp_salt_master.config.copy(),
id=temp_salt_minion.id,
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
recon_randomize=False,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
sign_pub_messages=False,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
with channel:
with patch(
"salt.crypt.AsyncAuth.crypticle",
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
) as mock_test:
res = channel._decode_messages(message)
assert res.result()["enc"] == "aes"

View file

@ -1,16 +1,5 @@
# -*- coding: utf-8 -*-
# Import Python libs
from __future__ import absolute_import, print_function, unicode_literals
import salt.ext.tornado.gen
# Import Salt Libs
import salt.transport.client
# Import 3rd-party libs
from salt.ext import six
def run_loop_in_thread(loop, evt):
"""
@ -31,79 +20,3 @@ def run_loop_in_thread(loop, evt):
loop.start()
finally:
loop.close()
class ReqChannelMixin(object):
def test_basic(self):
"""
Test a variety of messages, make sure we get the expected responses
"""
msgs = [
{"foo": "bar"},
{"bar": "baz"},
{"baz": "qux", "list": [1, 2, 3]},
]
for msg in msgs:
ret = self.channel.send(msg, timeout=2, tries=1)
self.assertEqual(ret["load"], msg)
def test_normalization(self):
"""
Since we use msgpack, we need to test that list types are converted to lists
"""
types = {
"list": list,
}
msgs = [
{"list": tuple([1, 2, 3])},
]
for msg in msgs:
ret = self.channel.send(msg, timeout=2, tries=1)
for k, v in six.iteritems(ret["load"]):
self.assertEqual(types[k], type(v))
def test_badload(self):
"""
Test a variety of bad requests, make sure that we get some sort of error
"""
msgs = ["", [], tuple()]
for msg in msgs:
ret = self.channel.send(msg, timeout=2, tries=1)
self.assertEqual(ret, "payload and load must be a dict")
class PubChannelMixin(object):
def test_basic(self):
self.pub = None
def handle_pub(ret):
self.pub = ret
self.stop()
self.pub_channel = salt.transport.client.AsyncPubChannel.factory(
self.minion_opts, io_loop=self.io_loop
)
connect_future = self.pub_channel.connect()
connect_future.add_done_callback(lambda f: self.stop())
self.wait()
connect_future.result()
self.pub_channel.on_recv(handle_pub)
load = {
"fun": "f",
"arg": "a",
"tgt": "t",
"jid": "j",
"ret": "r",
"tgt_type": "glob",
}
self.server_channel.publish(load)
self.wait()
self.assertEqual(self.pub["load"], load)
self.pub_channel.on_recv(None)
self.server_channel.publish(load)
with self.assertRaises(self.failureException):
self.wait(timeout=0.5)
# close our pub_channel, to pass our FD checks
self.pub_channel.close()
del self.pub_channel

View file

@ -3,7 +3,6 @@
"""
import logging
import socket
import threading
import pytest
@ -16,22 +15,11 @@ import salt.transport.client
import salt.transport.server
import salt.utils.platform
import salt.utils.process
from salt.ext.tornado.testing import AsyncTestCase, gen_test
from salt.transport.tcp import (
SaltMessageClient,
SaltMessageClientPool,
TCPPubServerChannel,
)
from salt.ext.tornado.testing import AsyncTestCase
from saltfactories.utils.ports import get_unused_localhost_port
from tests.support.helpers import flaky, slowTest
from tests.support.mixins import AdaptedConfigurationTestCaseMixin
from tests.support.mock import MagicMock, patch
from tests.support.unit import TestCase, skipIf
from tests.unit.transport.mixins import (
PubChannelMixin,
ReqChannelMixin,
run_loop_in_thread,
)
from tests.support.unit import skipIf
from tests.unit.transport.mixins import run_loop_in_thread
pytestmark = [
pytest.mark.skip_on_darwin,
@ -41,136 +29,10 @@ pytestmark = [
log = logging.getLogger(__name__)
class BaseTCPReqCase(TestCase, AdaptedConfigurationTestCaseMixin):
@skipIf(True, "Skip until we can devote time to fix this test")
class AsyncPubChannelTest(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
"""
Test the req server/client pair
"""
@classmethod
def setUpClass(cls):
if not hasattr(cls, "_handle_payload"):
return
ret_port = get_unused_localhost_port()
publish_port = get_unused_localhost_port()
tcp_master_pub_port = get_unused_localhost_port()
tcp_master_pull_port = get_unused_localhost_port()
tcp_master_publish_pull = get_unused_localhost_port()
tcp_master_workers = get_unused_localhost_port()
cls.master_config = cls.get_temp_config(
"master",
**{
"transport": "tcp",
"auto_accept": True,
"ret_port": ret_port,
"publish_port": publish_port,
"tcp_master_pub_port": tcp_master_pub_port,
"tcp_master_pull_port": tcp_master_pull_port,
"tcp_master_publish_pull": tcp_master_publish_pull,
"tcp_master_workers": tcp_master_workers,
}
)
cls.minion_config = cls.get_temp_config(
"minion",
**{
"transport": "tcp",
"master_ip": "127.0.0.1",
"master_port": ret_port,
"master_uri": "tcp://127.0.0.1:{}".format(ret_port),
}
)
cls.process_manager = salt.utils.process.ProcessManager(
name="ReqServer_ProcessManager"
)
cls.server_channel = salt.transport.server.ReqServerChannel.factory(
cls.master_config
)
cls.server_channel.pre_fork(cls.process_manager)
cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
cls.stop = threading.Event()
cls.server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
cls.server_thread = threading.Thread(
target=run_loop_in_thread, args=(cls.io_loop, cls.stop,),
)
cls.server_thread.start()
@classmethod
def tearDownClass(cls):
cls.server_channel.close()
cls.stop.set()
cls.server_thread.join()
cls.process_manager.kill_children()
del cls.server_channel
@classmethod
@salt.ext.tornado.gen.coroutine
def _handle_payload(cls, payload):
"""
TODO: something besides echo
"""
raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
@skipIf(salt.utils.platform.is_darwin(), "hanging test suite on MacOS")
class ClearReqTestCases(BaseTCPReqCase, ReqChannelMixin):
"""
Test all of the clear msg stuff
"""
def setUp(self):
self.channel = salt.transport.client.ReqChannel.factory(
self.minion_config, crypt="clear"
)
def tearDown(self):
self.channel.close()
del self.channel
@classmethod
@salt.ext.tornado.gen.coroutine
def _handle_payload(cls, payload):
"""
TODO: something besides echo
"""
raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
@skipIf(salt.utils.platform.is_darwin(), "hanging test suite on MacOS")
class AESReqTestCases(BaseTCPReqCase, ReqChannelMixin):
def setUp(self):
self.channel = salt.transport.client.ReqChannel.factory(self.minion_config)
def tearDown(self):
self.channel.close()
del self.channel
@classmethod
@salt.ext.tornado.gen.coroutine
def _handle_payload(cls, payload):
"""
TODO: something besides echo
"""
raise salt.ext.tornado.gen.Return((payload, {"fun": "send"}))
# TODO: make failed returns have a specific framing so we can raise the same exception
# on encrypted channels
@flaky
@slowTest
def test_badload(self):
"""
Test a variety of bad requests, make sure that we get some sort of error
"""
msgs = ["", [], tuple()]
for msg in msgs:
with self.assertRaises(salt.exceptions.AuthenticationError):
ret = self.channel.send(msg)
class BaseTCPPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
"""
Test the req server/client pair
Tests around the publish system
"""
@classmethod
@ -259,235 +121,39 @@ class BaseTCPPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
del self.channel
del self._start_handlers
def test_basic(self):
self.pub = None
class AsyncTCPPubChannelTest(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
@slowTest
def test_connect_publish_port(self):
"""
test when publish_port is not 4506
"""
opts = self.get_temp_config("master")
opts["master_uri"] = ""
opts["master_ip"] = "127.0.0.1"
opts["publish_port"] = 1234
channel = salt.transport.tcp.AsyncTCPPubChannel(opts)
patch_auth = MagicMock(return_value=True)
patch_client = MagicMock(spec=SaltMessageClientPool)
with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
"salt.crypt.AsyncAuth.authenticated", patch_auth
), patch("salt.transport.tcp.SaltMessageClientPool", patch_client):
channel.connect()
assert patch_client.call_args[0][0]["publish_port"] == opts["publish_port"]
def handle_pub(ret):
self.pub = ret
self.stop() # pylint: disable=not-callable
self.pub_channel = salt.transport.client.AsyncPubChannel.factory(
self.minion_opts, io_loop=self.io_loop
)
connect_future = self.pub_channel.connect()
connect_future.add_done_callback(
lambda f: self.stop() # pylint: disable=not-callable
)
self.wait()
connect_future.result()
self.pub_channel.on_recv(handle_pub)
load = {
"fun": "f",
"arg": "a",
"tgt": "t",
"jid": "j",
"ret": "r",
"tgt_type": "glob",
}
self.server_channel.publish(load)
self.wait()
self.assertEqual(self.pub["load"], load)
self.pub_channel.on_recv(None)
self.server_channel.publish(load)
with self.assertRaises(self.failureException):
self.wait(timeout=0.5)
@skipIf(True, "Skip until we can devote time to fix this test")
class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
"""
Tests around the publish system
"""
class SaltMessageClientPoolTest(AsyncTestCase):
def setUp(self):
super().setUp()
sock_pool_size = 5
with patch(
"salt.transport.tcp.SaltMessageClient.__init__",
MagicMock(return_value=None),
):
self.message_client_pool = SaltMessageClientPool(
{"sock_pool_size": sock_pool_size}, args=({}, "", 0)
)
self.original_message_clients = self.message_client_pool.message_clients
self.message_client_pool.message_clients = [
MagicMock() for _ in range(sock_pool_size)
]
def tearDown(self):
with patch(
"salt.transport.tcp.SaltMessageClient.close", MagicMock(return_value=None)
):
del self.original_message_clients
super().tearDown()
def test_send(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
self.assertEqual([], self.message_client_pool.send())
self.message_client_pool.message_clients[2].send_queue = [0]
self.message_client_pool.message_clients[2].send.return_value = [1]
self.assertEqual([1], self.message_client_pool.send())
def test_write_to_stream(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock._stream.write.return_value = []
self.assertEqual([], self.message_client_pool.write_to_stream(""))
self.message_client_pool.message_clients[2].send_queue = [0]
self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
self.assertEqual([1], self.message_client_pool.write_to_stream(""))
def test_close(self):
self.message_client_pool.close()
self.assertEqual([], self.message_client_pool.message_clients)
def test_on_recv(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.on_recv.return_value = None
self.message_client_pool.on_recv()
for message_client_mock in self.message_client_pool.message_clients:
self.assertTrue(message_client_mock.on_recv.called)
def test_connect_all(self):
@gen_test
def test_connect(self):
yield self.message_client_pool.connect()
for message_client_mock in self.message_client_pool.message_clients:
future = salt.ext.tornado.concurrent.Future()
future.set_result("foo")
message_client_mock.connect.return_value = future
self.assertIsNone(test_connect(self))
def test_connect_partial(self):
@gen_test(timeout=0.1)
def test_connect(self):
yield self.message_client_pool.connect()
for idx, message_client_mock in enumerate(
self.message_client_pool.message_clients
):
future = salt.ext.tornado.concurrent.Future()
if idx % 2 == 0:
future.set_result("foo")
message_client_mock.connect.return_value = future
with self.assertRaises(salt.ext.tornado.ioloop.TimeoutError):
test_connect(self)
class SaltMessageClientCleanupTest(TestCase, AdaptedConfigurationTestCaseMixin):
def setUp(self):
self.listen_on = "127.0.0.1"
self.port = get_unused_localhost_port()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((self.listen_on, self.port))
self.sock.listen(1)
def tearDown(self):
self.sock.close()
del self.sock
def test_message_client(self):
"""
test message client cleanup on close
"""
orig_loop = salt.ext.tornado.ioloop.IOLoop()
orig_loop.make_current()
opts = self.get_temp_config("master")
client = SaltMessageClient(opts, self.listen_on, self.port)
# Mock the io_loop's stop method so we know when it has been called.
orig_loop.real_stop = orig_loop.stop
orig_loop.stop_called = False
def stop(*args, **kwargs):
orig_loop.stop_called = True
orig_loop.real_stop()
orig_loop.stop = stop
try:
assert client.io_loop == orig_loop
client.io_loop.run_sync(client.connect)
# Ensure we are testing the _read_until_future and io_loop teardown
assert client._stream is not None
assert client._read_until_future is not None
assert orig_loop.stop_called is True
# The run_sync call will set stop_called, reset it
orig_loop.stop_called = False
client.close()
# Stop should be called again, client's io_loop should be None
assert orig_loop.stop_called is True
assert client.io_loop is None
finally:
orig_loop.stop = orig_loop.real_stop
del orig_loop.real_stop
del orig_loop.stop_called
class TCPPubServerChannelTest(TestCase, AdaptedConfigurationTestCaseMixin):
@patch("salt.master.SMaster.secrets")
@patch("salt.crypt.Crypticle")
@patch("salt.utils.asynchronous.SyncWrapper")
def test_publish_filtering(self, sync_wrapper, crypticle, secrets):
opts = self.get_temp_config("master")
opts["sign_pub_messages"] = False
channel = TCPPubServerChannel(opts)
wrap = MagicMock()
crypt = MagicMock()
crypt.dumps.return_value = {"test": "value"}
secrets.return_value = {"aes": {"secret": None}}
crypticle.return_value = crypt
sync_wrapper.return_value = wrap
# try simple publish with glob tgt_type
channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
payload = wrap.send.call_args[0][0]
# verify we send it without any specific topic
assert "topic_lst" not in payload
# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
# verify we send it with correct topic
assert "topic_lst" in payload
self.assertEqual(payload["topic_lst"], ["minion01"])
# try with syndic settings
opts["order_masters"] = True
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
# verify we send it without topic for syndics
assert "topic_lst" not in payload
@patch("salt.utils.minions.CkMinions.check_minions")
@patch("salt.master.SMaster.secrets")
@patch("salt.crypt.Crypticle")
@patch("salt.utils.asynchronous.SyncWrapper")
def test_publish_filtering_str_list(
self, sync_wrapper, crypticle, secrets, check_minions
):
opts = self.get_temp_config("master")
opts["sign_pub_messages"] = False
channel = TCPPubServerChannel(opts)
wrap = MagicMock()
crypt = MagicMock()
crypt.dumps.return_value = {"test": "value"}
secrets.return_value = {"aes": {"secret": None}}
crypticle.return_value = crypt
sync_wrapper.return_value = wrap
check_minions.return_value = {"minions": ["minion02"]}
# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
payload = wrap.send.call_args[0][0]
# verify we send it with correct topic
assert "topic_lst" in payload
self.assertEqual(payload["topic_lst"], ["minion02"])
# verify it was correctly calling check_minions
check_minions.assert_called_with("minion02", tgt_type="list")
# close our pub_channel, to pass our FD checks
self.pub_channel.close()
del self.pub_channel

View file

@ -1,815 +0,0 @@
"""
:codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
"""
import ctypes
import multiprocessing
import os
import threading
import time
from concurrent.futures.thread import ThreadPoolExecutor
import pytest
import salt.config
import salt.exceptions
import salt.ext.tornado.gen
import salt.ext.tornado.ioloop
import salt.log.setup
import salt.transport.client
import salt.transport.server
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
import zmq.eventloop.ioloop
from salt.ext.tornado.testing import AsyncTestCase
from salt.transport.zeromq import AsyncReqMessageClientPool
from saltfactories.utils.ports import get_unused_localhost_port
from tests.support.helpers import flaky, not_runs_on, slowTest
from tests.support.mixins import AdaptedConfigurationTestCaseMixin
from tests.support.mock import MagicMock, call, patch
from tests.support.runtests import RUNTIME_VARS
from tests.support.unit import TestCase, skipIf
from tests.unit.transport.mixins import (
PubChannelMixin,
ReqChannelMixin,
run_loop_in_thread,
)
pytestmark = [
pytest.mark.skip_on_darwin,
pytest.mark.skip_on_freebsd,
]
x = "fix pre"
# support pyzmq 13.0.x, TODO: remove once we force people to 14.0.x
if not hasattr(zmq.eventloop.ioloop, "ZMQIOLoop"):
zmq.eventloop.ioloop.ZMQIOLoop = zmq.eventloop.ioloop.IOLoop
class BaseZMQReqCase(TestCase, AdaptedConfigurationTestCaseMixin):
"""
Test the req server/client pair
"""
@classmethod
def setUpClass(cls):
if not hasattr(cls, "_handle_payload"):
return
ret_port = get_unused_localhost_port()
publish_port = get_unused_localhost_port()
tcp_master_pub_port = get_unused_localhost_port()
tcp_master_pull_port = get_unused_localhost_port()
tcp_master_publish_pull = get_unused_localhost_port()
tcp_master_workers = get_unused_localhost_port()
cls.master_config = cls.get_temp_config(
"master",
**{
"transport": "zeromq",
"auto_accept": True,
"ret_port": ret_port,
"publish_port": publish_port,
"tcp_master_pub_port": tcp_master_pub_port,
"tcp_master_pull_port": tcp_master_pull_port,
"tcp_master_publish_pull": tcp_master_publish_pull,
"tcp_master_workers": tcp_master_workers,
}
)
cls.minion_config = cls.get_temp_config(
"minion",
**{
"transport": "zeromq",
"master_ip": "127.0.0.1",
"master_port": ret_port,
"auth_timeout": 5,
"auth_tries": 1,
"master_uri": "tcp://127.0.0.1:{}".format(ret_port),
}
)
cls.process_manager = salt.utils.process.ProcessManager(
name="ReqServer_ProcessManager"
)
cls.server_channel = salt.transport.server.ReqServerChannel.factory(
cls.master_config
)
cls.server_channel.pre_fork(cls.process_manager)
cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
cls.evt = threading.Event()
cls.server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
cls.server_thread = threading.Thread(
target=run_loop_in_thread, args=(cls.io_loop, cls.evt)
)
cls.server_thread.start()
@classmethod
def tearDownClass(cls):
if not hasattr(cls, "_handle_payload"):
return
# Attempting to kill the children hangs the test suite.
# Let the test suite handle this instead.
cls.process_manager.stop_restarting()
cls.process_manager.kill_children()
cls.evt.set()
cls.server_thread.join()
time.sleep(
2
) # Give the procs a chance to fully close before we stop the io_loop
cls.server_channel.close()
del cls.server_channel
del cls.io_loop
del cls.process_manager
del cls.server_thread
del cls.master_config
del cls.minion_config
@classmethod
def _handle_payload(cls, payload):
"""
TODO: something besides echo
"""
return payload, {"fun": "send_clear"}
class ClearReqTestCases(BaseZMQReqCase, ReqChannelMixin):
"""
Test all of the clear msg stuff
"""
def setUp(self):
self.channel = salt.transport.client.ReqChannel.factory(
self.minion_config, crypt="clear"
)
def tearDown(self):
self.channel.close()
del self.channel
@classmethod
@salt.ext.tornado.gen.coroutine
def _handle_payload(cls, payload):
"""
TODO: something besides echo
"""
raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
@slowTest
def test_master_uri_override(self):
"""
ensure master_uri kwarg is respected
"""
# minion_config should be 127.0.0.1, we want a different uri that still connects
uri = "tcp://{master_ip}:{master_port}".format(
master_ip="localhost", master_port=self.minion_config["master_port"]
)
channel = salt.transport.client.ReqChannel.factory(
self.minion_config, master_uri=uri
)
self.assertIn("localhost", channel.master_uri)
del channel
@flaky
@not_runs_on(
kernel="linux",
os_familiy="Suse",
reason="Skipping until https://github.com/saltstack/salt/issues/32902 gets fixed",
)
class AESReqTestCases(BaseZMQReqCase, ReqChannelMixin):
def setUp(self):
self.channel = salt.transport.client.ReqChannel.factory(self.minion_config)
def tearDown(self):
self.channel.close()
del self.channel
@classmethod
@salt.ext.tornado.gen.coroutine
def _handle_payload(cls, payload):
"""
TODO: something besides echo
"""
raise salt.ext.tornado.gen.Return((payload, {"fun": "send"}))
# TODO: make failed returns have a specific framing so we can raise the same exception
# on encrypted channels
#
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
#
# WARNING: This test will fail randomly on any system with > 1 CPU core!!!
#
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@slowTest
def test_badload(self):
"""
Test a variety of bad requests, make sure that we get some sort of error
"""
# TODO: This test should be re-enabled when Jenkins moves to C7.
# Once the version of salt-testing is increased to something newer than the September
# release of salt-testing, the @flaky decorator should be applied to this test.
msgs = ["", [], tuple()]
for msg in msgs:
with self.assertRaises(salt.exceptions.AuthenticationError):
ret = self.channel.send(msg, timeout=5)
class BaseZMQPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
"""
Test the req server/client pair
"""
@classmethod
def setUpClass(cls):
ret_port = get_unused_localhost_port()
publish_port = get_unused_localhost_port()
tcp_master_pub_port = get_unused_localhost_port()
tcp_master_pull_port = get_unused_localhost_port()
tcp_master_publish_pull = get_unused_localhost_port()
tcp_master_workers = get_unused_localhost_port()
cls.master_config = cls.get_temp_config(
"master",
**{
"transport": "zeromq",
"auto_accept": True,
"ret_port": ret_port,
"publish_port": publish_port,
"tcp_master_pub_port": tcp_master_pub_port,
"tcp_master_pull_port": tcp_master_pull_port,
"tcp_master_publish_pull": tcp_master_publish_pull,
"tcp_master_workers": tcp_master_workers,
}
)
cls.minion_config = salt.config.minion_config(
os.path.join(RUNTIME_VARS.TMP_CONF_DIR, "minion")
)
cls.minion_config = cls.get_temp_config(
"minion",
**{
"transport": "zeromq",
"master_ip": "127.0.0.1",
"master_port": ret_port,
"master_uri": "tcp://127.0.0.1:{}".format(ret_port),
}
)
cls.process_manager = salt.utils.process.ProcessManager(
name="ReqServer_ProcessManager"
)
cls.server_channel = salt.transport.server.PubServerChannel.factory(
cls.master_config
)
cls.server_channel.pre_fork(cls.process_manager)
# we also require req server for auth
cls.req_server_channel = salt.transport.server.ReqServerChannel.factory(
cls.master_config
)
cls.req_server_channel.pre_fork(cls.process_manager)
cls._server_io_loop = salt.ext.tornado.ioloop.IOLoop()
cls.evt = threading.Event()
cls.req_server_channel.post_fork(
cls._handle_payload, io_loop=cls._server_io_loop
)
cls.server_thread = threading.Thread(
target=run_loop_in_thread, args=(cls._server_io_loop, cls.evt)
)
cls.server_thread.start()
@classmethod
def tearDownClass(cls):
cls.process_manager.kill_children()
cls.process_manager.stop_restarting()
time.sleep(
2
) # Give the procs a chance to fully close before we stop the io_loop
cls.evt.set()
cls.server_thread.join()
cls.req_server_channel.close()
cls.server_channel.close()
cls._server_io_loop.stop()
del cls.server_channel
del cls._server_io_loop
del cls.process_manager
del cls.server_thread
del cls.master_config
del cls.minion_config
@classmethod
def _handle_payload(cls, payload):
"""
TODO: something besides echo
"""
return payload, {"fun": "send_clear"}
def setUp(self):
super().setUp()
self._start_handlers = dict(self.io_loop._handlers)
def tearDown(self):
super().tearDown()
failures = []
for k, v in self.io_loop._handlers.items():
if self._start_handlers.get(k) != v:
failures.append((k, v))
del self._start_handlers
if len(failures) > 0:
raise Exception("FDs still attached to the IOLoop: {}".format(failures))
@skipIf(True, "Skip until we can devote time to fix this test")
class AsyncPubChannelTest(BaseZMQPubCase, PubChannelMixin):
"""
Tests around the publish system
"""
def get_new_ioloop(self):
return salt.ext.tornado.ioloop.IOLoop()
class AsyncReqMessageClientPoolTest(TestCase):
def setUp(self):
super().setUp()
sock_pool_size = 5
with patch(
"salt.transport.zeromq.AsyncReqMessageClient.__init__",
MagicMock(return_value=None),
):
self.message_client_pool = AsyncReqMessageClientPool(
{"sock_pool_size": sock_pool_size}, args=({}, "")
)
self.original_message_clients = self.message_client_pool.message_clients
self.message_client_pool.message_clients = [
MagicMock() for _ in range(sock_pool_size)
]
def tearDown(self):
del self.original_message_clients
super().tearDown()
def test_send(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
self.assertEqual([], self.message_client_pool.send())
self.message_client_pool.message_clients[2].send_queue = [0]
self.message_client_pool.message_clients[2].send.return_value = [1]
self.assertEqual([1], self.message_client_pool.send())
class ZMQConfigTest(TestCase):
def test_master_uri(self):
"""
test _get_master_uri method
"""
m_ip = "127.0.0.1"
m_port = 4505
s_ip = "111.1.0.1"
s_port = 4058
m_ip6 = "1234:5678::9abc"
s_ip6 = "1234:5678::1:9abc"
with patch("salt.transport.zeromq.LIBZMQ_VERSION_INFO", (4, 1, 6)), patch(
"salt.transport.zeromq.ZMQ_VERSION_INFO", (16, 0, 1)
):
# pass in both source_ip and source_port
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port, source_ip=s_ip, source_port=s_port
) == "tcp://{}:{};{}:{}".format(s_ip, s_port, m_ip, m_port)
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip6, master_port=m_port, source_ip=s_ip6, source_port=s_port
) == "tcp://[{}]:{};[{}]:{}".format(s_ip6, s_port, m_ip6, m_port)
# source ip and source_port empty
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port
) == "tcp://{}:{}".format(m_ip, m_port)
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip6, master_port=m_port
) == "tcp://[{}]:{}".format(m_ip6, m_port)
# pass in only source_ip
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port, source_ip=s_ip
) == "tcp://{}:0;{}:{}".format(s_ip, m_ip, m_port)
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip6, master_port=m_port, source_ip=s_ip6
) == "tcp://[{}]:0;[{}]:{}".format(s_ip6, m_ip6, m_port)
# pass in only source_port
assert salt.transport.zeromq._get_master_uri(
master_ip=m_ip, master_port=m_port, source_port=s_port
) == "tcp://0.0.0.0:{};{}:{}".format(s_port, m_ip, m_port)
class PubServerChannel(TestCase, AdaptedConfigurationTestCaseMixin):
@classmethod
def setUpClass(cls):
ret_port = get_unused_localhost_port()
publish_port = get_unused_localhost_port()
tcp_master_pub_port = get_unused_localhost_port()
tcp_master_pull_port = get_unused_localhost_port()
tcp_master_publish_pull = get_unused_localhost_port()
tcp_master_workers = get_unused_localhost_port()
cls.master_config = cls.get_temp_config(
"master",
**{
"transport": "zeromq",
"auto_accept": True,
"ret_port": ret_port,
"publish_port": publish_port,
"tcp_master_pub_port": tcp_master_pub_port,
"tcp_master_pull_port": tcp_master_pull_port,
"tcp_master_publish_pull": tcp_master_publish_pull,
"tcp_master_workers": tcp_master_workers,
"sign_pub_messages": False,
}
)
salt.master.SMaster.secrets["aes"] = {
"secret": multiprocessing.Array(
ctypes.c_char,
salt.utils.stringutils.to_bytes(
salt.crypt.Crypticle.generate_key_string()
),
),
}
cls.minion_config = cls.get_temp_config(
"minion",
**{
"transport": "zeromq",
"master_ip": "127.0.0.1",
"master_port": ret_port,
"auth_timeout": 5,
"auth_tries": 1,
"master_uri": "tcp://127.0.0.1:{}".format(ret_port),
}
)
@classmethod
def tearDownClass(cls):
del cls.minion_config
del cls.master_config
def setUp(self):
# Start the event loop, even though we don't directly use this with
# ZeroMQPubServerChannel, having it running seems to increase the
# likely hood of dropped messages.
self.io_loop = salt.ext.tornado.ioloop.IOLoop()
self.io_loop.make_current()
self.io_loop_thread = threading.Thread(target=self.io_loop.start)
self.io_loop_thread.start()
self.process_manager = salt.utils.process.ProcessManager(
name="PubServer_ProcessManager"
)
def tearDown(self):
self.io_loop.add_callback(self.io_loop.stop)
self.io_loop_thread.join()
self.process_manager.stop_restarting()
self.process_manager.kill_children()
del self.io_loop
del self.io_loop_thread
del self.process_manager
@staticmethod
def _gather_results(opts, pub_uri, results, timeout=120, messages=None):
"""
Gather results until then number of seconds specified by timeout passes
without reveiving a message
"""
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
sock.setsockopt(zmq.LINGER, -1)
sock.setsockopt(zmq.SUBSCRIBE, b"")
sock.connect(pub_uri)
last_msg = time.time()
serial = salt.payload.Serial(opts)
crypticle = salt.crypt.Crypticle(
opts, salt.master.SMaster.secrets["aes"]["secret"].value
)
while time.time() - last_msg < timeout:
try:
payload = sock.recv(zmq.NOBLOCK)
except zmq.ZMQError:
time.sleep(0.01)
else:
if messages:
if messages != 1:
messages -= 1
continue
payload = crypticle.loads(serial.loads(payload)["load"])
if "stop" in payload:
break
last_msg = time.time()
results.append(payload["jid"])
@skipIf(salt.utils.platform.is_windows(), "Skip on Windows OS")
@slowTest
def test_publish_to_pubserv_ipc(self):
"""
Test sending 10K messags to ZeroMQPubServerChannel using IPC transport
ZMQ's ipc transport not supported on Windows
"""
opts = dict(self.master_config, ipc_mode="ipc", pub_hwm=0)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(
self.process_manager,
kwargs={"log_queue": salt.log.setup.get_multiprocessing_logging_queue()},
)
pub_uri = "tcp://{interface}:{publish_port}".format(**server_channel.opts)
send_num = 10000
expect = []
results = []
gather = threading.Thread(
target=self._gather_results, args=(self.minion_config, pub_uri, results,)
)
gather.start()
# Allow time for server channel to start, especially on windows
time.sleep(2)
for i in range(send_num):
expect.append(i)
load = {"tgt_type": "glob", "tgt": "*", "jid": i}
server_channel.publish(load)
server_channel.publish({"tgt_type": "glob", "tgt": "*", "stop": True})
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))
@skipIf(salt.utils.platform.is_linux(), "Skip on Linux")
@slowTest
def test_zeromq_publish_port(self):
"""
test when connecting that we
use the publish_port set in opts
when its not 4506
"""
opts = dict(
self.master_config,
ipc_mode="ipc",
pub_hwm=0,
recon_randomize=False,
publish_port=455505,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
patch_socket = MagicMock(return_value=True)
patch_auth = MagicMock(return_value=True)
with patch.object(channel, "_socket", patch_socket), patch.object(
channel, "auth", patch_auth
):
channel.connect()
assert str(opts["publish_port"]) in patch_socket.mock_calls[0][1][0]
@skipIf(salt.utils.platform.is_linux(), "Skip on Linux")
def test_zeromq_zeromq_filtering_decode_message_no_match(self):
"""
test AsyncZeroMQPubChannel _decode_messages when
zmq_filtering enabled and minion does not match
"""
message = [
b"4f26aeafdb2367620a393c973eddbe8f8b846eb",
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
]
opts = dict(
self.master_config,
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
recon_randomize=False,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
server_channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
with patch(
"salt.crypt.AsyncAuth.crypticle",
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
) as mock_test:
res = server_channel._decode_messages(message)
assert res.result() is None
@skipIf(salt.utils.platform.is_linux(), "Skip on Linux")
def test_zeromq_zeromq_filtering_decode_message(self):
"""
test AsyncZeroMQPubChannel _decode_messages
when zmq_filtered enabled
"""
message = [
b"4f26aeafdb2367620a393c973eddbe8f8b846ebd",
b"\x82\xa3enc\xa3aes\xa4load\xda\x00`\xeeR\xcf"
b"\x0eaI#V\x17if\xcf\xae\x05\xa7\xb3bN\xf7\xb2\xe2"
b'\xd0sF\xd1\xd4\xecB\xe8\xaf"/*ml\x80Q3\xdb\xaexg'
b"\x8e\x8a\x8c\xd3l\x03\\,J\xa7\x01i\xd1:]\xe3\x8d"
b"\xf4\x03\x88K\x84\n`\xe8\x9a\xad\xad\xc6\x8ea\x15>"
b"\x92m\x9e\xc7aM\x11?\x18;\xbd\x04c\x07\x85\x99\xa3\xea[\x00D",
]
opts = dict(
self.master_config,
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
recon_randomize=False,
recon_default=1,
recon_max=2,
master_ip="127.0.0.1",
acceptance_wait_time=5,
acceptance_wait_time_max=5,
)
opts["master_uri"] = "tcp://{interface}:{publish_port}".format(**opts)
server_channel = salt.transport.zeromq.AsyncZeroMQPubChannel(opts)
with patch(
"salt.crypt.AsyncAuth.crypticle",
MagicMock(return_value={"tgt_type": "glob", "tgt": "*", "jid": 1}),
) as mock_test:
res = server_channel._decode_messages(message)
assert res.result()["enc"] == "aes"
@skipIf(salt.utils.platform.is_windows(), "Skip on Windows OS")
@slowTest
def test_zeromq_filtering(self):
"""
Test sending messags to publisher using UDP
with zeromq_filtering enabled
"""
opts = dict(
self.master_config,
ipc_mode="ipc",
pub_hwm=0,
zmq_filtering=True,
acceptance_wait_time=5,
)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(
self.process_manager,
kwargs={"log_queue": salt.log.setup.get_multiprocessing_logging_queue()},
)
pub_uri = "tcp://{interface}:{publish_port}".format(**server_channel.opts)
send_num = 1
expect = []
results = []
gather = threading.Thread(
target=self._gather_results,
args=(self.minion_config, pub_uri, results,),
kwargs={"messages": 2},
)
gather.start()
# Allow time for server channel to start, especially on windows
time.sleep(2)
expect.append(send_num)
load = {"tgt_type": "glob", "tgt": "*", "jid": send_num}
with patch(
"salt.utils.minions.CkMinions.check_minions",
MagicMock(
return_value={
"minions": ["minion"],
"missing": [],
"ssh_minions": False,
}
),
):
server_channel.publish(load)
server_channel.publish({"tgt_type": "glob", "tgt": "*", "stop": True})
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))
@slowTest
def test_publish_to_pubserv_tcp(self):
"""
Test sending 10K messags to ZeroMQPubServerChannel using TCP transport
"""
opts = dict(self.master_config, ipc_mode="tcp", pub_hwm=0)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(
self.process_manager,
kwargs={"log_queue": salt.log.setup.get_multiprocessing_logging_queue()},
)
pub_uri = "tcp://{interface}:{publish_port}".format(**server_channel.opts)
send_num = 10000
expect = []
results = []
gather = threading.Thread(
target=self._gather_results, args=(self.minion_config, pub_uri, results,)
)
gather.start()
# Allow time for server channel to start, especially on windows
time.sleep(2)
for i in range(send_num):
expect.append(i)
load = {"tgt_type": "glob", "tgt": "*", "jid": i}
server_channel.publish(load)
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))
@staticmethod
def _send_small(opts, sid, num=10):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
for i in range(num):
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-{}".format(sid, i)}
server_channel.publish(load)
server_channel.pub_close()
@staticmethod
def _send_large(opts, sid, num=10, size=250000 * 3):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
for i in range(num):
load = {
"tgt_type": "glob",
"tgt": "*",
"jid": "{}-{}".format(sid, i),
"xdata": "0" * size,
}
server_channel.publish(load)
server_channel.pub_close()
@skipIf(salt.utils.platform.is_freebsd(), "Skip on FreeBSD")
@slowTest
def test_issue_36469_tcp(self):
"""
Test sending both large and small messags to publisher using TCP
https://github.com/saltstack/salt/issues/36469
"""
opts = dict(self.master_config, ipc_mode="tcp", pub_hwm=0)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(
self.process_manager,
kwargs={"log_queue": salt.log.setup.get_multiprocessing_logging_queue()},
)
send_num = 10 * 4
expect = []
results = []
pub_uri = "tcp://{interface}:{publish_port}".format(**opts)
# Allow time for server channel to start, especially on windows
time.sleep(2)
gather = threading.Thread(
target=self._gather_results, args=(self.minion_config, pub_uri, results,)
)
gather.start()
with ThreadPoolExecutor(max_workers=4) as executor:
executor.submit(self._send_small, opts, 1)
executor.submit(self._send_small, opts, 2)
executor.submit(self._send_small, opts, 3)
executor.submit(self._send_large, opts, 4)
expect = ["{}-{}".format(a, b) for a in range(10) for b in (1, 2, 3, 4)]
time.sleep(0.1)
server_channel.publish({"tgt_type": "glob", "tgt": "*", "stop": True})
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))
class AsyncZeroMQReqChannelTests(TestCase):
def test_force_close_all_instances(self):
zmq1 = MagicMock()
zmq2 = MagicMock()
zmq3 = MagicMock()
zmq_objects = {"zmq": {"1": zmq1, "2": zmq2}, "other_zmq": {"3": zmq3}}
with patch(
"salt.transport.zeromq.AsyncZeroMQReqChannel.instance_map", zmq_objects
):
salt.transport.zeromq.AsyncZeroMQReqChannel.force_close_all_instances()
self.assertEqual(zmq1.mock_calls, [call.close()])
self.assertEqual(zmq2.mock_calls, [call.close()])
self.assertEqual(zmq3.mock_calls, [call.close()])
# check if instance map changed
self.assertIs(
zmq_objects, salt.transport.zeromq.AsyncZeroMQReqChannel.instance_map
)