From 5991f14a8c9aa1e6b63e20a32f344aa8803f90e4 Mon Sep 17 00:00:00 2001 From: Tyler Levy Conde Date: Tue, 30 Apr 2024 10:04:00 -0600 Subject: [PATCH] Allow ssh pre connection hook --- changelog/66210.added.md | 1 + salt/client/ssh/__init__.py | 14 ++++ salt/client/ssh/client.py | 1 + tests/pytests/unit/client/ssh/test_single.py | 69 ++++++++++++++++++++ 4 files changed, 85 insertions(+) create mode 100644 changelog/66210.added.md diff --git a/changelog/66210.added.md b/changelog/66210.added.md new file mode 100644 index 00000000000..caff725c7ec --- /dev/null +++ b/changelog/66210.added.md @@ -0,0 +1 @@ +Allow pre-connection scripts to be run on host before any ssh commands diff --git a/salt/client/ssh/__init__.py b/salt/client/ssh/__init__.py index b8cf40f0f51..c20bbd88719 100644 --- a/salt/client/ssh/__init__.py +++ b/salt/client/ssh/__init__.py @@ -1008,6 +1008,7 @@ class Single: self.fsclient = fsclient self.context = {"master_opts": self.opts, "fileclient": self.fsclient} + self.ssh_pre_hook = kwargs.get("ssh_pre_hook", None) self.ssh_pre_flight = kwargs.get("ssh_pre_flight", None) self.ssh_pre_flight_args = kwargs.get("ssh_pre_flight_args", None) @@ -1093,6 +1094,12 @@ class Single: return arg return "".join(["\\" + char if re.match(r"\W", char) else char for char in arg]) + def run_ssh_pre_hook(self): + """ + Run a pre_hook script on the host machine before running any ssh commands + """ + return self.shell.exec_cmd(self.ssh_pre_hook) + def run_ssh_pre_flight(self): """ Run our pre_flight script before running any ssh commands @@ -1168,6 +1175,13 @@ class Single: stdout = stderr = "" retcode = salt.defaults.exitcodes.EX_OK + if self.ssh_pre_hook: + stdout, stderr, retcode = self.run_ssh_pre_hook() + if retcode != salt.defaults.exitcodes.EX_OK: + log.error("Error running ssh_pre_hook script %s", self.ssh_pre_hook) + return stdout, stderr, retcode + log.info("Successfully ran the ssh_pre_hook script: %s", self.ssh_pre_hook) + if self.ssh_pre_flight: if not self.opts.get("ssh_run_pre_flight", False) and self.check_thin_dir(): log.info( diff --git a/salt/client/ssh/client.py b/salt/client/ssh/client.py index 8727ce23c3c..f3f678bfff1 100644 --- a/salt/client/ssh/client.py +++ b/salt/client/ssh/client.py @@ -65,6 +65,7 @@ class SSHClient: ("ssh_scan_timeout", int), ("ssh_timeout", int), ("ssh_log_file", str), + ("ssh_pre_hook", str), ("raw_shell", bool), ("refresh_cache", bool), ("roster", str), diff --git a/tests/pytests/unit/client/ssh/test_single.py b/tests/pytests/unit/client/ssh/test_single.py index 91b67a250b7..5e5357bf222 100644 --- a/tests/pytests/unit/client/ssh/test_single.py +++ b/tests/pytests/unit/client/ssh/test_single.py @@ -834,3 +834,72 @@ def test_ssh_single__cmd_str_sudo_passwd_user(opts): ) assert expected in cmd + + +def test_run_ssh_pre_hook_success(opts, target, tmp_path): + """ + Test run_ssh_pre_hook when ssh_pre_hook is successful. + """ + target["ssh_pre_hook"] = "echo 'Pre-hook success'" + single_instance = ssh.Single(opts, opts["argv"], "localhost", **target) + mock_exec_cmd = MagicMock(return_value=("Output", "No errors", 0)) + with patch.object(single_instance.shell, "exec_cmd", mock_exec_cmd): + result = single_instance.run_ssh_pre_hook() + assert result == ("Output", "No errors", 0) + + +def test_run_ssh_pre_hook_failure(opts, target): + """ + Test run_ssh_pre_hook when ssh_pre_hook fails. + """ + target["ssh_pre_hook"] = "echo 'Pre-hook failure'" + single_instance = ssh.Single(opts, opts["argv"], "localhost", **target) + mock_exec_cmd = MagicMock(return_value=("Error output", "Failed to execute", 1)) + with patch.object(single_instance.shell, "exec_cmd", mock_exec_cmd): + result = single_instance.run_ssh_pre_hook() + assert result == ("Error output", "Failed to execute", 1) + + +def test_run_integration_with_pre_hook_success(opts, target): + """ + Test the run method integrates run_ssh_pre_hook and proceeds on success. + """ + target["ssh_pre_hook"] = "echo 'Pre-hook success'" + target["ssh_pre_flight"] = None + single_instance = ssh.Single(opts, opts["argv"], "localhost", **target) + mock_pre_hook = MagicMock(return_value=("", "", 0)) + mock_cmd_block = MagicMock(return_value=("", "", 0)) + with patch.object(single_instance, "run_ssh_pre_hook", mock_pre_hook), patch.object( + single_instance, "cmd_block", mock_cmd_block + ): + stdout, stderr, retcode = single_instance.run() + assert retcode == 0 + mock_pre_hook.assert_called_once() + + +def test_run_integration_with_pre_hook_failure(opts, target): + """ + Test the run method handles pre_hook failure correctly and skips further steps. + """ + target["ssh_pre_hook"] = "echo 'Pre-hook failure'" + target["ssh_pre_flight"] = None + single_instance = ssh.Single(opts, opts["argv"], "localhost", **target) + mock_pre_hook = MagicMock(return_value=("Error output", "Failed to execute", 1)) + with patch.object(single_instance, "run_ssh_pre_hook", mock_pre_hook): + stdout, stderr, retcode = single_instance.run() + assert retcode == 1 + assert "Failed to execute" in stderr + mock_pre_hook.assert_called_once() + + +def test_run_integration_with_no_pre_hook(opts, target): + """ + Test the run method succeeds with no ssh_pre_hook + """ + target["ssh_pre_hook"] = None + target["ssh_pre_flight"] = None + single_instance = ssh.Single(opts, opts["argv"], "localhost", **target) + mock_cmd_block = MagicMock(return_value=("", "", 0)) + with patch.object(single_instance, "cmd_block", mock_cmd_block): + stdout, stderr, retcode = single_instance.run() + assert retcode == 0