mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Better testing of ssl opts and ws transport
This commit is contained in:
parent
2bf2936f73
commit
f62f6469ff
6 changed files with 147 additions and 43 deletions
|
@ -1,10 +1,14 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
import salt.utils.stringutils
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
TRANSPORTS = (
|
||||
"zeromq",
|
||||
"tcp",
|
||||
|
@ -127,17 +131,17 @@ def publish_client(
|
|||
elif "transport" in opts.get("pillar", {}).get("master", {}):
|
||||
ttype = opts["pillar"]["master"]["transport"]
|
||||
|
||||
ssl = None
|
||||
ssl_opts = None
|
||||
if "ssl" in kwargs:
|
||||
ssl = kwargs["ssl"]
|
||||
ssl_opts = kwargs["ssl"]
|
||||
elif opts.get("ssl", None) is not None:
|
||||
ssl = opts["ssl"]
|
||||
ssl_opts = opts["ssl"]
|
||||
|
||||
# switch on available ttypes
|
||||
if ttype == "zeromq":
|
||||
import salt.transport.zeromq
|
||||
|
||||
if ssl:
|
||||
if ssl_opts:
|
||||
log.warning("TLS not supported with zeromq transport")
|
||||
return salt.transport.zeromq.PublishClient(
|
||||
opts, io_loop, host=host, port=port, path=path
|
||||
|
@ -151,7 +155,7 @@ def publish_client(
|
|||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
ssl=ssl,
|
||||
ssl=ssl_opts,
|
||||
)
|
||||
elif ttype == "ws":
|
||||
import salt.transport.ws
|
||||
|
@ -162,7 +166,7 @@ def publish_client(
|
|||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
ssl=ssl,
|
||||
ssl=ssl_opts,
|
||||
)
|
||||
|
||||
raise Exception(f"Transport type not found: {ttype}")
|
||||
|
@ -432,8 +436,6 @@ class PublishClient(Transport):
|
|||
|
||||
|
||||
def ssl_context(ssl_options, server_side=False):
|
||||
if isinstance(ssl_options, ssl.SSLContext):
|
||||
return ssl_options
|
||||
default_version = ssl.PROTOCOL_TLS
|
||||
if server_side:
|
||||
default_version = ssl.PROTOCOL_TLS_SERVER
|
||||
|
@ -445,27 +447,28 @@ def ssl_context(ssl_options, server_side=False):
|
|||
ssl_options["certfile"], ssl_options.get("keyfile", None)
|
||||
)
|
||||
if "cert_reqs" in ssl_options:
|
||||
if ssl_options["cert_reqs"] == ssl.CERT_NONE:
|
||||
if ssl_options["cert_reqs"].upper() == "CERT_NONE":
|
||||
# This may have been set automatically by PROTOCOL_TLS_CLIENT but is
|
||||
# incompatible with CERT_NONE so we must manually clear it.
|
||||
context.check_hostname = False
|
||||
context.verify_mode = getattr(ssl, VerifyMode, ssl_options["cert_reqs"])
|
||||
context.verify_mode = getattr(ssl.VerifyMode, ssl_options["cert_reqs"])
|
||||
if "ca_certs" in ssl_options:
|
||||
context.load_verify_locations(ssl_options["ca_certs"])
|
||||
if "verify_locations" in ssl_options:
|
||||
for _ in ssl_options["verify_locations"]:
|
||||
if _.lower().startswith("cafile:"):
|
||||
cafile = _[7:]
|
||||
context.load_verify_locations(cafile=cafile)
|
||||
elif _.lower().startswith("capath:"):
|
||||
capath = _[7:]
|
||||
context.load_verify_locations(capath=capath)
|
||||
elif _.lower().startswith("cadata:"):
|
||||
cadata = _[7:]
|
||||
context.load_verify_locations(cadata=cadata)
|
||||
if isinstance(_, dict):
|
||||
for key in _:
|
||||
if key.lower() == "cafile":
|
||||
context.load_verify_locations(cafile=_[key])
|
||||
elif key.lower() == "capath":
|
||||
context.load_verify_locations(capath=_[key])
|
||||
elif key.lower() == "cadata":
|
||||
context.load_verify_locations(cadata=_[key])
|
||||
else:
|
||||
log.warning("Unkown verify location type: %s", key)
|
||||
else:
|
||||
cafile = _
|
||||
context.load_verify_locations(cafile=cafile)
|
||||
context.load_verify_locations(cafile=_)
|
||||
if "verify_flags" in ssl_options:
|
||||
for flag in ssl_options["verify_flags"]:
|
||||
context.verify_flags |= getattr(ssl.VerifyFlags, flag.upper())
|
||||
|
|
|
@ -18,7 +18,6 @@ import time
|
|||
import urllib
|
||||
import uuid
|
||||
import warnings
|
||||
import ssl
|
||||
|
||||
import tornado
|
||||
import tornado.concurrent
|
||||
|
@ -31,8 +30,8 @@ import tornado.util
|
|||
|
||||
import salt.master
|
||||
import salt.payload
|
||||
import salt.transport.frame
|
||||
import salt.transport.base
|
||||
import salt.transport.frame
|
||||
import salt.utils.asynchronous
|
||||
import salt.utils.files
|
||||
import salt.utils.msgpack
|
||||
|
@ -1116,7 +1115,7 @@ class PubServer(tornado.tcpserver.TCPServer):
|
|||
to_remove = []
|
||||
if topic_list:
|
||||
for topic in topic_list:
|
||||
sent = Falses
|
||||
sent = False
|
||||
for client in list(self.clients):
|
||||
if topic == client.id_:
|
||||
try:
|
||||
|
|
|
@ -4,7 +4,6 @@ import multiprocessing
|
|||
import socket
|
||||
import time
|
||||
import warnings
|
||||
import ssl
|
||||
|
||||
import aiohttp
|
||||
import aiohttp.web
|
||||
|
@ -109,6 +108,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
start = time.monotonic()
|
||||
timeout = kwargs.get("timeout", None)
|
||||
while ws is None and (not self._closed and not self._closing):
|
||||
session = None
|
||||
try:
|
||||
ctx = None
|
||||
if self.ssl is not None:
|
||||
|
@ -139,6 +139,8 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
exc,
|
||||
self.backoff,
|
||||
)
|
||||
if session:
|
||||
await session.close()
|
||||
if timeout and time.monotonic() - start > timeout:
|
||||
break
|
||||
await asyncio.sleep(self.backoff)
|
||||
|
@ -374,9 +376,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
runner = aiohttp.web.ServerRunner(server)
|
||||
await runner.setup()
|
||||
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
|
||||
log.info(
|
||||
"Publisher binding to socket %s:%s", (self.pub_host, self.pub_port)
|
||||
)
|
||||
log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port)
|
||||
await site.start()
|
||||
|
||||
if self.pull_path:
|
||||
|
@ -548,7 +548,7 @@ class RequestClient(salt.transport.base.RequestClient):
|
|||
self.io_loop = io_loop
|
||||
self._closing = False
|
||||
self._closed = False
|
||||
self.ssl = self.opts("ssl", None)
|
||||
self.ssl = self.opts.get("ssl", None)
|
||||
|
||||
async def connect(self):
|
||||
ctx = None
|
||||
|
|
|
@ -53,7 +53,7 @@ def transport_ids(value):
|
|||
return f"transport({value})"
|
||||
|
||||
|
||||
@pytest.fixture(params=["tcp", "zeromq"], ids=transport_ids)
|
||||
@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids)
|
||||
def transport(request):
|
||||
return request.param
|
||||
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
<<<<<<< HEAD
|
||||
"""
|
||||
Unit tests for salt.transport.base.
|
||||
"""
|
||||
import contextlib
|
||||
import ssl
|
||||
|
||||
import pytest
|
||||
|
||||
import salt.transport.base
|
||||
from tests.support.helpers import dedent
|
||||
from tests.support.mock import Mock, patch
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.core_test,
|
||||
|
@ -19,3 +25,63 @@ def test_unclosed_warning():
|
|||
assert transport._connect_called is True
|
||||
with pytest.warns(salt.transport.base.TransportWarning):
|
||||
del transport
|
||||
|
||||
@patch('ssl.SSLContext')
|
||||
def test_ssl_context_legacy_opts(mock):
|
||||
ctx = salt.transport.base.ssl_context({
|
||||
'certfile': "server.crt",
|
||||
'keyfile': "server.key",
|
||||
'cert_reqs': "CERT_NONE",
|
||||
"ca_certs": "ca.crt",
|
||||
})
|
||||
ctx.load_cert_chain.assert_called_with(
|
||||
"server.crt",
|
||||
"server.key",
|
||||
)
|
||||
ctx.load_verify_locations.assert_called_with(
|
||||
"ca.crt"
|
||||
)
|
||||
assert ssl.VerifyMode.CERT_NONE == ctx.verify_mode
|
||||
assert not ctx.check_hostname
|
||||
|
||||
|
||||
@patch('ssl.SSLContext')
|
||||
def test_ssl_context_opts(mock):
|
||||
mock.verify_flags = ssl.VerifyFlags.VERIFY_X509_TRUSTED_FIRST
|
||||
ctx = salt.transport.base.ssl_context({
|
||||
'certfile': "server.crt",
|
||||
'keyfile': "server.key",
|
||||
'cert_reqs': "CERT_OPTIONAL",
|
||||
"verify_locations": [
|
||||
"ca.crt",
|
||||
{"cafile": "crl.pem"},
|
||||
{"capath": "/tmp/mycapathsdf"},
|
||||
{"cadata": "mycadataother"},
|
||||
{"CADATA": "mycadatasdf"},
|
||||
],
|
||||
"verify_flags": [
|
||||
"VERIFY_CRL_CHECK_CHAIN",
|
||||
]
|
||||
})
|
||||
ctx.load_cert_chain.assert_called_with(
|
||||
"server.crt",
|
||||
"server.key",
|
||||
)
|
||||
ctx.load_verify_locations.assert_any_call(
|
||||
cafile="ca.crt"
|
||||
)
|
||||
ctx.load_verify_locations.assert_any_call(
|
||||
cafile="crl.pem"
|
||||
)
|
||||
ctx.load_verify_locations.assert_any_call(
|
||||
capath="/tmp/mycapathsdf"
|
||||
)
|
||||
ctx.load_verify_locations.assert_any_call(
|
||||
cadata="mycadataother"
|
||||
)
|
||||
ctx.load_verify_locations.assert_called_with(
|
||||
cadata="mycadatasdf"
|
||||
)
|
||||
assert ssl.VerifyMode.CERT_OPTIONAL == ctx.verify_mode
|
||||
assert ctx.check_hostname
|
||||
assert ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN & ctx.verify_flags
|
||||
|
|
|
@ -13,6 +13,7 @@ import tornado.ioloop
|
|||
|
||||
import salt.crypt
|
||||
import salt.transport.tcp
|
||||
import salt.transport.ws
|
||||
import salt.transport.zeromq
|
||||
import salt.utils.stringutils
|
||||
from tests.support.mock import MagicMock, patch
|
||||
|
@ -29,7 +30,7 @@ def transport_ids(value):
|
|||
return f"Transport({value})"
|
||||
|
||||
|
||||
@pytest.fixture(params=("zeromq", "tcp"), ids=transport_ids)
|
||||
@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
|
||||
def transport(request):
|
||||
return request.param
|
||||
|
||||
|
@ -171,29 +172,28 @@ async def test_publish_client_connect_server_down(transport, io_loop):
|
|||
await client.connect()
|
||||
assert client._socket
|
||||
elif transport == "tcp":
|
||||
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
|
||||
try:
|
||||
# XXX: This is an implimentation detail of the tcp transport.
|
||||
# await client.connect(port)
|
||||
io_loop.spawn_callback(client.connect)
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log.error("Got exception", exc_info=True)
|
||||
client = salt.transport.tcp.PublishClient(opts, io_loop, host=host, port=port)
|
||||
io_loop.spawn_callback(client.connect)
|
||||
assert client._stream is None
|
||||
elif transport == "ws":
|
||||
client = salt.transport.ws.PublishClient(opts, io_loop, host=host, port=port)
|
||||
io_loop.spawn_callback(client.connect)
|
||||
assert client._ws is None
|
||||
assert client._session is None
|
||||
client.close()
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
|
||||
async def test_publish_client_connect_server_comes_up(transport, io_loop):
|
||||
opts = {"master_ip": "127.0.0.1"}
|
||||
host = "127.0.0.1"
|
||||
port = 11122
|
||||
msg = salt.payload.dumps({"meh": 123})
|
||||
if transport == "zeromq":
|
||||
import zmq
|
||||
|
||||
ctx = zmq.asyncio.Context()
|
||||
uri = f"tcp://{opts['master_ip']}:{port}"
|
||||
msg = salt.payload.dumps({"meh": 123})
|
||||
log.debug("TEST - Senging %r", msg)
|
||||
client = salt.transport.zeromq.PublishClient(
|
||||
opts, io_loop, host=host, port=port
|
||||
|
@ -213,7 +213,8 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
|
|||
|
||||
task = asyncio.create_task(recv())
|
||||
# Sleep to allow zmq to do it's thing.
|
||||
await sock.send(msg)
|
||||
await socket.send(msg)
|
||||
await asyncio.sleep(0.03)
|
||||
await task
|
||||
response = task.result()
|
||||
assert response
|
||||
|
@ -223,7 +224,7 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
|
|||
ctx.term()
|
||||
elif transport == "tcp":
|
||||
|
||||
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
|
||||
client = salt.transport.tcp.PublishClient(opts, io_loop, host=host, port=port)
|
||||
# XXX: This is an implimentation detail of the tcp transport.
|
||||
# await client.connect(port)
|
||||
io_loop.spawn_callback(client.connect)
|
||||
|
@ -254,6 +255,41 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
|
|||
|
||||
conn.send(msg)
|
||||
response = await client.recv()
|
||||
assert response
|
||||
assert msg == response
|
||||
elif transport == "ws":
|
||||
import socket
|
||||
|
||||
import aiohttp
|
||||
|
||||
client = salt.transport.ws.PublishClient(opts, io_loop, host=host, port=port)
|
||||
io_loop.spawn_callback(client.connect)
|
||||
assert client._ws is None
|
||||
assert client._session is None
|
||||
await asyncio.sleep(2)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.setblocking(0)
|
||||
sock.bind((opts["master_ip"], port))
|
||||
sock.listen(128)
|
||||
|
||||
async def handler(request):
|
||||
ws = aiohttp.web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
data = salt.transport.frame.frame_msg(msg, header=None)
|
||||
await ws.send_bytes(data)
|
||||
|
||||
server = aiohttp.web.Server(handler)
|
||||
runner = aiohttp.web.ServerRunner(server)
|
||||
await runner.setup()
|
||||
site = aiohttp.web.SockSite(runner, sock)
|
||||
await site.start()
|
||||
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
response = await client.recv()
|
||||
assert msg == response
|
||||
else:
|
||||
raise Exception(f"Unknown transport {transport}")
|
||||
client.close()
|
||||
await asyncio.sleep(0.03)
|
||||
|
|
Loading…
Add table
Reference in a new issue