Better testing of ssl opts and ws transport

This commit is contained in:
Daniel A. Wozniak 2023-08-12 03:06:27 -07:00 committed by Daniel Wozniak
parent 2bf2936f73
commit f62f6469ff
6 changed files with 147 additions and 43 deletions

View file

@ -1,10 +1,14 @@
import hashlib
import logging
import os
import ssl
import traceback
import warnings
import salt.utils.stringutils
log = logging.getLogger(__name__)
TRANSPORTS = (
"zeromq",
"tcp",
@ -127,17 +131,17 @@ def publish_client(
elif "transport" in opts.get("pillar", {}).get("master", {}):
ttype = opts["pillar"]["master"]["transport"]
ssl = None
ssl_opts = None
if "ssl" in kwargs:
ssl = kwargs["ssl"]
ssl_opts = kwargs["ssl"]
elif opts.get("ssl", None) is not None:
ssl = opts["ssl"]
ssl_opts = opts["ssl"]
# switch on available ttypes
if ttype == "zeromq":
import salt.transport.zeromq
if ssl:
if ssl_opts:
log.warning("TLS not supported with zeromq transport")
return salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port, path=path
@ -151,7 +155,7 @@ def publish_client(
host=host,
port=port,
path=path,
ssl=ssl,
ssl=ssl_opts,
)
elif ttype == "ws":
import salt.transport.ws
@ -162,7 +166,7 @@ def publish_client(
host=host,
port=port,
path=path,
ssl=ssl,
ssl=ssl_opts,
)
raise Exception(f"Transport type not found: {ttype}")
@ -432,8 +436,6 @@ class PublishClient(Transport):
def ssl_context(ssl_options, server_side=False):
if isinstance(ssl_options, ssl.SSLContext):
return ssl_options
default_version = ssl.PROTOCOL_TLS
if server_side:
default_version = ssl.PROTOCOL_TLS_SERVER
@ -445,27 +447,28 @@ def ssl_context(ssl_options, server_side=False):
ssl_options["certfile"], ssl_options.get("keyfile", None)
)
if "cert_reqs" in ssl_options:
if ssl_options["cert_reqs"] == ssl.CERT_NONE:
if ssl_options["cert_reqs"].upper() == "CERT_NONE":
# This may have been set automatically by PROTOCOL_TLS_CLIENT but is
# incompatible with CERT_NONE so we must manually clear it.
context.check_hostname = False
context.verify_mode = getattr(ssl, VerifyMode, ssl_options["cert_reqs"])
context.verify_mode = getattr(ssl.VerifyMode, ssl_options["cert_reqs"])
if "ca_certs" in ssl_options:
context.load_verify_locations(ssl_options["ca_certs"])
if "verify_locations" in ssl_options:
for _ in ssl_options["verify_locations"]:
if _.lower().startswith("cafile:"):
cafile = _[7:]
context.load_verify_locations(cafile=cafile)
elif _.lower().startswith("capath:"):
capath = _[7:]
context.load_verify_locations(capath=capath)
elif _.lower().startswith("cadata:"):
cadata = _[7:]
context.load_verify_locations(cadata=cadata)
if isinstance(_, dict):
for key in _:
if key.lower() == "cafile":
context.load_verify_locations(cafile=_[key])
elif key.lower() == "capath":
context.load_verify_locations(capath=_[key])
elif key.lower() == "cadata":
context.load_verify_locations(cadata=_[key])
else:
log.warning("Unkown verify location type: %s", key)
else:
cafile = _
context.load_verify_locations(cafile=cafile)
context.load_verify_locations(cafile=_)
if "verify_flags" in ssl_options:
for flag in ssl_options["verify_flags"]:
context.verify_flags |= getattr(ssl.VerifyFlags, flag.upper())

View file

@ -18,7 +18,6 @@ import time
import urllib
import uuid
import warnings
import ssl
import tornado
import tornado.concurrent
@ -31,8 +30,8 @@ import tornado.util
import salt.master
import salt.payload
import salt.transport.frame
import salt.transport.base
import salt.transport.frame
import salt.utils.asynchronous
import salt.utils.files
import salt.utils.msgpack
@ -1116,7 +1115,7 @@ class PubServer(tornado.tcpserver.TCPServer):
to_remove = []
if topic_list:
for topic in topic_list:
sent = Falses
sent = False
for client in list(self.clients):
if topic == client.id_:
try:

View file

@ -4,7 +4,6 @@ import multiprocessing
import socket
import time
import warnings
import ssl
import aiohttp
import aiohttp.web
@ -109,6 +108,7 @@ class PublishClient(salt.transport.base.PublishClient):
start = time.monotonic()
timeout = kwargs.get("timeout", None)
while ws is None and (not self._closed and not self._closing):
session = None
try:
ctx = None
if self.ssl is not None:
@ -139,6 +139,8 @@ class PublishClient(salt.transport.base.PublishClient):
exc,
self.backoff,
)
if session:
await session.close()
if timeout and time.monotonic() - start > timeout:
break
await asyncio.sleep(self.backoff)
@ -374,9 +376,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
log.info(
"Publisher binding to socket %s:%s", (self.pub_host, self.pub_port)
)
log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port)
await site.start()
if self.pull_path:
@ -548,7 +548,7 @@ class RequestClient(salt.transport.base.RequestClient):
self.io_loop = io_loop
self._closing = False
self._closed = False
self.ssl = self.opts("ssl", None)
self.ssl = self.opts.get("ssl", None)
async def connect(self):
ctx = None

View file

@ -53,7 +53,7 @@ def transport_ids(value):
return f"transport({value})"
@pytest.fixture(params=["tcp", "zeromq"], ids=transport_ids)
@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids)
def transport(request):
return request.param

View file

@ -1,9 +1,15 @@
<<<<<<< HEAD
"""
Unit tests for salt.transport.base.
"""
import contextlib
import ssl
import pytest
import salt.transport.base
from tests.support.helpers import dedent
from tests.support.mock import Mock, patch
pytestmark = [
pytest.mark.core_test,
@ -19,3 +25,63 @@ def test_unclosed_warning():
assert transport._connect_called is True
with pytest.warns(salt.transport.base.TransportWarning):
del transport
@patch('ssl.SSLContext')
def test_ssl_context_legacy_opts(mock):
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_NONE",
"ca_certs": "ca.crt",
})
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_called_with(
"ca.crt"
)
assert ssl.VerifyMode.CERT_NONE == ctx.verify_mode
assert not ctx.check_hostname
@patch('ssl.SSLContext')
def test_ssl_context_opts(mock):
mock.verify_flags = ssl.VerifyFlags.VERIFY_X509_TRUSTED_FIRST
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_OPTIONAL",
"verify_locations": [
"ca.crt",
{"cafile": "crl.pem"},
{"capath": "/tmp/mycapathsdf"},
{"cadata": "mycadataother"},
{"CADATA": "mycadatasdf"},
],
"verify_flags": [
"VERIFY_CRL_CHECK_CHAIN",
]
})
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_any_call(
cafile="ca.crt"
)
ctx.load_verify_locations.assert_any_call(
cafile="crl.pem"
)
ctx.load_verify_locations.assert_any_call(
capath="/tmp/mycapathsdf"
)
ctx.load_verify_locations.assert_any_call(
cadata="mycadataother"
)
ctx.load_verify_locations.assert_called_with(
cadata="mycadatasdf"
)
assert ssl.VerifyMode.CERT_OPTIONAL == ctx.verify_mode
assert ctx.check_hostname
assert ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN & ctx.verify_flags

View file

@ -13,6 +13,7 @@ import tornado.ioloop
import salt.crypt
import salt.transport.tcp
import salt.transport.ws
import salt.transport.zeromq
import salt.utils.stringutils
from tests.support.mock import MagicMock, patch
@ -29,7 +30,7 @@ def transport_ids(value):
return f"Transport({value})"
@pytest.fixture(params=("zeromq", "tcp"), ids=transport_ids)
@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
def transport(request):
return request.param
@ -171,29 +172,28 @@ async def test_publish_client_connect_server_down(transport, io_loop):
await client.connect()
assert client._socket
elif transport == "tcp":
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
try:
# XXX: This is an implimentation detail of the tcp transport.
# await client.connect(port)
io_loop.spawn_callback(client.connect)
except TimeoutError:
pass
except Exception: # pylint: disable=broad-except
log.error("Got exception", exc_info=True)
client = salt.transport.tcp.PublishClient(opts, io_loop, host=host, port=port)
io_loop.spawn_callback(client.connect)
assert client._stream is None
elif transport == "ws":
client = salt.transport.ws.PublishClient(opts, io_loop, host=host, port=port)
io_loop.spawn_callback(client.connect)
assert client._ws is None
assert client._session is None
client.close()
await asyncio.sleep(0.03)
async def test_publish_client_connect_server_comes_up(transport, io_loop):
opts = {"master_ip": "127.0.0.1"}
host = "127.0.0.1"
port = 11122
msg = salt.payload.dumps({"meh": 123})
if transport == "zeromq":
import zmq
ctx = zmq.asyncio.Context()
uri = f"tcp://{opts['master_ip']}:{port}"
msg = salt.payload.dumps({"meh": 123})
log.debug("TEST - Senging %r", msg)
client = salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port
@ -213,7 +213,8 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
task = asyncio.create_task(recv())
# Sleep to allow zmq to do it's thing.
await sock.send(msg)
await socket.send(msg)
await asyncio.sleep(0.03)
await task
response = task.result()
assert response
@ -223,7 +224,7 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
ctx.term()
elif transport == "tcp":
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
client = salt.transport.tcp.PublishClient(opts, io_loop, host=host, port=port)
# XXX: This is an implimentation detail of the tcp transport.
# await client.connect(port)
io_loop.spawn_callback(client.connect)
@ -254,6 +255,41 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
conn.send(msg)
response = await client.recv()
assert response
assert msg == response
elif transport == "ws":
import socket
import aiohttp
client = salt.transport.ws.PublishClient(opts, io_loop, host=host, port=port)
io_loop.spawn_callback(client.connect)
assert client._ws is None
assert client._session is None
await asyncio.sleep(2)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind((opts["master_ip"], port))
sock.listen(128)
async def handler(request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
data = salt.transport.frame.frame_msg(msg, header=None)
await ws.send_bytes(data)
server = aiohttp.web.Server(handler)
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
site = aiohttp.web.SockSite(runner, sock)
await site.start()
await asyncio.sleep(0.03)
response = await client.recv()
assert msg == response
else:
raise Exception(f"Unknown transport {transport}")
client.close()
await asyncio.sleep(0.03)