Merge pull request #65228 from dwoz/issue/master/65226

[master] Fix cluster key rotation
This commit is contained in:
Daniel Wozniak 2023-12-09 22:41:31 -07:00 committed by GitHub
commit afdb17b125
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 333 additions and 188 deletions

View file

@ -16,7 +16,7 @@ on:
env:
COLUMNS: 190
CACHE_SEED: SEED-1 # Bump the number to invalidate all caches
CACHE_SEED: SEED-2 # Bump the number to invalidate all caches
RELENV_DATA: "${{ github.workspace }}/.relenv"
permissions:

View file

@ -22,7 +22,7 @@ on:
env:
COLUMNS: 190
CACHE_SEED: SEED-1 # Bump the number to invalidate all caches
CACHE_SEED: SEED-2 # Bump the number to invalidate all caches
RELENV_DATA: "${{ github.workspace }}/.relenv"
permissions:

View file

@ -21,7 +21,7 @@ on:
env:
COLUMNS: 190
CACHE_SEED: SEED-1 # Bump the number to invalidate all caches
CACHE_SEED: SEED-2 # Bump the number to invalidate all caches
RELENV_DATA: "${{ github.workspace }}/.relenv"
permissions:

View file

@ -12,7 +12,7 @@ on:
env:
COLUMNS: 190
CACHE_SEED: SEED-1 # Bump the number to invalidate all caches
CACHE_SEED: SEED-2 # Bump the number to invalidate all caches
RELENV_DATA: "${{ github.workspace }}/.relenv"
permissions:

View file

@ -37,7 +37,7 @@ on:
env:
COLUMNS: 190
CACHE_SEED: SEED-1 # Bump the number to invalidate all caches
CACHE_SEED: SEED-2 # Bump the number to invalidate all caches
RELENV_DATA: "${{ github.workspace }}/.relenv"
permissions:

View file

@ -34,7 +34,7 @@ on:
env:
COLUMNS: 190
CACHE_SEED: SEED-1 # Bump the number to invalidate all caches
CACHE_SEED: SEED-2 # Bump the number to invalidate all caches
RELENV_DATA: "${{ github.workspace }}/.relenv"
<%- endblock env %>

View file

@ -1045,9 +1045,7 @@ class MasterPubServerChannel:
"""
try:
tag, data = salt.utils.event.SaltEvent.unpack(payload)
log.error("recieved event from peer %s %r", tag, data)
if tag.startswith("cluster/peer"):
log.error("Got peer join %r", data)
peer = data["peer_id"]
aes = data["peers"][self.opts["id"]]["aes"]
sig = data["peers"][self.opts["id"]]["sig"]

View file

@ -16,6 +16,7 @@ import pathlib
import random
import stat
import sys
import tempfile
import time
import traceback
import uuid
@ -1620,19 +1621,22 @@ class Crypticle:
return b64key.replace("\n", "")
@classmethod
def read_or_generate_key(cls, path, key_size=192, remove=False):
if remove:
os.remove(path)
def write_key(cls, path, key_size=192):
directory = pathlib.Path(path).parent
with salt.utils.files.set_umask(0o177):
try:
with salt.utils.files.fopen(path, "r") as fp:
return fp.read()
except FileNotFoundError:
pass
key = cls.generate_key_string(key_size)
with salt.utils.files.fopen(path, "w") as fp:
fp.write(key)
return key
fd, tmp = tempfile.mkstemp(dir=directory, prefix="aes")
os.close(fd)
with salt.utils.files.fopen(tmp, "w") as fp:
fp.write(cls.generate_key_string(key_size))
os.rename(tmp, path)
@classmethod
def read_key(cls, path):
try:
with salt.utils.files.fopen(path, "r") as fp:
return fp.read()
except FileNotFoundError:
pass
@classmethod
def extract_keys(cls, key_string, key_size):

View file

@ -6,7 +6,6 @@ import asyncio
import collections
import copy
import ctypes
import functools
import logging
import multiprocessing
import os
@ -142,7 +141,6 @@ class SMaster:
def rotate_secrets(
cls, opts=None, event=None, use_lock=True, owner=False, publisher=None
):
log.info("Rotating master AES key")
if opts is None:
opts = {}
@ -173,6 +171,41 @@ class SMaster:
log.debug("Pinging all connected minions due to key rotation")
salt.utils.master.ping_all_connected_minions(opts)
@classmethod
def rotate_cluster_secret(
cls, opts=None, event=None, use_lock=True, owner=False, publisher=None
):
log.debug("Rotating cluster AES key")
if opts is None:
opts = {}
if use_lock:
with cls.secrets["cluster_aes"]["secret"].get_lock():
cls.secrets["cluster_aes"][
"secret"
].value = salt.utils.stringutils.to_bytes(
cls.secrets["cluster_aes"]["reload"](remove=owner)
)
else:
cls.secrets["cluster_aes"][
"secret"
].value = salt.utils.stringutils.to_bytes(
cls.secrets["cluster_aes"]["reload"](remove=owner)
)
if event:
event.fire_event(
{f"rotate_cluster_aes_key": True}, tag="rotate_cluster_aes_key"
)
if publisher:
publisher.send_aes_key_event()
if opts.get("ping_on_rotate"):
# Ping all minions to get them to pick up the new key
log.debug("Pinging all connected minions due to key rotation")
salt.utils.master.ping_all_connected_minions(opts)
class Maintenance(salt.utils.process.SignalHandlingProcess):
"""
@ -358,7 +391,7 @@ class Maintenance(salt.utils.process.SignalHandlingProcess):
if to_rotate:
if self.opts.get("cluster_id", None):
SMaster.rotate_secrets(
SMaster.rotate_cluster_secret(
self.opts, self.event, owner=True, publisher=self.ipc_publisher
)
else:
@ -714,6 +747,20 @@ class Master(SMaster):
log.critical("Master failed pre flight checks, exiting\n")
sys.exit(salt.defaults.exitcodes.EX_GENERIC)
def read_or_generate_key(self, remove=False, fs_wait=0.1):
"""
Used to manage a cluster aes session key file.
"""
path = os.path.join(self.opts["cluster_pki_dir"], ".aes")
if remove:
os.remove(path)
key = salt.crypt.Crypticle.read_key(path)
if key:
return key
salt.crypt.Crypticle.write_key(path)
time.sleep(fs_wait)
return salt.crypt.Crypticle.read_key(path)
def start(self):
"""
Turn on the master server components
@ -731,22 +778,18 @@ class Master(SMaster):
# signal handlers
with salt.utils.process.default_signals(signal.SIGINT, signal.SIGTERM):
if self.opts["cluster_id"]:
keypath = os.path.join(self.opts["cluster_pki_dir"], ".aes")
cluster_keygen = functools.partial(
salt.crypt.Crypticle.read_or_generate_key,
keypath,
)
# Setup the secrets here because the PubServerChannel may need
# them as well.
SMaster.secrets["cluster_aes"] = {
"secret": multiprocessing.Array(
ctypes.c_char, salt.utils.stringutils.to_bytes(cluster_keygen())
ctypes.c_char,
salt.utils.stringutils.to_bytes(self.read_or_generate_key()),
),
"serial": multiprocessing.Value(
ctypes.c_longlong,
lock=False, # We'll use the lock from 'secret'
),
"reload": cluster_keygen,
"reload": self.read_or_generate_key,
}
SMaster.secrets["aes"] = {
@ -779,7 +822,7 @@ class Master(SMaster):
ipc_publisher.pre_fork(self.process_manager)
self.process_manager.add_process(
EventMonitor,
args=[self.opts],
args=[self.opts, ipc_publisher],
name="EventMonitor",
)
@ -908,19 +951,19 @@ class EventMonitor(salt.utils.process.SignalHandlingProcess):
- Handle key rotate events.
"""
def __init__(self, opts, channels=None, name="EventMonitor"):
def __init__(self, opts, ipc_publisher, channels=None, name="EventMonitor"):
super().__init__(name=name)
self.opts = opts
if channels is None:
channels = []
self.channels = channels
self.ipc_publisher = ipc_publisher
async def handle_event(self, package):
"""
Event handler for publish forwarder
"""
tag, data = salt.utils.event.SaltEvent.unpack(package)
log.debug("Event monitor got event %s %r", tag, data)
if tag.startswith("salt/job") and tag.endswith("/publish"):
peer_id = data.pop("__peer_id", None)
if peer_id:
@ -937,9 +980,15 @@ class EventMonitor(salt.utils.process.SignalHandlingProcess):
for chan in self.channels:
tasks.append(asyncio.create_task(chan.publish(data)))
await asyncio.gather(*tasks)
elif tag == "rotate_aes_key":
log.debug("Event monitor recieved rotate aes key event, rotating key.")
SMaster.rotate_secrets(self.opts, owner=False)
elif tag == "rotate_cluster_aes_key":
peer_id = data.pop("__peer_id", None)
if peer_id:
log.debug("Rotating AES session key")
SMaster.rotate_cluster_secret(
self.opts, owner=False, publisher=self.ipc_publisher
)
else:
log.trace("Ignore tag %s", tag)
def run(self):
io_loop = tornado.ioloop.IOLoop()

View file

@ -0,0 +1,150 @@
import logging
import subprocess
import pytest
import salt.utils.platform
log = logging.getLogger(__name__)
@pytest.fixture
def cluster_shared_path(tmp_path):
path = tmp_path / "cluster"
path.mkdir()
return path
@pytest.fixture
def cluster_pki_path(cluster_shared_path):
path = cluster_shared_path / "pki"
path.mkdir()
(path / "peers").mkdir()
return path
@pytest.fixture
def cluster_cache_path(cluster_shared_path):
path = cluster_shared_path / "cache"
path.mkdir()
return path
@pytest.fixture
def cluster_master_1(request, salt_factories, cluster_pki_path, cluster_cache_path):
config_defaults = {
"open_mode": True,
"transport": request.config.getoption("--transport"),
}
config_overrides = {
"interface": "127.0.0.1",
"cluster_id": "master_cluster",
"cluster_peers": [
"127.0.0.2",
"127.0.0.3",
],
"cluster_pki_dir": str(cluster_pki_path),
"cache_dir": str(cluster_cache_path),
}
factory = salt_factories.salt_master_daemon(
"127.0.0.1",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory
@pytest.fixture
def cluster_master_2(salt_factories, cluster_master_1):
if salt.utils.platform.is_darwin() or salt.utils.platform.is_freebsd():
subprocess.check_output(["ifconfig", "lo0", "alias", "127.0.0.2", "up"])
config_defaults = {
"open_mode": True,
"transport": cluster_master_1.config["transport"],
}
config_overrides = {
"interface": "127.0.0.2",
"cluster_id": "master_cluster",
"cluster_peers": [
"127.0.0.1",
"127.0.0.3",
],
"cluster_pki_dir": cluster_master_1.config["cluster_pki_dir"],
"cache_dir": cluster_master_1.config["cache_dir"],
}
# Use the same ports for both masters, they are binding to different interfaces
for key in (
"ret_port",
"publish_port",
):
config_overrides[key] = cluster_master_1.config[key]
factory = salt_factories.salt_master_daemon(
"127.0.0.2",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory
@pytest.fixture
def cluster_master_3(salt_factories, cluster_master_1):
if salt.utils.platform.is_darwin() or salt.utils.platform.is_freebsd():
subprocess.check_output(["ifconfig", "lo0", "alias", "127.0.0.3", "up"])
config_defaults = {
"open_mode": True,
"transport": cluster_master_1.config["transport"],
}
config_overrides = {
"interface": "127.0.0.3",
"cluster_id": "master_cluster",
"cluster_peers": [
"127.0.0.1",
"127.0.0.2",
],
"cluster_pki_dir": cluster_master_1.config["cluster_pki_dir"],
"cache_dir": cluster_master_1.config["cache_dir"],
}
# Use the same ports for both masters, they are binding to different interfaces
for key in (
"ret_port",
"publish_port",
):
config_overrides[key] = cluster_master_1.config[key]
factory = salt_factories.salt_master_daemon(
"127.0.0.3",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory
@pytest.fixture
def cluster_minion_1(cluster_master_1):
config_defaults = {
"transport": cluster_master_1.config["transport"],
}
port = cluster_master_1.config["ret_port"]
addr = cluster_master_1.config["interface"]
config_overrides = {
"master": f"{addr}:{port}",
"test.foo": "baz",
}
factory = cluster_master_1.salt_minion_daemon(
"cluster-minion-1",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory

View file

@ -1,150 +1,17 @@
import logging
import subprocess
import pytest
# pylint: disable=unused-import
from tests.pytests.integration.cluster.conftest import (
cluster_cache_path,
cluster_master_1,
cluster_master_2,
cluster_master_3,
cluster_minion_1,
cluster_pki_path,
cluster_shared_path,
)
# pylint: enable=unused-import
import salt.utils.platform
log = logging.getLogger(__name__)
@pytest.fixture
def cluster_shared_path(tmp_path):
path = tmp_path / "cluster"
path.mkdir()
return path
@pytest.fixture
def cluster_pki_path(cluster_shared_path):
path = cluster_shared_path / "pki"
path.mkdir()
(path / "peers").mkdir()
return path
@pytest.fixture
def cluster_cache_path(cluster_shared_path):
path = cluster_shared_path / "cache"
path.mkdir()
return path
@pytest.fixture
def cluster_master_1(request, salt_factories, cluster_pki_path, cluster_cache_path):
config_defaults = {
"open_mode": True,
"transport": request.config.getoption("--transport"),
}
config_overrides = {
"interface": "127.0.0.1",
"cluster_id": "master_cluster",
"cluster_peers": [
"127.0.0.2",
"127.0.0.3",
],
"cluster_pki_dir": str(cluster_pki_path),
"cache_dir": str(cluster_cache_path),
}
factory = salt_factories.salt_master_daemon(
"127.0.0.1",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory
@pytest.fixture
def cluster_master_2(salt_factories, cluster_master_1):
if salt.utils.platform.is_darwin() or salt.utils.platform.is_freebsd():
subprocess.check_output(["ifconfig", "lo0", "alias", "127.0.0.2", "up"])
config_defaults = {
"open_mode": True,
"transport": cluster_master_1.config["transport"],
}
config_overrides = {
"interface": "127.0.0.2",
"cluster_id": "master_cluster",
"cluster_peers": [
"127.0.0.1",
"127.0.0.3",
],
"cluster_pki_dir": cluster_master_1.config["cluster_pki_dir"],
"cache_dir": cluster_master_1.config["cache_dir"],
}
# Use the same ports for both masters, they are binding to different interfaces
for key in (
"ret_port",
"publish_port",
):
config_overrides[key] = cluster_master_1.config[key]
factory = salt_factories.salt_master_daemon(
"127.0.0.2",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory
@pytest.fixture
def cluster_master_3(salt_factories, cluster_master_1):
if salt.utils.platform.is_darwin() or salt.utils.platform.is_freebsd():
subprocess.check_output(["ifconfig", "lo0", "alias", "127.0.0.3", "up"])
config_defaults = {
"open_mode": True,
"transport": cluster_master_1.config["transport"],
}
config_overrides = {
"interface": "127.0.0.3",
"cluster_id": "master_cluster",
"cluster_peers": [
"127.0.0.1",
"127.0.0.2",
],
"cluster_pki_dir": cluster_master_1.config["cluster_pki_dir"],
"cache_dir": cluster_master_1.config["cache_dir"],
}
# Use the same ports for both masters, they are binding to different interfaces
for key in (
"ret_port",
"publish_port",
):
config_overrides[key] = cluster_master_1.config[key]
factory = salt_factories.salt_master_daemon(
"127.0.0.3",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory
@pytest.fixture
def cluster_minion_1(cluster_master_1):
config_defaults = {
"transport": cluster_master_1.config["transport"],
}
port = cluster_master_1.config["ret_port"]
addr = cluster_master_1.config["interface"]
config_overrides = {
"master": f"{addr}:{port}",
"test.foo": "baz",
}
factory = cluster_master_1.salt_minion_daemon(
"cluster-minion-1",
defaults=config_defaults,
overrides=config_overrides,
extra_cli_arguments_after_first_start_failure=["--log-level=info"],
)
with factory.started(start_timeout=120):
yield factory

View file

@ -0,0 +1,74 @@
import os
import pathlib
import time
import salt.crypt
def test_cluster_key_rotation(
cluster_master_1,
cluster_master_2,
cluster_master_3,
cluster_minion_1,
cluster_cache_path,
):
cli = cluster_master_2.salt_cli(timeout=120)
ret = cli.run("test.ping", minion_tgt="cluster-minion-1")
assert ret.data is True
# Validate the aes session key for all masters match
keys = set()
for master in (
cluster_master_1,
cluster_master_2,
cluster_master_3,
):
config = cluster_minion_1.config.copy()
config[
"master_uri"
] = f"tcp://{master.config['interface']}:{master.config['ret_port']}"
auth = salt.crypt.SAuth(config)
auth.authenticate()
assert "aes" in auth._creds
keys.add(auth._creds["aes"])
assert len(keys) == 1
orig_aes = keys.pop()
# Create a drop file and wait for the master to do a key rotation.
dfpath = pathlib.Path(cluster_master_1.config["cachedir"]) / ".dfn"
assert not dfpath.exists()
salt.crypt.dropfile(
cluster_master_1.config["cachedir"],
user=os.getlogin(),
master_id=cluster_master_1.config["id"],
)
assert dfpath.exists()
timeout = 2 * cluster_master_1.config["loop_interval"]
start = time.monotonic()
while True:
if not dfpath.exists():
break
if time.monotonic() - start > timeout:
assert False, f"Drop file never removed {dfpath}"
keys = set()
# Validate the aes session key for all masters match
for master in (
cluster_master_1,
cluster_master_2,
cluster_master_3,
):
config = cluster_minion_1.config.copy()
config[
"master_uri"
] = f"tcp://{master.config['interface']}:{master.config['ret_port']}"
auth = salt.crypt.SAuth(config)
auth.authenticate()
assert "aes" in auth._creds
keys.add(auth._creds["aes"])
assert len(keys) == 1
# Validate the aes session key actually changed
assert orig_aes != keys.pop()

View file

@ -177,7 +177,10 @@ def test_refresh_matchers():
assert ret is False
@pytest.mark.skip_on_windows
def test_refresh_modules_async_false():
# XXX: This test adds coverage but what is it really testing? Seems we'd be
# better off with at least a functional test here.
kwargs = {"async": False}
ret = saltutil.refresh_modules(**kwargs)
assert ret is False

View file

@ -284,12 +284,12 @@ def test_verify_signature_bad_sig(tmp_path):
def test_read_or_generate_key_string(tmp_path):
keyfile = tmp_path / ".aes"
assert not keyfile.exists()
first_key = salt.crypt.Crypticle.read_or_generate_key(keyfile)
assert keyfile.exists()
second_key = salt.crypt.Crypticle.read_or_generate_key(keyfile)
assert first_key == second_key
third_key = salt.crypt.Crypticle.read_or_generate_key(keyfile, remove=True)
assert second_key != third_key
first_key = salt.crypt.Crypticle.read_key(keyfile)
assert first_key is None
assert not keyfile.exists()
salt.crypt.Crypticle.write_key(keyfile)
second_key = salt.crypt.Crypticle.read_key(keyfile)
assert second_key is not None
def test_dropfile_contents(tmp_path, master_opts):

View file

@ -990,7 +990,7 @@ def test_key_rotate_no_master_match(maintenance):
def test_key_dfn_wait(cluster_maintenance):
now = time.monotonic()
key = pathlib.Path(cluster_maintenance.opts["cluster_pki_dir"]) / ".aes"
salt.crypt.Crypticle.read_or_generate_key(str(key))
salt.crypt.Crypticle.write_key(str(key))
rotate_time = time.monotonic() - (cluster_maintenance.opts["publish_session"] + 1)
os.utime(str(key), (rotate_time, rotate_time))