Add test coverage

This commit is contained in:
Daniel A. Wozniak 2023-10-11 00:10:22 -07:00 committed by Pedro Algarvio
parent 8a872eff08
commit ae1bf35ce8
9 changed files with 486 additions and 101 deletions

View file

@ -35,6 +35,7 @@ import salt.utils.platform
import salt.utils.versions
from salt.exceptions import SaltClientError, SaltReqTimeoutError
from salt.utils.network import ip_bracket
from salt.utils.process import SignalHandlingProcess
if salt.utils.platform.is_windows():
USE_LOAD_BALANCER = True
@ -43,7 +44,6 @@ else:
if USE_LOAD_BALANCER:
import salt.ext.tornado.util
from salt.utils.process import SignalHandlingProcess
log = logging.getLogger(__name__)
@ -128,69 +128,64 @@ def _set_tcp_keepalive(sock, opts):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0)
if USE_LOAD_BALANCER:
class LoadBalancerServer(SignalHandlingProcess):
"""
Raw TCP server which runs in its own process and will listen
for incoming connections. Each incoming connection will be
sent via multiprocessing queue to the workers.
Since the queue is shared amongst workers, only one worker will
handle a given connection.
"""
class LoadBalancerServer(SignalHandlingProcess):
"""
Raw TCP server which runs in its own process and will listen
for incoming connections. Each incoming connection will be
sent via multiprocessing queue to the workers.
Since the queue is shared amongst workers, only one worker will
handle a given connection.
"""
# TODO: opts!
# Based on default used in salt.ext.tornado.netutil.bind_sockets()
backlog = 128
# TODO: opts!
# Based on default used in salt.ext.tornado.netutil.bind_sockets()
backlog = 128
def __init__(self, opts, socket_queue, **kwargs):
super().__init__(**kwargs)
self.opts = opts
self.socket_queue = socket_queue
self._socket = None
def __init__(self, opts, socket_queue, **kwargs):
super().__init__(**kwargs)
self.opts = opts
self.socket_queue = socket_queue
def close(self):
if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
self._socket = None
def close(self):
if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
self._socket = None
# pylint: disable=W1701
def __del__(self):
self.close()
# pylint: disable=W1701
def __del__(self):
self.close()
# pylint: enable=W1701
# pylint: enable=W1701
def run(self):
"""
Start the load balancer
"""
self._socket = _get_socket(self.opts)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(self._socket, self.opts)
self._socket.setblocking(1)
self._socket.bind(_get_bind_addr(self.opts, "ret_port"))
self._socket.listen(self.backlog)
def run(self):
"""
Start the load balancer
"""
self._socket = _get_socket(self.opts)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(self._socket, self.opts)
self._socket.setblocking(1)
self._socket.bind(_get_bind_addr(self.opts, "ret_port"))
self._socket.listen(self.backlog)
while True:
try:
# Wait for a connection to occur since the socket is
# blocking.
connection, address = self._socket.accept()
# Wait for a free slot to be available to put
# the connection into.
# Sockets are picklable on Windows in Python 3.
self.socket_queue.put((connection, address), True, None)
except OSError as e:
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if (
salt.ext.tornado.util.errno_from_exception(e)
== errno.ECONNABORTED
):
continue
raise
while True:
try:
# Wait for a connection to occur since the socket is
# blocking.
connection, address = self._socket.accept()
# Wait for a free slot to be available to put
# the connection into.
# Sockets are picklable on Windows in Python 3.
self.socket_queue.put((connection, address), True, None)
except OSError as e:
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if salt.ext.tornado.util.errno_from_exception(e) == errno.ECONNABORTED:
continue
raise
class Resolver:
@ -468,45 +463,43 @@ class SaltMessageServer(salt.ext.tornado.tcpserver.TCPServer):
raise
if USE_LOAD_BALANCER:
class LoadBalancerWorker(SaltMessageServer):
"""
This will receive TCP connections from 'LoadBalancerServer' via
a multiprocessing queue.
Since the queue is shared amongst workers, only one worker will handle
a given connection.
"""
class LoadBalancerWorker(SaltMessageServer):
"""
This will receive TCP connections from 'LoadBalancerServer' via
a multiprocessing queue.
Since the queue is shared amongst workers, only one worker will handle
a given connection.
"""
def __init__(self, socket_queue, message_handler, *args, **kwargs):
super().__init__(message_handler, *args, **kwargs)
self.socket_queue = socket_queue
self._stop = threading.Event()
self.thread = threading.Thread(target=self.socket_queue_thread)
self.thread.start()
def __init__(self, socket_queue, message_handler, *args, **kwargs):
super().__init__(message_handler, *args, **kwargs)
self.socket_queue = socket_queue
self._stop = threading.Event()
self.thread = threading.Thread(target=self.socket_queue_thread)
self.thread.start()
def close(self):
self._stop.set()
self.thread.join()
super().close()
def close(self):
self._stop.set()
self.thread.join()
super().close()
def socket_queue_thread(self):
try:
while True:
try:
client_socket, address = self.socket_queue.get(True, 1)
except queue.Empty:
if self._stop.is_set():
break
continue
# 'self.io_loop' initialized in super class
# 'salt.ext.tornado.tcpserver.TCPServer'.
# 'self._handle_connection' defined in same super class.
self.io_loop.spawn_callback(
self._handle_connection, client_socket, address
)
except (KeyboardInterrupt, SystemExit):
pass
def socket_queue_thread(self):
try:
while True:
try:
client_socket, address = self.socket_queue.get(True, 1)
except queue.Empty:
if self._stop.is_set():
break
continue
# 'self.io_loop' initialized in super class
# 'salt.ext.tornado.tcpserver.TCPServer'.
# 'self._handle_connection' defined in same super class.
self.io_loop.spawn_callback(
self._handle_connection, client_socket, address
)
except (KeyboardInterrupt, SystemExit):
pass
class TCPClientKeepAlive(salt.ext.tornado.tcpclient.TCPClient):
@ -583,10 +576,6 @@ class MessageClient:
self.backoff = opts.get("tcp_reconnect_backoff", 1)
def _stop_io_loop(self):
if self.io_loop is not None:
self.io_loop.stop()
# TODO: timeout inflight sessions
def close(self):
if self._closing:
@ -962,6 +951,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer):
"""
io_loop = salt.ext.tornado.ioloop.IOLoop()
io_loop.make_current()
self.io_loop = io_loop
# Spin up the publisher
self.pub_server = pub_server = PubServer(

View file

@ -11,8 +11,8 @@ import threading
from random import randint
import zmq.error
import zmq.eventloop.zmqstream
import zmq.eventloop.future
import zmq.eventloop.zmqstream
import salt.ext.tornado
import salt.ext.tornado.concurrent

View file

@ -164,6 +164,7 @@ def test_pub_server_channel(
log.info("TEST - Req Server handle payload %r", payload)
req_server_channel.post_fork(handle_payload, io_loop=io_loop)
if master_config["transport"] == "zeromq":
time.sleep(1)
attempts = 5

View file

@ -0,0 +1,55 @@
import multiprocessing
import socket
import threading
import time
import pytest
import salt.transport.tcp
pytestmark = [
pytest.mark.core_test,
]
def test_tcp_load_balancer_server(master_opts, io_loop):
messages = []
def handler(stream, message, header):
messages.append(message)
queue = multiprocessing.Queue()
server = salt.transport.tcp.LoadBalancerServer(master_opts, queue)
worker = salt.transport.tcp.LoadBalancerWorker(queue, handler, io_loop=io_loop)
def run_loop():
io_loop.start()
loop_thread = threading.Thread(target=run_loop)
loop_thread.start()
thread = threading.Thread(target=server.run)
thread.start()
# Wait for bind to happen.
time.sleep(0.5)
package = {"foo": "bar"}
payload = salt.transport.frame.frame_msg(package)
sock = socket.socket()
sock.connect(("127.0.0.1", master_opts["ret_port"]))
sock.send(payload)
try:
start = time.monotonic()
while not messages:
time.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "Took longer than 30 seconds to receive message"
assert [package] == messages
finally:
server.close()
thread.join()
io_loop.stop()
worker.close()

View file

@ -0,0 +1,58 @@
import threading
import time
import salt.ext.tornado.gen
import salt.transport.tcp
async def test_pub_channel(master_opts, minion_opts, io_loop):
def presence_callback(client):
pass
def remove_presence_callback(client):
pass
master_opts["transport"] = "tcp"
minion_opts.update(master_ip="127.0.0.1", transport="tcp")
server = salt.transport.tcp.TCPPublishServer(master_opts)
client = salt.transport.tcp.TCPPubClient(minion_opts, io_loop)
payloads = []
publishes = []
def publish_payload(payload, callback):
server.publish_payload(payload)
payloads.append(payload)
def on_recv(message):
print("ON RECV")
publishes.append(message)
thread = threading.Thread(
target=server.publish_daemon,
args=(publish_payload, presence_callback, remove_presence_callback),
)
thread.start()
# Wait for socket to bind.
time.sleep(3)
await client.connect(master_opts["publish_port"])
client.on_recv(on_recv)
print("Publish message")
server.publish({"meh": "bah"})
start = time.monotonic()
try:
while not publishes:
await salt.ext.tornado.gen.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "Message not published after 30 seconds"
finally:
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)

View file

@ -1,7 +1,10 @@
import logging
import threading
import time
import pytest
import salt.transport.zeromq
from tests.support.mock import MagicMock, patch
from tests.support.pytest.transport import PubServerChannelProcess
@ -51,3 +54,86 @@ def test_zeromq_filtering(salt_master, salt_minion):
assert len(results) == send_num, "{} != {}, difference: {}".format(
len(results), send_num, set(expect).difference(results)
)
def test_pub_channel(master_opts):
server = salt.transport.zeromq.PublishServer(master_opts)
payloads = []
def publish_payload(payload):
server.publish_payload(payload)
payloads.append(payload)
thread = threading.Thread(target=server.publish_daemon, args=(publish_payload,))
thread.start()
server.publish({"meh": "bah"})
start = time.monotonic()
try:
while not payloads:
time.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "No message received after 30 seconds"
finally:
server.close()
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)
def test_pub_channel_filtering(master_opts):
master_opts["zmq_filtering"] = True
server = salt.transport.zeromq.PublishServer(master_opts)
payloads = []
def publish_payload(payload):
server.publish_payload(payload)
payloads.append(payload)
thread = threading.Thread(target=server.publish_daemon, args=(publish_payload,))
thread.start()
server.publish({"meh": "bah"})
start = time.monotonic()
try:
while not payloads:
time.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "No message received after 30 seconds"
finally:
server.close()
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)
def test_pub_channel_filtering_topic(master_opts):
master_opts["zmq_filtering"] = True
server = salt.transport.zeromq.PublishServer(master_opts)
payloads = []
def publish_payload(payload):
server.publish_payload(payload, topic_list=["meh"])
payloads.append(payload)
thread = threading.Thread(target=server.publish_daemon, args=(publish_payload,))
thread.start()
server.publish({"meh": "bah"})
start = time.monotonic()
try:
while not payloads:
time.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "No message received after 30 seconds"
finally:
server.close()
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)

View file

@ -32,8 +32,8 @@ async def test_request_channel_issue_64627(io_loop, minion_opts, port):
request_client = salt.transport.zeromq.RequestClient(minion_opts, io_loop)
rep = await request_client.send(b"foo")
req_socket = request_client.message_client.stream.socket
req_socket = request_client.message_client.socket
rep = await request_client.send(b"foo")
assert req_socket is request_client.message_client.stream.socket
assert req_socket is request_client.message_client.socket
request_client.close()
assert request_client.message_client.stream is None
assert request_client.message_client.socket is None

View file

@ -9,10 +9,11 @@ from pytestshellutils.utils import ports
import salt.channel.server
import salt.exceptions
import salt.ext.tornado
import salt.ext.tornado.concurrent
import salt.transport.tcp
from tests.support.mock import MagicMock, PropertyMock, patch
pytestmark = [
tpytestmark = [
pytest.mark.core_test,
]
@ -483,3 +484,185 @@ def test_presence_removed_on_stream_closed():
io_loop.run_sync(functools.partial(server.publish_payload, package, None))
server.remove_presence_callback.assert_called_with(client)
async def test_tcp_pub_client_decode_dict(minion_opts, io_loop):
dmsg = {"meh": "bah"}
client = salt.transport.tcp.TCPPubClient(minion_opts, io_loop)
assert dmsg == await client._decode_messages(dmsg)
async def test_tcp_pub_client_decode_msgpack(minion_opts, io_loop):
dmsg = {"meh": "bah"}
msg = salt.payload.dumps(dmsg)
client = salt.transport.tcp.TCPPubClient(minion_opts, io_loop)
assert dmsg == await client._decode_messages(msg)
def test_tcp_pub_client_close(minion_opts, io_loop):
client = salt.transport.tcp.TCPPubClient(minion_opts, io_loop)
message_client = MagicMock()
client.message_client = message_client
client.close()
assert client._closing is True
assert client.message_client is None
client.close()
message_client.close.assert_called_once_with()
async def test_pub_server__stream_read(master_opts, io_loop):
messages = [salt.transport.frame.frame_msg({"foo": "bar"})]
class Stream:
def __init__(self, messages):
self.messages = messages
def read_bytes(self, *args, **kwargs):
if self.messages:
msg = self.messages.pop(0)
future = salt.ext.tornado.concurrent.Future()
future.set_result(msg)
return future
raise salt.ext.tornado.iostream.StreamClosedError()
client = MagicMock()
client.stream = Stream(messages)
client.address = "client address"
server = salt.transport.tcp.PubServer(master_opts, io_loop)
await server._stream_read(client)
client.close.assert_called_once()
async def test_pub_server__stream_read_exception(master_opts, io_loop):
client = MagicMock()
client.stream = MagicMock()
client.stream.read_bytes = MagicMock(
side_effect=[
Exception("Something went wrong"),
salt.ext.tornado.iostream.StreamClosedError(),
]
)
client.address = "client address"
server = salt.transport.tcp.PubServer(master_opts, io_loop)
await server._stream_read(client)
client.close.assert_called_once()
async def test_salt_message_server(master_opts):
received = []
def handler(stream, body, header):
received.append(body)
server = salt.transport.tcp.SaltMessageServer(handler)
msg = {"foo": "bar"}
messages = [salt.transport.frame.frame_msg(msg)]
class Stream:
def __init__(self, messages):
self.messages = messages
def read_bytes(self, *args, **kwargs):
if self.messages:
msg = self.messages.pop(0)
future = salt.ext.tornado.concurrent.Future()
future.set_result(msg)
return future
raise salt.ext.tornado.iostream.StreamClosedError()
stream = Stream(messages)
address = "client address"
await server.handle_stream(stream, address)
# Let loop iterate so callback gets called
await salt.ext.tornado.gen.sleep(0.01)
assert received
assert [msg] == received
async def test_salt_message_server_exception(master_opts, io_loop):
received = []
def handler(stream, body, header):
received.append(body)
stream = MagicMock()
stream.read_bytes = MagicMock(
side_effect=[
Exception("Something went wrong"),
]
)
address = "client address"
server = salt.transport.tcp.SaltMessageServer(handler)
await server.handle_stream(stream, address)
stream.close.assert_called_once()
async def test_message_client_stream_return_exception(minion_opts, io_loop):
msg = {"foo": "bar"}
payload = salt.transport.frame.frame_msg(msg)
future = salt.ext.tornado.concurrent.Future()
future.set_result(payload)
client = salt.transport.tcp.MessageClient(
minion_opts,
"127.0.0.1",
12345,
connect_callback=MagicMock(),
disconnect_callback=MagicMock(),
)
client._stream = MagicMock()
client._stream.read_bytes.side_effect = [
future,
]
try:
io_loop.add_callback(client._stream_return)
await salt.ext.tornado.gen.sleep(0.01)
client.close()
await salt.ext.tornado.gen.sleep(0.01)
assert client._stream is None
finally:
client.close()
def test_tcp_pub_server_pre_fork(master_opts):
process_manager = MagicMock()
server = salt.transport.tcp.TCPPublishServer(master_opts)
server.pre_fork(process_manager)
async def test_pub_server_publish_payload(master_opts, io_loop):
server = salt.transport.tcp.PubServer(master_opts, io_loop=io_loop)
package = {"foo": "bar"}
topic_list = ["meh"]
future = salt.ext.tornado.concurrent.Future()
future.set_result(None)
client = MagicMock()
client.stream = MagicMock()
client.stream.write.side_effect = [future]
client.id_ = "meh"
server.clients = [client]
await server.publish_payload(package, topic_list)
client.stream.write.assert_called_once()
async def test_pub_server_publish_payload_closed_stream(master_opts, io_loop):
server = salt.transport.tcp.PubServer(master_opts, io_loop=io_loop)
package = {"foo": "bar"}
topic_list = ["meh"]
client = MagicMock()
client.stream = MagicMock()
client.stream.write.side_effect = [
salt.ext.tornado.iostream.StreamClosedError("mock")
]
client.id_ = "meh"
server.clients = {client}
await server.publish_payload(package, topic_list)
assert server.clients == set()

View file

@ -1414,6 +1414,7 @@ async def test_req_server_garbage_request(io_loop):
RequestServers's message handler.
"""
opts = salt.config.master_config("")
opts["zmq_monitor"] = True
request_server = salt.transport.zeromq.RequestServer(opts)
def message_handler(payload):
@ -1486,3 +1487,14 @@ async def test_client_timeout_msg(minion_opts):
await client.send({"meh": "bah"}, 1)
finally:
client.close()
def test_pub_client_init(minion_opts, io_loop):
minion_opts["id"] = "minion"
minion_opts["__role"] = "syndic"
minion_opts["master_ip"] = "127.0.0.1"
minion_opts["zmq_filtering"] = True
minion_opts["zmq_monitor"] = True
client = salt.transport.zeromq.PublishClient(minion_opts, io_loop)
client.send(b"asf")
client.close()