Allow ssh pre connection hook

This commit is contained in:
Tyler Levy Conde 2024-04-30 10:04:00 -06:00 committed by Thomas Phipps
parent 21d5cca071
commit 5991f14a8c
4 changed files with 85 additions and 0 deletions

1
changelog/66210.added.md Normal file
View file

@ -0,0 +1 @@
Allow pre-connection scripts to be run on host before any ssh commands

View file

@ -1008,6 +1008,7 @@ class Single:
self.fsclient = fsclient
self.context = {"master_opts": self.opts, "fileclient": self.fsclient}
self.ssh_pre_hook = kwargs.get("ssh_pre_hook", None)
self.ssh_pre_flight = kwargs.get("ssh_pre_flight", None)
self.ssh_pre_flight_args = kwargs.get("ssh_pre_flight_args", None)
@ -1093,6 +1094,12 @@ class Single:
return arg
return "".join(["\\" + char if re.match(r"\W", char) else char for char in arg])
def run_ssh_pre_hook(self):
"""
Run a pre_hook script on the host machine before running any ssh commands
"""
return self.shell.exec_cmd(self.ssh_pre_hook)
def run_ssh_pre_flight(self):
"""
Run our pre_flight script before running any ssh commands
@ -1168,6 +1175,13 @@ class Single:
stdout = stderr = ""
retcode = salt.defaults.exitcodes.EX_OK
if self.ssh_pre_hook:
stdout, stderr, retcode = self.run_ssh_pre_hook()
if retcode != salt.defaults.exitcodes.EX_OK:
log.error("Error running ssh_pre_hook script %s", self.ssh_pre_hook)
return stdout, stderr, retcode
log.info("Successfully ran the ssh_pre_hook script: %s", self.ssh_pre_hook)
if self.ssh_pre_flight:
if not self.opts.get("ssh_run_pre_flight", False) and self.check_thin_dir():
log.info(

View file

@ -65,6 +65,7 @@ class SSHClient:
("ssh_scan_timeout", int),
("ssh_timeout", int),
("ssh_log_file", str),
("ssh_pre_hook", str),
("raw_shell", bool),
("refresh_cache", bool),
("roster", str),

View file

@ -834,3 +834,72 @@ def test_ssh_single__cmd_str_sudo_passwd_user(opts):
)
assert expected in cmd
def test_run_ssh_pre_hook_success(opts, target, tmp_path):
"""
Test run_ssh_pre_hook when ssh_pre_hook is successful.
"""
target["ssh_pre_hook"] = "echo 'Pre-hook success'"
single_instance = ssh.Single(opts, opts["argv"], "localhost", **target)
mock_exec_cmd = MagicMock(return_value=("Output", "No errors", 0))
with patch.object(single_instance.shell, "exec_cmd", mock_exec_cmd):
result = single_instance.run_ssh_pre_hook()
assert result == ("Output", "No errors", 0)
def test_run_ssh_pre_hook_failure(opts, target):
"""
Test run_ssh_pre_hook when ssh_pre_hook fails.
"""
target["ssh_pre_hook"] = "echo 'Pre-hook failure'"
single_instance = ssh.Single(opts, opts["argv"], "localhost", **target)
mock_exec_cmd = MagicMock(return_value=("Error output", "Failed to execute", 1))
with patch.object(single_instance.shell, "exec_cmd", mock_exec_cmd):
result = single_instance.run_ssh_pre_hook()
assert result == ("Error output", "Failed to execute", 1)
def test_run_integration_with_pre_hook_success(opts, target):
"""
Test the run method integrates run_ssh_pre_hook and proceeds on success.
"""
target["ssh_pre_hook"] = "echo 'Pre-hook success'"
target["ssh_pre_flight"] = None
single_instance = ssh.Single(opts, opts["argv"], "localhost", **target)
mock_pre_hook = MagicMock(return_value=("", "", 0))
mock_cmd_block = MagicMock(return_value=("", "", 0))
with patch.object(single_instance, "run_ssh_pre_hook", mock_pre_hook), patch.object(
single_instance, "cmd_block", mock_cmd_block
):
stdout, stderr, retcode = single_instance.run()
assert retcode == 0
mock_pre_hook.assert_called_once()
def test_run_integration_with_pre_hook_failure(opts, target):
"""
Test the run method handles pre_hook failure correctly and skips further steps.
"""
target["ssh_pre_hook"] = "echo 'Pre-hook failure'"
target["ssh_pre_flight"] = None
single_instance = ssh.Single(opts, opts["argv"], "localhost", **target)
mock_pre_hook = MagicMock(return_value=("Error output", "Failed to execute", 1))
with patch.object(single_instance, "run_ssh_pre_hook", mock_pre_hook):
stdout, stderr, retcode = single_instance.run()
assert retcode == 1
assert "Failed to execute" in stderr
mock_pre_hook.assert_called_once()
def test_run_integration_with_no_pre_hook(opts, target):
"""
Test the run method succeeds with no ssh_pre_hook
"""
target["ssh_pre_hook"] = None
target["ssh_pre_flight"] = None
single_instance = ssh.Single(opts, opts["argv"], "localhost", **target)
mock_cmd_block = MagicMock(return_value=("", "", 0))
with patch.object(single_instance, "cmd_block", mock_cmd_block):
stdout, stderr, retcode = single_instance.run()
assert retcode == 0