Fix up tcp ssl and add ssl to ws

This commit is contained in:
Daniel A. Wozniak 2023-08-12 00:15:33 -07:00 committed by Daniel Wozniak
parent 2f5fe3bcb6
commit 2bf2936f73
3 changed files with 231 additions and 40 deletions

View file

@ -48,6 +48,7 @@ def request_client(opts, io_loop):
ttype = opts["transport"]
elif "transport" in opts.get("pillar", {}).get("master", {}):
ttype = opts["pillar"]["master"]["transport"]
if ttype == "zeromq":
import salt.transport.zeromq
@ -90,6 +91,9 @@ def publish_server(opts, **kwargs):
else:
kwargs["pull_path"] = os.path.join(opts["sock_dir"], "publish_pull.ipc")
if "ssl" not in kwargs and opts.get("ssl", None) is not None:
kwargs["ssl"] = opts["ssl"]
# switch on available ttypes
if ttype == "zeromq":
import salt.transport.zeromq
@ -110,7 +114,9 @@ def publish_server(opts, **kwargs):
raise Exception(f"Transport type not found: {ttype}")
def publish_client(opts, io_loop, host=None, port=None, path=None, transport=None):
def publish_client(
opts, io_loop, host=None, port=None, path=None, transport=None, **kwargs
):
# Default to ZeroMQ for now
ttype = "zeromq"
# determine the ttype
@ -121,10 +127,18 @@ def publish_client(opts, io_loop, host=None, port=None, path=None, transport=Non
elif "transport" in opts.get("pillar", {}).get("master", {}):
ttype = opts["pillar"]["master"]["transport"]
ssl = None
if "ssl" in kwargs:
ssl = kwargs["ssl"]
elif opts.get("ssl", None) is not None:
ssl = opts["ssl"]
# switch on available ttypes
if ttype == "zeromq":
import salt.transport.zeromq
if ssl:
log.warning("TLS not supported with zeromq transport")
return salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port, path=path
)
@ -132,13 +146,23 @@ def publish_client(opts, io_loop, host=None, port=None, path=None, transport=Non
import salt.transport.tcp
return salt.transport.tcp.PublishClient(
opts, io_loop, host=host, port=port, path=path
opts,
io_loop,
host=host,
port=port,
path=path,
ssl=ssl,
)
elif ttype == "ws":
import salt.transport.ws
return salt.transport.ws.PublishClient(
opts, io_loop, host=host, port=port, path=path
opts,
io_loop,
host=host,
port=port,
path=path,
ssl=ssl,
)
raise Exception(f"Transport type not found: {ttype}")
@ -154,7 +178,7 @@ def _minion_hash(hash_type, minion_id):
def ipc_publish_client(node, opts, io_loop):
# Default to TCP for now
kwargs = {"transport": "tcp"}
kwargs = {"transport": "tcp", "ssl": None}
if opts["ipc_mode"] == "tcp":
if node == "master":
kwargs.update(
@ -184,7 +208,7 @@ def ipc_publish_client(node, opts, io_loop):
def ipc_publish_server(node, opts):
# Default to TCP for now
kwargs = {"transport": "tcp"}
kwargs = {"transport": "tcp", "ssl": None}
if opts["ipc_mode"] == "tcp":
if node == "master":
kwargs.update(
@ -405,3 +429,54 @@ class PublishClient(Transport):
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def ssl_context(ssl_options, server_side=False):
if isinstance(ssl_options, ssl.SSLContext):
return ssl_options
default_version = ssl.PROTOCOL_TLS
if server_side:
default_version = ssl.PROTOCOL_TLS_SERVER
elif server_side is not None:
default_version = ssl.PROTOCOL_TLS_CLIENT
context = ssl.SSLContext(ssl_options.get("ssl_version", default_version))
if "certfile" in ssl_options:
context.load_cert_chain(
ssl_options["certfile"], ssl_options.get("keyfile", None)
)
if "cert_reqs" in ssl_options:
if ssl_options["cert_reqs"] == ssl.CERT_NONE:
# This may have been set automatically by PROTOCOL_TLS_CLIENT but is
# incompatible with CERT_NONE so we must manually clear it.
context.check_hostname = False
context.verify_mode = getattr(ssl, VerifyMode, ssl_options["cert_reqs"])
if "ca_certs" in ssl_options:
context.load_verify_locations(ssl_options["ca_certs"])
if "verify_locations" in ssl_options:
for _ in ssl_options["verify_locations"]:
if _.lower().startswith("cafile:"):
cafile = _[7:]
context.load_verify_locations(cafile=cafile)
elif _.lower().startswith("capath:"):
capath = _[7:]
context.load_verify_locations(capath=capath)
elif _.lower().startswith("cadata:"):
cadata = _[7:]
context.load_verify_locations(cadata=cadata)
else:
cafile = _
context.load_verify_locations(cafile=cafile)
if "verify_flags" in ssl_options:
for flag in ssl_options["verify_flags"]:
context.verify_flags |= getattr(ssl.VerifyFlags, flag.upper())
if "ciphers" in ssl_options:
context.set_ciphers(ssl_options["ciphers"])
return context
def common_name(cert):
try:
name = dict([_[0] for _ in cert["subject"]])["commonName"]
except (ValueError, KeyError):
return None
return name

View file

@ -18,6 +18,7 @@ import time
import urllib
import uuid
import warnings
import ssl
import tornado
import tornado.concurrent
@ -31,6 +32,7 @@ import tornado.util
import salt.master
import salt.payload
import salt.transport.frame
import salt.transport.base
import salt.utils.asynchronous
import salt.utils.files
import salt.utils.msgpack
@ -234,6 +236,7 @@ class PublishClient(salt.transport.base.PublishClient):
self.host = kwargs.get("host", None)
self.port = kwargs.get("port", None)
self.path = kwargs.get("path", None)
self.ssl = kwargs.get("ssl", None)
self.source_ip = self.opts.get("source_ip")
self.source_port = self.opts.get("source_publish_port")
self.on_recv_task = None
@ -275,11 +278,16 @@ class PublishClient(salt.transport.base.PublishClient):
self._tcp_client = TCPClientKeepAlive(
self.opts, resolver=self.resolver
)
ctx = None
if self.ssl is not None:
ctx = salt.transport.base.ssl_context(
self.ssl, server_side=False
)
stream = await asyncio.wait_for(
self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
ssl_options=ctx,
**kwargs,
),
1,
@ -479,6 +487,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
self.opts = opts
self._socket = None
self.req_server = None
self.ssl = self.opts.get("ssl", None)
@property
def socket(self):
@ -549,11 +558,14 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
log.info("RequestServer workers %s", socket)
with salt.utils.asynchronous.current_ioloop(io_loop):
ctx = None
if self.ssl is not None:
ctx = salt.transport.base.ssl_context(self.ssl, server_side=True)
if USE_LOAD_BALANCER:
self.req_server = LoadBalancerWorker(
self.socket_queue,
self.handle_message,
ssl_options=self.opts.get("ssl"),
ssl_options=ctx,
)
else:
if salt.utils.platform.is_windows():
@ -564,13 +576,21 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
self._socket.bind(_get_bind_addr(self.opts, "ret_port"))
self.req_server = SaltMessageServer(
self.handle_message,
ssl_options=self.opts.get("ssl"),
ssl_options=ctx,
io_loop=io_loop,
)
self.req_server.add_socket(self._socket)
self._socket.listen(self.backlog)
async def handle_message(self, stream, payload, header=None):
try:
cert = stream.socket.getpeercert()
except AttributeError:
pass
else:
if cert:
name = salt.transport.base.common_name(cert)
log.error("Request client cert %r", name)
payload = self.decode_payload(payload)
reply = await self.message_handler(payload)
# XXX Handle StreamClosedError
@ -1013,9 +1033,14 @@ class PubServer(tornado.tcpserver.TCPServer):
"""
def __init__(
self, opts, io_loop=None, presence_callback=None, remove_presence_callback=None
self,
opts,
io_loop=None,
presence_callback=None,
remove_presence_callback=None,
ssl=None,
):
super().__init__(ssl_options=opts.get("ssl"))
super().__init__(ssl_options=ssl)
self.io_loop = io_loop
self.opts = opts
self._closing = False
@ -1029,6 +1054,7 @@ class PubServer(tornado.tcpserver.TCPServer):
self.remove_presence_callback = remove_presence_callback
else:
self.remove_presence_callback = lambda subscriber: subscriber
self.ssl = ssl
def close(self):
if self._closing:
@ -1068,6 +1094,14 @@ class PubServer(tornado.tcpserver.TCPServer):
continue
def handle_stream(self, stream, address):
try:
cert = stream.socket.getpeercert()
except AttributeError:
pass
else:
if cert:
name = salt.transport.base.common_name(cert)
log.error("Request client cert %r", name)
log.debug("Subscriber at %s connected", address)
client = Subscriber(stream, address)
self.clients.add(client)
@ -1082,8 +1116,8 @@ class PubServer(tornado.tcpserver.TCPServer):
to_remove = []
if topic_list:
for topic in topic_list:
sent = False
for client in self.clients:
sent = Falses
for client in list(self.clients):
if topic == client.id_:
try:
# Write the packed str
@ -1095,7 +1129,7 @@ class PubServer(tornado.tcpserver.TCPServer):
if not sent:
log.debug("Publish target %s not connected %r", topic, self.clients)
else:
for client in self.clients:
for client in list(self.clients):
try:
# Write the packed str
await client.stream.write(payload)
@ -1296,6 +1330,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
pull_host=None,
pull_port=None,
pull_path=None,
ssl=None,
):
self.opts = opts
self.pub_sock = None
@ -1305,6 +1340,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
self.pull_host = pull_host
self.pull_port = pull_port
self.pull_path = pull_path
self.ssl = ssl
@property
def topic_support(self):
@ -1359,18 +1395,27 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
if io_loop is None:
io_loop = tornado.ioloop.IOLoop.current()
# Spin up the publisher
ctx = None
if self.ssl is not None:
ctx = salt.transport.base.ssl_context(self.ssl, server_side=True)
self.pub_server = pub_server = PubServer(
self.opts,
io_loop=io_loop,
presence_callback=presence_callback,
remove_presence_callback=remove_presence_callback,
ssl=ctx,
)
if self.pub_path:
log.debug("Publish server binding pub to %s", self.pub_path)
log.error(
"Publish server binding pub to %s ssl=%r", self.pub_path, self.ssl
)
sock = tornado.netutil.bind_unix_socket(self.pub_path)
else:
log.debug(
"Publish server binding pub to %s:%s", self.pub_host, self.pub_port
log.error(
"Publish server binding pub to %s:%s ssl=%r",
self.pub_host,
self.pub_port,
self.ssl,
)
sock = _get_socket(self.opts)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -1668,6 +1713,7 @@ class RequestClient(salt.transport.base.RequestClient):
self.disconnect_callback = _null_callback
self.connect_callback = _null_callback
self.backoff = opts.get("tcp_reconnect_backoff", 1)
self.ssl = self.opts.get("ssl", None)
async def getstream(self, **kwargs):
if self.source_ip or self.source_port:
@ -1676,10 +1722,13 @@ class RequestClient(salt.transport.base.RequestClient):
while stream is None and (not self._closed and not self._closing):
try:
# XXX: Support ipc sockets too
ctx = None
if self.ssl is not None:
ctx = salt.transport.base.ssl_context(self.ssl, server_side=False)
stream = await self._tcp_client.connect(
ip_bracket(self.host, strip=True),
self.port,
ssl_options=self.opts.get("ssl"),
ssl_options=ctx,
**kwargs,
)
except Exception as exc: # pylint: disable=broad-except

View file

@ -4,6 +4,7 @@ import multiprocessing
import socket
import time
import warnings
import ssl
import aiohttp
import aiohttp.web
@ -12,6 +13,7 @@ from tornado.locks import Lock
import salt.payload
import salt.transport.base
import salt.transport.frame
from salt.transport.tcp import (
USE_LOAD_BALANCER,
LoadBalancerServer,
@ -60,6 +62,7 @@ class PublishClient(salt.transport.base.PublishClient):
self.host = kwargs.get("host", None)
self.port = kwargs.get("port", None)
self.path = kwargs.get("path", None)
self.ssl = kwargs.get("ssl", None)
self.source_ip = self.opts.get("source_ip")
self.source_port = self.opts.get("source_publish_port")
self.connect_callback = None
@ -107,15 +110,25 @@ class PublishClient(salt.transport.base.PublishClient):
timeout = kwargs.get("timeout", None)
while ws is None and (not self._closed and not self._closing):
try:
ctx = None
if self.ssl is not None:
ctx = salt.transport.base.ssl_context(self.ssl, server_side=False)
if self.host and self.port:
conn = aiohttp.TCPConnector()
session = aiohttp.ClientSession(connector=conn)
url = f"http://{self.host}:{self.port}"
if self.ssl:
url = f"https://{self.host}:{self.port}/ws"
else:
url = f"http://{self.host}:{self.port}/ws"
else:
conn = aiohttp.UnixConnector(path=self.path)
session = aiohttp.ClientSession(connector=conn)
url = f"http://ipc.saltproject.io/ws"
ws = await asyncio.wait_for(session.ws_connect(url), 1)
if self.ssl:
url = f"https://ipc.saltproject.io/ws"
else:
url = f"http://ipc.saltproject.io/ws"
log.error("pub client connect %r %r", url, ctx)
ws = await asyncio.wait_for(session.ws_connect(url, ssl=ctx), 3)
except Exception as exc: # pylint: disable=broad-except
log.warning(
"WS Message Client encountered an exception while connecting to"
@ -259,27 +272,49 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
"close",
]
def __init__(self, opts, **kwargs):
def __init__(
self,
opts,
pub_host=None,
pub_port=None,
pub_path=None,
pull_host=None,
pull_port=None,
pull_path=None,
ssl=None,
):
self.opts = opts
self.pub_sock = None
self.pub_host = kwargs.get("pub_host", None)
self.pub_port = kwargs.get("pub_port", None)
self.pub_path = kwargs.get("pub_path", None)
self.pull_host = kwargs.get("pull_host", None)
self.pull_port = kwargs.get("pull_port", None)
self.pull_path = kwargs.get("pull_path", None)
self.pub_host = pub_host
self.pub_port = pub_port
self.pub_path = pub_path
self.pull_host = pull_host
self.pull_port = pull_port
self.pull_path = pull_path
self.ssl = ssl
self.clients = set()
self._run = None
self.pub_sock = None
@property
def topic_support(self):
return not self.opts.get("order_masters", False)
def __setstate__(self, state):
self.__init__(**state)
def __setstate__(self, state):
self.__init__(state["opts"])
def __getstate__(self):
return {"opts": self.opts}
return {
"opts": self.opts,
"pub_host": self.pub_host,
"pub_port": self.pub_port,
"pub_path": self.pub_path,
"pull_host": self.pull_host,
"pull_port": self.pull_port,
"pull_path": self.pull_path,
}
def publish_daemon(
self,
@ -319,13 +354,15 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
self._run = asyncio.Event()
self._run.set()
ctx = None
if self.ssl is not None:
ctx = salt.transport.base.ssl_context(self.ssl, server_side=True)
if self.pub_path:
server = aiohttp.web.Server(self.handle_request)
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
site = aiohttp.web.UnixSite(runner, self.pub_path)
log.info("Publisher binding to path %s", self.pub_path)
await site.start()
site = aiohttp.web.UnixSite(runner, self.pub_path, ssl_context=ctx)
log.info("Publisher binding to socket %s", self.pub_path)
else:
sock = _get_socket(self.opts)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -336,9 +373,11 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
server = aiohttp.web.Server(self.handle_request)
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
site = aiohttp.web.SockSite(runner, sock)
log.info("Publisher binding to socket %s", (self.pub_host, self.pub_port))
await site.start()
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
log.info(
"Publisher binding to socket %s:%s", (self.pub_host, self.pub_port)
)
await site.start()
if self.pull_path:
pull_uri = self.pull_path
@ -367,6 +406,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
process_manager.add_process(self.publish_daemon, name=self.__class__.__name__)
async def handle_request(self, request):
try:
cert = request.get_extra_info("peercert")
except AttributeError:
pass
else:
if cert:
name = salt.transport.base.common_name(cert)
log.error("Request client cert %r", name)
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
self.clients.add(ws)
@ -414,6 +461,7 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
def __init__(self, opts): # pylint: disable=W0231
self.opts = opts
self.site = None
self.ssl = self.opts.get("ssl", None)
def pre_fork(self, process_manager):
"""
@ -448,7 +496,10 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
server = aiohttp.web.Server(self.handle_message)
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
self.site = aiohttp.web.SockSite(runner, self._socket)
ctx = None
if self.ssl is not None:
ctx = tornado.netutil.ssl_options_to_context(self.ssl, server_side=True)
self.site = aiohttp.web.SockSite(runner, self._socket, ssl_context=ctx)
log.info("Worker binding to socket %s", self._socket)
await self.site.start()
# pause here for very long time by serving HTTP requests and
@ -460,6 +511,14 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer):
io_loop.spawn_callback(server)
async def handle_message(self, request):
try:
cert = request.get_extra_info("peercert")
except AttributeError:
pass
else:
if cert:
name = salt.transport.base.common_name(cert)
log.error("Request client cert %r", name)
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
@ -489,12 +548,16 @@ class RequestClient(salt.transport.base.RequestClient):
self.io_loop = io_loop
self._closing = False
self._closed = False
self.ssl = self.opts("ssl", None)
async def connect(self):
# if self.session is None:
ctx = None
if self.ssl is not None:
ctx = tornado.netutil.ssl_options_to_context(self.ssl, server_side=False)
self.session = aiohttp.ClientSession()
URL = self.get_master_uri(self.opts)
self.ws = await self.session.ws_connect(URL)
log.error("Connect to %s %s", URL, ctx)
self.ws = await self.session.ws_connect(URL, ssl=ctx)
async def send(self, load, timeout=60):
if self.sending or self._closing:
@ -528,10 +591,13 @@ class RequestClient(salt.transport.base.RequestClient):
self._closing = True
self.close_task = asyncio.create_task(self._close())
@staticmethod
def get_master_uri(opts):
def get_master_uri(self, opts):
if "master_uri" in opts:
if self.opts.get("ssl", None):
return opts["master_uri"].replace("tcp:", "https:", 1)
return opts["master_uri"].replace("tcp:", "http:", 1)
if self.opts.get("ssl", None):
return f"https://{opts['master_ip']}:{opts['master_port']}/ws"
return f"http://{opts['master_ip']}:{opts['master_port']}/ws"
# pylint: disable=W1701
@ -540,4 +606,5 @@ class RequestClient(salt.transport.base.RequestClient):
warnings.warn(
"Unclosed publish client {self!r}", ResourceWarning, source=self
)
# pylint: enable=W1701