From 5d68c1156085bd113081ca9938bfec66cd6bb414 Mon Sep 17 00:00:00 2001 From: Tyler Levy Conde Date: Mon, 15 Apr 2024 17:02:34 -0600 Subject: [PATCH] Added tests and verified functionality --- salt/client/ssh/shell.py | 9 +-- tests/pytests/unit/client/ssh/test_shell.py | 73 +++++++++++++++++++++ 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/salt/client/ssh/shell.py b/salt/client/ssh/shell.py index 0327bea2ccb..b4636b4ca52 100644 --- a/salt/client/ssh/shell.py +++ b/salt/client/ssh/shell.py @@ -9,11 +9,11 @@ import shlex import subprocess import sys import time -import shutil import salt.defaults.exitcodes import salt.utils.json import salt.utils.nb_popen +import salt.utils.path import salt.utils.vt log = logging.getLogger(__name__) @@ -34,9 +34,10 @@ SUDO_PROMPT_RE = re.compile( RSTR = "_edbc7885e4f9aac9b83b35999b68d015148caf467b78fa39c05f669c0ff89878" RSTR_RE = re.compile(r"(?:^|\r?\n)" + RSTR + r"(?:\r?\n|$)") -SSH_KEYGEN_PATH = shutil.which('ssh-keygen') or 'ssh-keygen' -SSH_PATH = shutil.which('ssh') or 'ssh' -SCP_PATH = shutil.which('scp') or 'scp' +SSH_KEYGEN_PATH = salt.utils.path.which("ssh-keygen") or "ssh-keygen" +SSH_PATH = salt.utils.path.which("ssh") or "ssh" +SCP_PATH = salt.utils.path.which("scp") or "scp" + def gen_key(path): """ diff --git a/tests/pytests/unit/client/ssh/test_shell.py b/tests/pytests/unit/client/ssh/test_shell.py index 0b87ec1082a..54106217fbc 100644 --- a/tests/pytests/unit/client/ssh/test_shell.py +++ b/tests/pytests/unit/client/ssh/test_shell.py @@ -1,3 +1,5 @@ +import importlib +import logging import subprocess import types @@ -98,3 +100,74 @@ def test_ssh_shell_exec_cmd_returns_status_code_with_highest_bit_set_if_process_ assert stdout == "" assert stderr == "leave me alone please" assert retcode == 137 + + +def exec_cmd(cmd): + if cmd.startswith("mkdir -p"): + return "", "Not a directory", 1 + return "OK", "", 0 + + +def test_ssh_shell_send_makedirs_failure_returns_immediately(): + with patch("salt.client.ssh.shell.Shell.exec_cmd", side_effect=exec_cmd): + shl = shell.Shell({}, "localhost") + stdout, stderr, retcode = shl.send("/tmp/file", "/tmp/file", True) + assert retcode == 1 + assert "Not a directory" in stderr + + +def test_ssh_shell_send_makedirs_on_relative_filename_skips_exec(caplog): + with patch("salt.client.ssh.shell.Shell.exec_cmd", side_effect=exec_cmd) as cmd: + with patch("salt.client.ssh.shell.Shell._run_cmd", return_value=("", "", 0)): + shl = shell.Shell({}, "localhost") + with caplog.at_level(logging.WARNING): + stdout, stderr, retcode = shl.send("/tmp/file", "targetfile", True) + assert retcode == 0 + assert "Not a directory" not in stderr + assert call("mkdir -p ''") not in cmd.mock_calls + assert "Makedirs called on relative filename" in caplog.text + + +@pytest.fixture() +def mock_bin_paths(): + """Automatically apply fixture to all tests that need it.""" + with patch("salt.utils.path.which") as mock_which: + mock_which.side_effect = lambda x: { + "ssh-keygen": "/custom/ssh-keygen", + "ssh": "/custom/ssh", + "scp": "/custom/scp", + }.get(x, None) + importlib.reload(shell) + yield + importlib.reload(shell) + + +def test_gen_key_uses_custom_ssh_keygen_path(mock_bin_paths): + """Test that gen_key function uses the correct ssh-keygen path.""" + with patch("subprocess.call") as mock_call: + shell.gen_key("/dev/null") + + # Extract the first argument of the first call to subprocess.call + args, _ = mock_call.call_args + + # Assert that the first part of the command is the custom ssh-keygen path + assert args[0][0] == "/custom/ssh-keygen" + + +def test_ssh_command_execution_uses_custom_path(mock_bin_paths): + options = {"_ssh_version": (4, 9)} + _shell = shell.Shell(opts=options, host="example.com") + cmd_string = _shell._cmd_str("ls -la") + assert "/custom/ssh" in cmd_string + + +def test_scp_command_execution_uses_custom_path(mock_bin_paths): + _shell = shell.Shell(opts={}, host="example.com") + with patch.object( + _shell, "_run_cmd", return_value=(None, None, None) + ) as mock_run_cmd: + _shell.send("source_file.txt", "/path/dest_file.txt") + # The command string passed to _run_cmd should include the custom scp path + args, _ = mock_run_cmd.call_args + assert "/custom/scp" in args[0] + assert "source_file.txt example.com:/path/dest_file.txt" in args[0]