Make Client context manager, add test

This commit is contained in:
Twangboy 2022-11-14 19:11:48 -07:00 committed by Megan Wilhite
parent a7bca34a67
commit c0c8241bb4
2 changed files with 34 additions and 9 deletions

View file

@ -897,6 +897,15 @@ class Client:
self._client = PsExecClient(server, username, password, port, encrypt)
self._client._service = ScmrService(self.service_name, self._client.session)
def __enter__(self):
self.connect()
self.create_service()
return self
def __exit__(self, tb_type, tb_value, tb):
self.remove_service()
self.disconnect()
def connect(self):
return self._client.connect()
@ -971,17 +980,10 @@ def run_psexec_command(cmd, args, host, username, password, port=445):
Run a command remotely using the psexec protocol
"""
service_name = "PS-Exec-{}".format(uuid.uuid4())
stdout, stderr, ret_code = "", "", None
client = Client(
with Client(
host, username, password, port=port, encrypt=False, service_name=service_name
)
client.connect()
try:
client.create_service()
) as client:
stdout, stderr, ret_code = client.run_executable(cmd, args)
finally:
client.remove_service()
client.disconnect()
return stdout, stderr, ret_code

View file

@ -206,6 +206,29 @@ def test_deploy_windows_custom_port():
mock.assert_called_once_with("test", "Administrator", None, 1234)
def test_run_psexec_command_cleanup_lingering_paexec():
pytest.importorskip("pypsexec.client", reason="Requires PyPsExec")
mock_psexec = patch("salt.utils.cloud.PsExecClient", autospec=True)
mock_scmr = patch("salt.utils.cloud.ScmrService", autospec=True)
mock_rm_svc = patch("salt.utils.cloud.Client.remove_service", autospec=True)
with mock_psexec as mock_client, mock_scmr, mock_rm_svc:
mock_client.return_value.session = MagicMock(username="Gary")
mock_client.return_value.connection = MagicMock(server_name="Krabbs")
mock_client.return_value.run_executable.return_value = (
"Sandy",
"MermaidMan",
"BarnicleBoy",
)
cloud.run_psexec_command(
"spongebob",
"squarepants",
"patrick",
"squidward",
"plankton",
)
mock_client.return_value.cleanup.assert_called_once()
@pytest.mark.skip_unless_on_windows(reason="Only applicable for Windows.")
def test_deploy_windows_programdata():
"""