Refactor and add some tests

Added the check to a few other places in channel server
This commit is contained in:
Shane Lee 2024-02-28 09:27:39 -07:00 committed by Daniel Wozniak
parent 539ad0f888
commit b7f44da849
3 changed files with 56 additions and 3 deletions

View file

@ -56,6 +56,16 @@ class ReqServerChannel:
transport = salt.transport.request_server(opts, **kwargs)
return cls(opts, transport)
@classmethod
def compare_keys(cls, key1, key2):
"""
Normalize and compare two keys
Returns:
bool: ``True`` if the keys match, otherwise ``False``
"""
return salt.crypt.clean_key(key1) == salt.crypt.clean_key(key2)
def __init__(self, opts, transport):
self.opts = opts
self.transport = transport
@ -381,7 +391,7 @@ class ReqServerChannel:
elif os.path.isfile(pubfn):
# The key has been accepted, check it
with salt.utils.files.fopen(pubfn, "r") as pubfn_handle:
if salt.crypt.clean_key(pubfn_handle.read()) != load["pub"]:
if not self.compare_keys(pubfn_handle.read(), load["pub"]):
log.error(
"Authentication attempt from %s failed, the public "
"keys did not match. This may be an attempt to compromise "
@ -490,7 +500,7 @@ class ReqServerChannel:
# case. Otherwise log the fact that the minion is still
# pending.
with salt.utils.files.fopen(pubfn_pend, "r") as pubfn_handle:
if salt.crypt.clean_key(pubfn_handle.read()) != load["pub"]:
if not self.compare_keys(pubfn_handle.read(), load["pub"]):
log.error(
"Authentication attempt from %s failed, the public "
"key in pending did not match. This may be an "
@ -546,7 +556,7 @@ class ReqServerChannel:
# so, pass on doing anything here, and let it get automatically
# accepted below.
with salt.utils.files.fopen(pubfn_pend, "r") as pubfn_handle:
if salt.crypt.clean_key(pubfn_handle.read()) != load["pub"]:
if not self.compare_keys(pubfn_handle.read(), load["pub"]):
log.error(
"Authentication attempt from %s failed, the public "
"keys in pending did not match. This may be an "

View file

View file

@ -0,0 +1,43 @@
import pytest
import salt.channel.server as server
@pytest.fixture
def key_data():
return [
"-----BEGIN PUBLIC KEY-----",
"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoe5QSDYRWKyknbVyRrIj",
"rm1ht5HgKzAVUber0x54+b/UgxTd1cqI6I+eDlx53LqZSH3G8Rd5cUh8LHoGedSa",
"E62vEiLAjgXa+RdgcGiQpYS8+Z2RvQJ8oIcZgO+2AzgBRHboNWHTYRRmJXCd3dKs",
"9tcwK6wxChR06HzGqaOTixAuQlegWbOTU+X4dXIbW7AnuQBt9MCib7SxHlscrqcS",
"cBrRvq51YP6cxPm/rZJdBqZhVrlghBvIpa45NApP5PherGi4AbEGYte4l+gC+fOA",
"osEBis1V27djPpIyQS4qk3XAPQg6CYQMDltHqA4Fdo0Nt7SMScxJhfH0r6zmBFAe",
"BQIDAQAB",
"-----END PUBLIC KEY-----",
]
@pytest.mark.parametrize("linesep", ["\r\n", "\r", "\n"])
def test_compare_keys(key_data, linesep):
src_key = linesep.join(key_data)
tgt_key = "\n".join(key_data)
assert server.ReqServerChannel.compare_keys(src_key, tgt_key) is True
@pytest.mark.parametrize("linesep", ["\r\n", "\r", "\n"])
def test_compare_keys_newline_src(key_data, linesep):
src_key = linesep.join(key_data) + linesep
tgt_key = "\n".join(key_data)
assert src_key.endswith(linesep)
assert not tgt_key.endswith("\n")
assert server.ReqServerChannel.compare_keys(src_key, tgt_key) is True
@pytest.mark.parametrize("linesep", ["\r\n", "\r", "\n"])
def test_compare_keys_newline_tgt(key_data, linesep):
src_key = linesep.join(key_data)
tgt_key = "\n".join(key_data) + "\n"
assert not src_key.endswith(linesep)
assert tgt_key.endswith("\n")
assert server.ReqServerChannel.compare_keys(src_key, tgt_key) is True