Merge pull request #54277 from dwoz/win_runas_plus

Win runas plus
This commit is contained in:
Daniel Wozniak 2019-08-21 16:59:15 -07:00 committed by GitHub
commit de7776225d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 25 deletions

View file

@ -680,23 +680,16 @@ class ProcessManager(object):
class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
def __new__(cls, *args, **kwargs):
instance = super(MultiprocessingProcess, cls).__new__(cls)
# Patch the run method at runtime because decorating the run method
# with a function with a similar behavior would be ignored once this
# class'es run method is overridden.
instance._original_run = instance.run
instance.run = instance._run
return instance
def __init__(self, *args, **kwargs):
log_queue = kwargs.pop('log_queue', None)
log_queue_level = kwargs.pop('log_queue_level', None)
super(MultiprocessingProcess, self).__init__(*args, **kwargs)
if salt.utils.platform.is_windows():
# On Windows, subclasses should call super if they define
# __setstate__ and/or __getstate__
self._args_for_getstate = copy.copy(args)
self._kwargs_for_getstate = copy.copy(kwargs)
self.log_queue = kwargs.pop('log_queue', None)
self.log_queue = log_queue
if self.log_queue is None:
self.log_queue = salt.log.setup.get_multiprocessing_logging_queue()
else:
@ -704,16 +697,12 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
# salt.log.setup.get_multiprocessing_logging_queue().
salt.log.setup.set_multiprocessing_logging_queue(self.log_queue)
self.log_queue_level = kwargs.pop('log_queue_level', None)
self.log_queue_level = log_queue_level
if self.log_queue_level is None:
self.log_queue_level = salt.log.setup.get_multiprocessing_logging_level()
else:
salt.log.setup.set_multiprocessing_logging_level(self.log_queue_level)
# Call __init__ from 'multiprocessing.Process' only after removing
# 'log_queue' and 'log_queue_level' from kwargs.
super(MultiprocessingProcess, self).__init__(*args, **kwargs)
self._after_fork_methods = [
(MultiprocessingProcess._setup_process_logging, [self], {}),
]
@ -737,10 +726,6 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
kwargs['log_queue'] = self.log_queue
if 'log_queue_level' not in kwargs:
kwargs['log_queue_level'] = self.log_queue_level
# Remove the version of these in the parent process since
# they are no longer needed.
del self._args_for_getstate
del self._kwargs_for_getstate
return {'args': args,
'kwargs': kwargs,
'_after_fork_methods': self._after_fork_methods,
@ -750,11 +735,11 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
def _setup_process_logging(self):
salt.log.setup.setup_multiprocessing_logging(self.log_queue)
def _run(self):
def run(self):
for method, args, kwargs in self._after_fork_methods:
method(*args, **kwargs)
try:
return self._original_run()
return super(MultiprocessingProcess, self).run()
except SystemExit:
# These are handled by multiprocessing.Process._bootstrap()
raise

View file

@ -172,6 +172,7 @@ def runas(cmdLine, username, password=None, cwd=None):
# Create the environment for the user
env = win32profile.CreateEnvironmentBlock(user_token, False)
hProcess = None
try:
# Start the process in a suspended state.
process_info = salt.platform.win.CreateProcessWithTokenW(
@ -216,7 +217,8 @@ def runas(cmdLine, username, password=None, cwd=None):
stderr = f_err.read()
ret['stderr'] = stderr
finally:
salt.platform.win.kernel32.CloseHandle(hProcess)
if hProcess is not None:
salt.platform.win.kernel32.CloseHandle(hProcess)
win32api.CloseHandle(th)
win32api.CloseHandle(user_token)
if impersonation_token:

View file

@ -327,7 +327,7 @@ class TestSignalHandlingMultiprocessingProcess(TestCase):
log_to_mock = 'salt.utils.process.MultiprocessingProcess._setup_process_logging'
with patch(sig_to_mock) as ma, patch(log_to_mock) as mb:
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
self.sh_proc._run()
self.sh_proc.run()
ma.assert_called()
mb.assert_called()
@ -342,7 +342,7 @@ class TestSignalHandlingMultiprocessingProcess(TestCase):
with patch(sig_to_mock):
with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb:
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
self.sh_proc._run()
self.sh_proc.run()
ma.assert_called()
mb.assert_called()