mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Move tests into their proper transport. Relax timeouts on the IPC tests.
Signed-off-by: Pedro Algarvio <palgarvio@vmware.com>
This commit is contained in:
parent
d2bdc57b50
commit
458ab605ed
4 changed files with 427 additions and 370 deletions
|
@ -326,6 +326,10 @@ tests/support/pytest/mysql.py:
|
|||
- pytests.functional.states.test_mysql
|
||||
- pytests.functional.modules.test_mysql
|
||||
|
||||
tests/support/pytest/transport.py:
|
||||
- pytests.functional.transport.ipc.test_pub_server_channel
|
||||
- pytests.functional.transport.zeromq.test_pub_server_channel
|
||||
|
||||
tests/pytests/scenarios/multimaster:
|
||||
- pytests.scenarios.multimaster.test_multimaster
|
||||
- pytests.scenarios.multimaster.test_offline_master
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
import logging
|
||||
import time
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from saltfactories.utils import random_string
|
||||
|
||||
import salt.channel.server
|
||||
import salt.master
|
||||
from tests.support.pytest.transport import PubServerChannelProcess
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skip_on_spawning_platform(
|
||||
reason="These tests are currently broken on spawning platforms. Need to be rewritten.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["tcp", "zeromq"])
|
||||
def transport(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def salt_master(salt_factories, transport):
|
||||
config_defaults = {
|
||||
"transport": transport,
|
||||
"auto_accept": True,
|
||||
"sign_pub_messages": False,
|
||||
}
|
||||
factory = salt_factories.salt_master_daemon(
|
||||
random_string("ipc-master-"), defaults=config_defaults
|
||||
)
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def salt_minion(salt_master):
|
||||
config_defaults = {
|
||||
"transport": salt_master.config["transport"],
|
||||
"master_ip": "127.0.0.1",
|
||||
"master_port": salt_master.config["ret_port"],
|
||||
"auth_timeout": 5,
|
||||
"auth_tries": 1,
|
||||
"master_uri": "tcp://127.0.0.1:{}".format(salt_master.config["ret_port"]),
|
||||
}
|
||||
factory = salt_master.salt_minion_daemon(
|
||||
random_string("zeromq-minion-"), defaults=config_defaults
|
||||
)
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.mark.skip_on_windows
|
||||
@pytest.mark.slow_test
|
||||
def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
|
||||
"""
|
||||
Test sending 10K messags to ZeroMQPubServerChannel using IPC transport
|
||||
|
||||
ZMQ's ipc transport not supported on Windows
|
||||
"""
|
||||
opts = dict(
|
||||
salt_master.config.copy(),
|
||||
ipc_mode="ipc",
|
||||
pub_hwm=0,
|
||||
transport=transport,
|
||||
)
|
||||
minion_opts = dict(salt_minion.config.copy(), transport=transport)
|
||||
with PubServerChannelProcess(opts, minion_opts) as server_channel:
|
||||
send_num = 10000
|
||||
expect = []
|
||||
for idx in range(send_num):
|
||||
expect.append(idx)
|
||||
load = {"tgt_type": "glob", "tgt": "*", "jid": idx}
|
||||
server_channel.publish(load)
|
||||
results = server_channel.collector.results
|
||||
assert len(results) == send_num, "{} != {}, difference: {}".format(
|
||||
len(results), send_num, set(expect).difference(results)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_on_freebsd
|
||||
@pytest.mark.slow_test
|
||||
def test_issue_36469_tcp(salt_master, salt_minion, transport):
|
||||
"""
|
||||
Test sending both large and small messags to publisher using TCP
|
||||
|
||||
https://github.com/saltstack/salt/issues/36469
|
||||
"""
|
||||
if transport == "tcp":
|
||||
pytest.skip("Test not applicable to the ZeroMQ transport.")
|
||||
|
||||
def _send_small(opts, sid, num=10):
|
||||
server_channel = salt.channel.server.PubServerChannel.factory(opts)
|
||||
for idx in range(num):
|
||||
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-s{}".format(sid, idx)}
|
||||
server_channel.publish(load)
|
||||
time.sleep(0.3)
|
||||
time.sleep(3)
|
||||
server_channel.close_pub()
|
||||
|
||||
def _send_large(opts, sid, num=10, size=250000 * 3):
|
||||
server_channel = salt.channel.server.PubServerChannel.factory(opts)
|
||||
for idx in range(num):
|
||||
load = {
|
||||
"tgt_type": "glob",
|
||||
"tgt": "*",
|
||||
"jid": "{}-l{}".format(sid, idx),
|
||||
"xdata": "0" * size,
|
||||
}
|
||||
server_channel.publish(load)
|
||||
time.sleep(0.3)
|
||||
server_channel.close_pub()
|
||||
|
||||
opts = dict(salt_master.config.copy(), ipc_mode="tcp", pub_hwm=0)
|
||||
send_num = 10 * 4
|
||||
expect = []
|
||||
with PubServerChannelProcess(opts, salt_minion.config.copy()) as server_channel:
|
||||
assert "aes" in salt.master.SMaster.secrets
|
||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||
executor.submit(_send_small, opts, 1)
|
||||
executor.submit(_send_large, opts, 2)
|
||||
executor.submit(_send_small, opts, 3)
|
||||
executor.submit(_send_large, opts, 4)
|
||||
expect.extend(["{}-s{}".format(a, b) for a in range(10) for b in (1, 3)])
|
||||
expect.extend(["{}-l{}".format(a, b) for a in range(10) for b in (2, 4)])
|
||||
results = server_channel.collector.results
|
||||
assert len(results) == send_num, "{} != {}, difference: {}".format(
|
||||
len(results), send_num, set(expect).difference(results)
|
||||
)
|
|
@ -1,31 +1,9 @@
|
|||
import ctypes
|
||||
import logging
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
from pytestshellutils.utils.processes import terminate_process
|
||||
|
||||
import salt.channel.client
|
||||
import salt.channel.server
|
||||
import salt.config
|
||||
import salt.exceptions
|
||||
import salt.ext.tornado.gen
|
||||
import salt.ext.tornado.ioloop
|
||||
import salt.log.setup
|
||||
import salt.master
|
||||
import salt.transport.client
|
||||
import salt.transport.server
|
||||
import salt.transport.tcp
|
||||
import salt.transport.zeromq
|
||||
import salt.utils.msgpack
|
||||
import salt.utils.platform
|
||||
import salt.utils.process
|
||||
import salt.utils.stringutils
|
||||
from tests.support.mock import MagicMock, patch
|
||||
from tests.support.pytest.transport import PubServerChannelProcess
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -37,353 +15,6 @@ pytestmark = [
|
|||
]
|
||||
|
||||
|
||||
class RecvError(Exception):
|
||||
"""
|
||||
Raised by the Collector's _recv method when there is a problem
|
||||
getting publishes from to the publisher.
|
||||
"""
|
||||
|
||||
|
||||
class Collector(salt.utils.process.SignalHandlingProcess):
|
||||
def __init__(
|
||||
self, minion_config, interface, port, aes_key, timeout=300, zmq_filtering=False
|
||||
):
|
||||
super().__init__()
|
||||
self.minion_config = minion_config
|
||||
self.interface = interface
|
||||
self.port = port
|
||||
self.aes_key = aes_key
|
||||
self.timeout = timeout
|
||||
self.aes_key = aes_key
|
||||
self.hard_timeout = time.time() + timeout + 30
|
||||
self.manager = multiprocessing.Manager()
|
||||
self.results = self.manager.list()
|
||||
self.zmq_filtering = zmq_filtering
|
||||
self.stopped = multiprocessing.Event()
|
||||
self.started = multiprocessing.Event()
|
||||
self.running = multiprocessing.Event()
|
||||
if salt.utils.msgpack.version >= (0, 5, 2):
|
||||
# Under Py2 we still want raw to be set to True
|
||||
msgpack_kwargs = {"raw": False}
|
||||
else:
|
||||
msgpack_kwargs = {"encoding": "utf-8"}
|
||||
self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
return self.minion_config["transport"]
|
||||
|
||||
def _rotate_secrets(self, now=None):
|
||||
salt.master.SMaster.secrets["aes"] = {
|
||||
"secret": multiprocessing.Array(
|
||||
ctypes.c_char,
|
||||
salt.utils.stringutils.to_bytes(
|
||||
salt.crypt.Crypticle.generate_key_string()
|
||||
),
|
||||
),
|
||||
"serial": multiprocessing.Value(
|
||||
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
|
||||
),
|
||||
"reload": salt.crypt.Crypticle.generate_key_string,
|
||||
"rotate_master_key": self._rotate_secrets,
|
||||
}
|
||||
|
||||
def _setup_listener(self):
|
||||
if self.transport == "zeromq":
|
||||
ctx = zmq.Context()
|
||||
self.sock = ctx.socket(zmq.SUB)
|
||||
self.sock.setsockopt(zmq.LINGER, -1)
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
|
||||
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
|
||||
self.sock.connect(pub_uri)
|
||||
else:
|
||||
end = time.time() + 60
|
||||
while True:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.connect((self.interface, self.port))
|
||||
except ConnectionRefusedError:
|
||||
if time.time() >= end:
|
||||
raise
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
self.sock = salt.ext.tornado.iostream.IOStream(sock)
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def _recv(self):
|
||||
if self.transport == "zeromq":
|
||||
# test_zeromq_filtering requires catching the
|
||||
# SaltDeserializationError in order to pass.
|
||||
try:
|
||||
payload = self.sock.recv(zmq.NOBLOCK)
|
||||
serial_payload = salt.payload.loads(payload)
|
||||
raise salt.ext.tornado.gen.Return(serial_payload)
|
||||
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
|
||||
raise RecvError("ZMQ Error")
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
raise salt.ext.tornado.gen.Return(msg["body"])
|
||||
byts = yield self.sock.read_bytes(8096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
raise salt.ext.tornado.gen.Return(msg["body"])
|
||||
raise RecvError("TCP Error")
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def _run(self, loop):
|
||||
try:
|
||||
self._setup_listener()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.started.set()
|
||||
log.exception("Failed to start listening")
|
||||
return
|
||||
self.started.set()
|
||||
last_msg = time.time()
|
||||
serial = salt.payload.Serial(self.minion_config)
|
||||
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
|
||||
while True:
|
||||
curr_time = time.time()
|
||||
if time.time() > self.hard_timeout:
|
||||
log.error("Hard timeout reaced in test collector!")
|
||||
break
|
||||
if curr_time - last_msg >= self.timeout:
|
||||
log.error("Receive timeout reaced in test collector!")
|
||||
break
|
||||
try:
|
||||
payload = yield self._recv()
|
||||
except RecvError:
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
try:
|
||||
payload = crypticle.loads(payload["load"])
|
||||
if not payload:
|
||||
continue
|
||||
if "start" in payload:
|
||||
log.info("Collector started")
|
||||
self.running.set()
|
||||
continue
|
||||
if "stop" in payload:
|
||||
log.info("Collector stopped")
|
||||
break
|
||||
last_msg = time.time()
|
||||
self.results.append(payload["jid"])
|
||||
except salt.exceptions.SaltDeserializationError:
|
||||
log.error("Deserializer Error")
|
||||
if not self.zmq_filtering:
|
||||
log.exception("Failed to deserialize...")
|
||||
break
|
||||
loop.stop()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Gather results until then number of seconds specified by timeout passes
|
||||
without receiving a message
|
||||
"""
|
||||
loop = salt.ext.tornado.ioloop.IOLoop()
|
||||
loop.add_callback(self._run, loop)
|
||||
loop.start()
|
||||
|
||||
def __enter__(self):
|
||||
self.manager.__enter__()
|
||||
self.start()
|
||||
# Wait until we can start receiving events
|
||||
self.started.wait()
|
||||
self.started.clear()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
# Wait until we either processed all expected messages or we reach the hard timeout
|
||||
join_secs = self.hard_timeout - time.time()
|
||||
log.info("Waiting at most %s seconds before exiting the collector", join_secs)
|
||||
self.join(join_secs)
|
||||
self.terminate()
|
||||
# Cast our manager.list into a plain list
|
||||
self.results = list(self.results)
|
||||
# Terminate our multiprocessing manager
|
||||
self.manager.__exit__(*args)
|
||||
log.debug("The collector has exited")
|
||||
self.stopped.set()
|
||||
|
||||
|
||||
class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
||||
def __init__(self, master_config, minion_config, **collector_kwargs):
|
||||
super().__init__()
|
||||
self._closing = False
|
||||
self.master_config = master_config
|
||||
self.minion_config = minion_config
|
||||
self.collector_kwargs = collector_kwargs
|
||||
self.aes_key = salt.crypt.Crypticle.generate_key_string()
|
||||
salt.master.SMaster.secrets["aes"] = {
|
||||
"secret": multiprocessing.Array(
|
||||
ctypes.c_char,
|
||||
salt.utils.stringutils.to_bytes(self.aes_key),
|
||||
),
|
||||
"serial": multiprocessing.Value(
|
||||
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
|
||||
),
|
||||
}
|
||||
self.process_manager = salt.utils.process.ProcessManager(
|
||||
name="ZMQ-PubServer-ProcessManager"
|
||||
)
|
||||
self.pub_server_channel = salt.channel.server.PubServerChannel.factory(
|
||||
self.master_config
|
||||
)
|
||||
self.pub_server_channel.pre_fork(self.process_manager)
|
||||
self.pub_uri = "tcp://{interface}:{publish_port}".format(**self.master_config)
|
||||
self.queue = multiprocessing.Queue()
|
||||
self.stopped = multiprocessing.Event()
|
||||
self.collector = Collector(
|
||||
self.minion_config,
|
||||
self.master_config["interface"],
|
||||
self.master_config["publish_port"],
|
||||
self.aes_key,
|
||||
**self.collector_kwargs
|
||||
)
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while True:
|
||||
payload = self.queue.get()
|
||||
if payload is None:
|
||||
log.debug("We received the stop sentinel")
|
||||
break
|
||||
self.pub_server_channel.publish(payload)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
self.stopped.set()
|
||||
|
||||
def _handle_signals(self, signum, sigframe):
|
||||
self.close()
|
||||
super()._handle_signals(signum, sigframe)
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
if self.process_manager is None:
|
||||
return
|
||||
self.process_manager.terminate()
|
||||
if hasattr(self.pub_server_channel, "pub_close"):
|
||||
self.pub_server_channel.pub_close()
|
||||
# Really terminate any process still left behind
|
||||
for pid in self.process_manager._process_map:
|
||||
terminate_process(pid=pid, kill_children=True, slow_stop=False)
|
||||
self.process_manager = None
|
||||
|
||||
def publish(self, payload):
|
||||
self.queue.put(payload)
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
self.collector.__enter__()
|
||||
attempts = 300
|
||||
while attempts > 0:
|
||||
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "start": True})
|
||||
if self.collector.running.wait(1) is True:
|
||||
break
|
||||
attempts -= 1
|
||||
else:
|
||||
pytest.fail("Failed to confirm the collector has started")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
# Publish a payload to tell the collection it's done processing
|
||||
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
|
||||
# Now trigger the collector to also exit
|
||||
self.collector.__exit__(*args)
|
||||
# We can safely wait here without a timeout because the Collector instance has a
|
||||
# hard timeout set, so eventually Collector.stopped will be set
|
||||
self.collector.stopped.wait()
|
||||
# Stop our own processing
|
||||
self.queue.put(None)
|
||||
# Wait at most 10 secs for the above `None` in the queue to be processed
|
||||
self.stopped.wait(10)
|
||||
self.close()
|
||||
self.terminate()
|
||||
log.info("The PubServerChannelProcess has terminated")
|
||||
|
||||
|
||||
@pytest.fixture(params=["tcp", "zeromq"])
|
||||
def transport(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.mark.skip_on_windows
|
||||
@pytest.mark.slow_test
|
||||
def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
|
||||
"""
|
||||
Test sending 10K messags to ZeroMQPubServerChannel using IPC transport
|
||||
|
||||
ZMQ's ipc transport not supported on Windows
|
||||
"""
|
||||
opts = dict(
|
||||
salt_master.config.copy(), ipc_mode="ipc", pub_hwm=0, transport=transport
|
||||
)
|
||||
minion_opts = dict(salt_minion.config.copy(), transport=transport)
|
||||
with PubServerChannelProcess(opts, minion_opts) as server_channel:
|
||||
send_num = 10000
|
||||
expect = []
|
||||
for idx in range(send_num):
|
||||
expect.append(idx)
|
||||
load = {"tgt_type": "glob", "tgt": "*", "jid": idx}
|
||||
server_channel.publish(load)
|
||||
results = server_channel.collector.results
|
||||
assert len(results) == send_num, "{} != {}, difference: {}".format(
|
||||
len(results), send_num, set(expect).difference(results)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_on_freebsd
|
||||
@pytest.mark.slow_test
|
||||
def test_issue_36469_tcp(salt_master, salt_minion):
|
||||
"""
|
||||
Test sending both large and small messags to publisher using TCP
|
||||
|
||||
https://github.com/saltstack/salt/issues/36469
|
||||
"""
|
||||
|
||||
def _send_small(opts, sid, num=10):
|
||||
server_channel = salt.channel.server.PubServerChannel.factory(opts)
|
||||
for idx in range(num):
|
||||
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-s{}".format(sid, idx)}
|
||||
server_channel.publish(load)
|
||||
time.sleep(0.3)
|
||||
time.sleep(3)
|
||||
server_channel.close_pub()
|
||||
|
||||
def _send_large(opts, sid, num=10, size=250000 * 3):
|
||||
server_channel = salt.channel.server.PubServerChannel.factory(opts)
|
||||
for idx in range(num):
|
||||
load = {
|
||||
"tgt_type": "glob",
|
||||
"tgt": "*",
|
||||
"jid": "{}-l{}".format(sid, idx),
|
||||
"xdata": "0" * size,
|
||||
}
|
||||
server_channel.publish(load)
|
||||
time.sleep(0.3)
|
||||
server_channel.close_pub()
|
||||
|
||||
opts = dict(salt_master.config.copy(), ipc_mode="tcp", pub_hwm=0)
|
||||
send_num = 10 * 4
|
||||
expect = []
|
||||
with PubServerChannelProcess(opts, salt_minion.config.copy()) as server_channel:
|
||||
assert "aes" in salt.master.SMaster.secrets
|
||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||
executor.submit(_send_small, opts, 1)
|
||||
executor.submit(_send_large, opts, 2)
|
||||
executor.submit(_send_small, opts, 3)
|
||||
executor.submit(_send_large, opts, 4)
|
||||
expect.extend(["{}-s{}".format(a, b) for a in range(10) for b in (1, 3)])
|
||||
expect.extend(["{}-l{}".format(a, b) for a in range(10) for b in (2, 4)])
|
||||
results = server_channel.collector.results
|
||||
assert len(results) == send_num, "{} != {}, difference: {}".format(
|
||||
len(results), send_num, set(expect).difference(results)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_on_windows
|
||||
@pytest.mark.slow_test
|
||||
def test_zeromq_filtering(salt_master, salt_minion):
|
||||
|
|
290
tests/support/pytest/transport.py
Normal file
290
tests/support/pytest/transport.py
Normal file
|
@ -0,0 +1,290 @@
|
|||
import ctypes
|
||||
import logging
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
from pytestshellutils.utils.processes import terminate_process
|
||||
|
||||
import salt.channel.server
|
||||
import salt.exceptions
|
||||
import salt.ext.tornado.gen
|
||||
import salt.ext.tornado.ioloop
|
||||
import salt.ext.tornado.iostream
|
||||
import salt.master
|
||||
import salt.utils.msgpack
|
||||
import salt.utils.process
|
||||
import salt.utils.stringutils
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecvError(Exception):
|
||||
"""
|
||||
Raised by the Collector's _recv method when there is a problem
|
||||
getting publishes from to the publisher.
|
||||
"""
|
||||
|
||||
|
||||
class Collector(salt.utils.process.SignalHandlingProcess):
|
||||
def __init__(
|
||||
self,
|
||||
minion_config,
|
||||
interface,
|
||||
port,
|
||||
aes_key,
|
||||
timeout=300,
|
||||
zmq_filtering=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.minion_config = minion_config
|
||||
self.interface = interface
|
||||
self.port = port
|
||||
self.aes_key = aes_key
|
||||
self.timeout = timeout
|
||||
self.aes_key = aes_key
|
||||
self.hard_timeout = time.time() + timeout + 120
|
||||
self.manager = multiprocessing.Manager()
|
||||
self.results = self.manager.list()
|
||||
self.zmq_filtering = zmq_filtering
|
||||
self.stopped = multiprocessing.Event()
|
||||
self.started = multiprocessing.Event()
|
||||
self.running = multiprocessing.Event()
|
||||
self.unpacker = salt.utils.msgpack.Unpacker(raw=False)
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
return self.minion_config["transport"]
|
||||
|
||||
def _rotate_secrets(self, now=None):
|
||||
salt.master.SMaster.secrets["aes"] = {
|
||||
"secret": multiprocessing.Array(
|
||||
ctypes.c_char,
|
||||
salt.utils.stringutils.to_bytes(
|
||||
salt.crypt.Crypticle.generate_key_string()
|
||||
),
|
||||
),
|
||||
"serial": multiprocessing.Value(
|
||||
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
|
||||
),
|
||||
"reload": salt.crypt.Crypticle.generate_key_string,
|
||||
"rotate_master_key": self._rotate_secrets,
|
||||
}
|
||||
|
||||
def _setup_listener(self):
|
||||
if self.transport == "zeromq":
|
||||
ctx = zmq.Context()
|
||||
self.sock = ctx.socket(zmq.SUB)
|
||||
self.sock.setsockopt(zmq.LINGER, -1)
|
||||
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
|
||||
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
|
||||
self.sock.connect(pub_uri)
|
||||
else:
|
||||
end = time.time() + 120
|
||||
while True:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.connect((self.interface, self.port))
|
||||
except ConnectionRefusedError:
|
||||
if time.time() >= end:
|
||||
raise
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
self.sock = salt.ext.tornado.iostream.IOStream(sock)
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def _recv(self):
|
||||
if self.transport == "zeromq":
|
||||
# test_zeromq_filtering requires catching the
|
||||
# SaltDeserializationError in order to pass.
|
||||
try:
|
||||
payload = self.sock.recv(zmq.NOBLOCK)
|
||||
serial_payload = salt.payload.loads(payload)
|
||||
raise salt.ext.tornado.gen.Return(serial_payload)
|
||||
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
|
||||
raise RecvError("ZMQ Error")
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
raise salt.ext.tornado.gen.Return(msg["body"])
|
||||
byts = yield self.sock.read_bytes(8096, partial=True)
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
raise salt.ext.tornado.gen.Return(msg["body"])
|
||||
raise RecvError("TCP Error")
|
||||
|
||||
@salt.ext.tornado.gen.coroutine
|
||||
def _run(self, loop):
|
||||
try:
|
||||
self._setup_listener()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.started.set()
|
||||
log.exception("Failed to start listening")
|
||||
return
|
||||
self.started.set()
|
||||
last_msg = time.time()
|
||||
serial = salt.payload.Serial(self.minion_config)
|
||||
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
|
||||
while True:
|
||||
curr_time = time.time()
|
||||
if time.time() > self.hard_timeout:
|
||||
log.error("Hard timeout reaced in test collector!")
|
||||
break
|
||||
if curr_time - last_msg >= self.timeout:
|
||||
log.error("Receive timeout reaced in test collector!")
|
||||
break
|
||||
try:
|
||||
payload = yield self._recv()
|
||||
except RecvError:
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
try:
|
||||
payload = crypticle.loads(payload["load"])
|
||||
if not payload:
|
||||
continue
|
||||
if "start" in payload:
|
||||
log.info("Collector started")
|
||||
self.running.set()
|
||||
continue
|
||||
if "stop" in payload:
|
||||
log.info("Collector stopped")
|
||||
break
|
||||
last_msg = time.time()
|
||||
self.results.append(payload["jid"])
|
||||
except salt.exceptions.SaltDeserializationError:
|
||||
log.error("Deserializer Error")
|
||||
if not self.zmq_filtering:
|
||||
log.exception("Failed to deserialize...")
|
||||
break
|
||||
loop.stop()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Gather results until then number of seconds specified by timeout passes
|
||||
without receiving a message
|
||||
"""
|
||||
loop = salt.ext.tornado.ioloop.IOLoop()
|
||||
loop.add_callback(self._run, loop)
|
||||
loop.start()
|
||||
|
||||
def __enter__(self):
|
||||
self.manager.__enter__()
|
||||
self.start()
|
||||
# Wait until we can start receiving events
|
||||
self.started.wait()
|
||||
self.started.clear()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
# Wait until we either processed all expected messages or we reach the hard timeout
|
||||
join_secs = self.hard_timeout - time.time()
|
||||
log.info("Waiting at most %s seconds before exiting the collector", join_secs)
|
||||
self.join(join_secs)
|
||||
self.terminate()
|
||||
# Cast our manager.list into a plain list
|
||||
self.results = list(self.results)
|
||||
# Terminate our multiprocessing manager
|
||||
self.manager.__exit__(*args)
|
||||
log.debug("The collector has exited")
|
||||
self.stopped.set()
|
||||
|
||||
|
||||
class PubServerChannelProcess(salt.utils.process.SignalHandlingProcess):
|
||||
def __init__(self, master_config, minion_config, **collector_kwargs):
|
||||
super().__init__()
|
||||
self._closing = False
|
||||
self.master_config = master_config
|
||||
self.minion_config = minion_config
|
||||
self.collector_kwargs = collector_kwargs
|
||||
self.aes_key = salt.crypt.Crypticle.generate_key_string()
|
||||
salt.master.SMaster.secrets["aes"] = {
|
||||
"secret": multiprocessing.Array(
|
||||
ctypes.c_char,
|
||||
salt.utils.stringutils.to_bytes(self.aes_key),
|
||||
),
|
||||
"serial": multiprocessing.Value(
|
||||
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
|
||||
),
|
||||
}
|
||||
self.process_manager = salt.utils.process.ProcessManager(
|
||||
name="ZMQ-PubServer-ProcessManager"
|
||||
)
|
||||
self.pub_server_channel = salt.channel.server.PubServerChannel.factory(
|
||||
self.master_config
|
||||
)
|
||||
self.pub_server_channel.pre_fork(self.process_manager)
|
||||
self.pub_uri = "tcp://{interface}:{publish_port}".format(**self.master_config)
|
||||
self.queue = multiprocessing.Queue()
|
||||
self.stopped = multiprocessing.Event()
|
||||
self.collector = Collector(
|
||||
self.minion_config,
|
||||
self.master_config["interface"],
|
||||
self.master_config["publish_port"],
|
||||
self.aes_key,
|
||||
**self.collector_kwargs
|
||||
)
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while True:
|
||||
payload = self.queue.get()
|
||||
if payload is None:
|
||||
log.debug("We received the stop sentinel")
|
||||
break
|
||||
self.pub_server_channel.publish(payload)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
self.stopped.set()
|
||||
|
||||
def _handle_signals(self, signum, sigframe):
|
||||
self.close()
|
||||
super()._handle_signals(signum, sigframe)
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
if self.process_manager is None:
|
||||
return
|
||||
self.process_manager.terminate()
|
||||
if hasattr(self.pub_server_channel, "pub_close"):
|
||||
self.pub_server_channel.pub_close()
|
||||
# Really terminate any process still left behind
|
||||
for pid in self.process_manager._process_map:
|
||||
terminate_process(pid=pid, kill_children=True, slow_stop=False)
|
||||
self.process_manager = None
|
||||
|
||||
def publish(self, payload):
|
||||
self.queue.put(payload)
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
self.collector.__enter__()
|
||||
attempts = 300
|
||||
while attempts > 0:
|
||||
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "start": True})
|
||||
if self.collector.running.wait(1) is True:
|
||||
break
|
||||
attempts -= 1
|
||||
else:
|
||||
pytest.fail("Failed to confirm the collector has started")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
# Publish a payload to tell the collection it's done processing
|
||||
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "stop": True})
|
||||
# Now trigger the collector to also exit
|
||||
self.collector.__exit__(*args)
|
||||
# We can safely wait here without a timeout because the Collector instance has a
|
||||
# hard timeout set, so eventually Collector.stopped will be set
|
||||
self.collector.stopped.wait()
|
||||
# Stop our own processing
|
||||
self.queue.put(None)
|
||||
# Wait at most 10 secs for the above `None` in the queue to be processed
|
||||
self.stopped.wait(10)
|
||||
self.close()
|
||||
self.terminate()
|
||||
log.info("The PubServerChannelProcess has terminated")
|
Loading…
Add table
Reference in a new issue