Make sure the file client is destroyed upon used

Signed-off-by: Pedro Algarvio <palgarvio@vmware.com>
This commit is contained in:
Pedro Algarvio 2023-04-19 17:25:49 +01:00 committed by Pedro Algarvio
parent 1a63beb506
commit 25821ae58a
9 changed files with 250 additions and 243 deletions

View file

@ -9,6 +9,7 @@ import tarfile
import tempfile
from contextlib import closing
import salt.fileclient
import salt.utils.files
import salt.utils.json
import salt.utils.url
@ -28,65 +29,66 @@ def update_master_cache(states, saltenv="base"):
# Setup for copying states to gendir
gendir = tempfile.mkdtemp()
trans_tar = salt.utils.files.mkstemp()
if "cp.fileclient_{}".format(id(__opts__)) not in __context__:
__context__[
"cp.fileclient_{}".format(id(__opts__))
] = salt.fileclient.get_file_client(__opts__)
cp_fileclient_ctx_key = "cp.fileclient_{}".format(id(__opts__))
if cp_fileclient_ctx_key not in __context__:
__context__[cp_fileclient_ctx_key] = salt.fileclient.get_file_client(__opts__)
# generate cp.list_states output and save to gendir
cp_output = salt.utils.json.dumps(__salt__["cp.list_states"]())
cp_output_file = os.path.join(gendir, "cp_output.txt")
with salt.utils.files.fopen(cp_output_file, "w") as fp:
fp.write(cp_output)
with __context__[cp_fileclient_ctx_key] as cp_fileclient:
# cp state directories to gendir
already_processed = []
sls_list = salt.utils.args.split_input(states)
for state_name in sls_list:
# generate low data for each state and save to gendir
state_low_file = os.path.join(gendir, state_name + ".low")
state_low_output = salt.utils.json.dumps(
__salt__["state.show_low_sls"](state_name)
)
with salt.utils.files.fopen(state_low_file, "w") as fp:
fp.write(state_low_output)
# generate cp.list_states output and save to gendir
cp_output = salt.utils.json.dumps(__salt__["cp.list_states"]())
cp_output_file = os.path.join(gendir, "cp_output.txt")
with salt.utils.files.fopen(cp_output_file, "w") as fp:
fp.write(cp_output)
state_name = state_name.replace(".", os.sep)
if state_name in already_processed:
log.debug("Already cached state for %s", state_name)
else:
file_copy_file = os.path.join(gendir, state_name + ".copy")
log.debug("copying %s to %s", state_name, gendir)
qualified_name = salt.utils.url.create(state_name, saltenv)
# Duplicate cp.get_dir to gendir
copy_result = __context__["cp.fileclient_{}".format(id(__opts__))].get_dir(
qualified_name, gendir, saltenv
# cp state directories to gendir
already_processed = []
sls_list = salt.utils.args.split_input(states)
for state_name in sls_list:
# generate low data for each state and save to gendir
state_low_file = os.path.join(gendir, state_name + ".low")
state_low_output = salt.utils.json.dumps(
__salt__["state.show_low_sls"](state_name)
)
if copy_result:
copy_result = [dir.replace(gendir, state_cache) for dir in copy_result]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
with salt.utils.files.fopen(state_low_file, "w") as fp:
fp.write(state_low_output)
state_name = state_name.replace(".", os.sep)
if state_name in already_processed:
log.debug("Already cached state for %s", state_name)
else:
# If files were not copied, assume state.file.sls was given and just copy state
state_name = os.path.dirname(state_name)
file_copy_file = os.path.join(gendir, state_name + ".copy")
if state_name in already_processed:
log.debug("Already cached state for %s", state_name)
log.debug("copying %s to %s", state_name, gendir)
qualified_name = salt.utils.url.create(state_name, saltenv)
# Duplicate cp.get_dir to gendir
copy_result = cp_fileclient.get_dir(qualified_name, gendir, saltenv)
if copy_result:
copy_result = [
dir.replace(gendir, state_cache) for dir in copy_result
]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
else:
qualified_name = salt.utils.url.create(state_name, saltenv)
copy_result = __context__[
"cp.fileclient_{}".format(id(__opts__))
].get_dir(qualified_name, gendir, saltenv)
if copy_result:
copy_result = [
dir.replace(gendir, state_cache) for dir in copy_result
]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
# If files were not copied, assume state.file.sls was given and just copy state
state_name = os.path.dirname(state_name)
file_copy_file = os.path.join(gendir, state_name + ".copy")
if state_name in already_processed:
log.debug("Already cached state for %s", state_name)
else:
qualified_name = salt.utils.url.create(state_name, saltenv)
copy_result = cp_fileclient.get_dir(
qualified_name, gendir, saltenv
)
if copy_result:
copy_result = [
dir.replace(gendir, state_cache) for dir in copy_result
]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
# turn gendir into tarball and remove gendir
try:

View file

@ -849,7 +849,6 @@ class Client:
kwargs.pop("env")
kwargs["saltenv"] = saltenv
url_data = urllib.parse.urlparse(url)
sfn = self.cache_file(url, saltenv, cachedir=cachedir)
if not sfn or not os.path.exists(sfn):
return ""
@ -1165,13 +1164,8 @@ class RemoteClient(Client):
if not salt.utils.platform.is_windows():
hash_server, stat_server = self.hash_and_stat_file(path, saltenv)
try:
mode_server = stat_server[0]
except (IndexError, TypeError):
mode_server = None
else:
hash_server = self.hash_file(path, saltenv)
mode_server = None
# Check if file exists on server, before creating files and
# directories
@ -1214,13 +1208,8 @@ class RemoteClient(Client):
if dest2check and os.path.isfile(dest2check):
if not salt.utils.platform.is_windows():
hash_local, stat_local = self.hash_and_stat_file(dest2check, saltenv)
try:
mode_local = stat_local[0]
except (IndexError, TypeError):
mode_local = None
else:
hash_local = self.hash_file(dest2check, saltenv)
mode_local = None
if hash_local == hash_server:
return dest2check

View file

@ -6644,14 +6644,6 @@ def script_retcode(
)["retcode"]
def _mk_fileclient():
"""
Create a file client and add it to the context.
"""
if "cp.fileclient" not in __context__:
__context__["cp.fileclient"] = salt.fileclient.get_file_client(__opts__)
def _generate_tmp_path():
return os.path.join("/tmp", "salt.docker.{}".format(uuid.uuid4().hex[:6]))
@ -6665,11 +6657,10 @@ def _prepare_trans_tar(name, sls_opts, mods=None, pillar=None, extra_filerefs=""
# reuse it from salt.ssh, however this function should
# be somewhere else
refs = salt.client.ssh.state.lowstate_file_refs(chunks, extra_filerefs)
_mk_fileclient()
trans_tar = salt.client.ssh.state.prep_trans_tar(
__context__["cp.fileclient"], chunks, refs, pillar, name
)
return trans_tar
with salt.fileclient.get_file_client(__opts__) as fileclient:
return salt.client.ssh.state.prep_trans_tar(
fileclient, chunks, refs, pillar, name
)
def _compile_state(sls_opts, mods=None):

View file

@ -9,7 +9,6 @@ import logging
import os
import sys
import traceback
import uuid
import salt.channel.client
import salt.ext.tornado.gen
@ -1341,6 +1340,11 @@ class Pillar:
if self._closing:
return
self._closing = True
if self.client:
try:
self.client.destroy()
except AttributeError:
pass
# pylint: disable=W1701
def __del__(self):

View file

@ -32,12 +32,10 @@ state:
- state: installed
"""
import logging
import os
import sys
# Import salt modules
import salt.fileclient
import salt.utils.decorators.path
from salt.utils.decorators import depends
@ -108,13 +106,6 @@ def __virtual__():
return __virtualname__
def _client():
"""
Get a fileclient
"""
return salt.fileclient.get_file_client(__opts__)
def _changes(plays):
"""
Find changes in ansible return data
@ -171,7 +162,7 @@ def playbooks(name, rundir=None, git_repo=None, git_kwargs=None, ansible_kwargs=
}
if git_repo:
if not isinstance(rundir, str) or not os.path.isdir(rundir):
with _client() as client:
with salt.fileclient.get_file_client(__opts__) as client:
rundir = client._extrn_path(git_repo, "base")
log.trace("rundir set to %s", rundir)
if not isinstance(git_kwargs, dict):

View file

@ -131,7 +131,7 @@ class SyncWrapper:
result = io_loop.run_sync(lambda: getattr(self.obj, key)(*args, **kwargs))
results.append(True)
results.append(result)
except Exception as exc: # pylint: disable=broad-except
except Exception: # pylint: disable=broad-except
results.append(False)
results.append(sys.exc_info())

View file

@ -221,6 +221,24 @@ class SaltCacheLoader(BaseLoader):
# there is no template file within searchpaths
raise TemplateNotFound(template)
def destroy(self):
for attr in ("_cached_client", "_cached_pillar_client"):
client = getattr(self, attr, None)
if client is not None:
try:
client.destroy()
except AttributeError:
# PillarClient and LocalClient objects do not have a destroy method
pass
setattr(self, attr, None)
def __enter__(self):
self.file_client()
return self
def __exit__(self, *args):
self.destroy()
class PrintableDict(OrderedDict):
"""

View file

@ -97,3 +97,10 @@ if HAS_MAKO:
self.cache[fpath] = self.file_client().get_file(
fpath, "", True, self.saltenv
)
def destroy(self):
if self.client:
try:
self.client.destroy()
except AttributeError:
pass

View file

@ -362,163 +362,169 @@ def render_jinja_tmpl(tmplstr, context, tmplpath=None):
elif tmplstr.endswith("\n"):
newline = "\n"
if not saltenv:
if tmplpath:
loader = jinja2.FileSystemLoader(os.path.dirname(tmplpath))
else:
loader = salt.utils.jinja.SaltCacheLoader(
opts,
saltenv,
pillar_rend=context.get("_pillar_rend", False),
_file_client=file_client,
)
env_args = {"extensions": [], "loader": loader}
if hasattr(jinja2.ext, "with_"):
env_args["extensions"].append("jinja2.ext.with_")
if hasattr(jinja2.ext, "do"):
env_args["extensions"].append("jinja2.ext.do")
if hasattr(jinja2.ext, "loopcontrols"):
env_args["extensions"].append("jinja2.ext.loopcontrols")
env_args["extensions"].append(salt.utils.jinja.SerializerExtension)
opt_jinja_env = opts.get("jinja_env", {})
opt_jinja_sls_env = opts.get("jinja_sls_env", {})
opt_jinja_env = opt_jinja_env if isinstance(opt_jinja_env, dict) else {}
opt_jinja_sls_env = opt_jinja_sls_env if isinstance(opt_jinja_sls_env, dict) else {}
# Pass through trim_blocks and lstrip_blocks Jinja parameters
# trim_blocks removes newlines around Jinja blocks
# lstrip_blocks strips tabs and spaces from the beginning of
# line to the start of a block.
if opts.get("jinja_trim_blocks", False):
log.debug("Jinja2 trim_blocks is enabled")
log.warning(
"jinja_trim_blocks is deprecated and will be removed in a future release,"
" please use jinja_env and/or jinja_sls_env instead"
)
opt_jinja_env["trim_blocks"] = True
opt_jinja_sls_env["trim_blocks"] = True
if opts.get("jinja_lstrip_blocks", False):
log.debug("Jinja2 lstrip_blocks is enabled")
log.warning(
"jinja_lstrip_blocks is deprecated and will be removed in a future release,"
" please use jinja_env and/or jinja_sls_env instead"
)
opt_jinja_env["lstrip_blocks"] = True
opt_jinja_sls_env["lstrip_blocks"] = True
def opt_jinja_env_helper(opts, optname):
for k, v in opts.items():
k = k.lower()
if hasattr(jinja2.defaults, k.upper()):
log.debug("Jinja2 environment %s was set to %s by %s", k, v, optname)
env_args[k] = v
else:
log.warning("Jinja2 environment %s is not recognized", k)
if "sls" in context and context["sls"] != "":
opt_jinja_env_helper(opt_jinja_sls_env, "jinja_sls_env")
else:
opt_jinja_env_helper(opt_jinja_env, "jinja_env")
if opts.get("allow_undefined", False):
jinja_env = jinja2.sandbox.SandboxedEnvironment(**env_args)
else:
jinja_env = jinja2.sandbox.SandboxedEnvironment(
undefined=jinja2.StrictUndefined, **env_args
)
indent_filter = jinja_env.filters.get("indent")
jinja_env.tests.update(JinjaTest.salt_jinja_tests)
jinja_env.filters.update(JinjaFilter.salt_jinja_filters)
if salt.utils.jinja.JINJA_VERSION >= Version("2.11"):
# Use the existing indent filter on Jinja versions where it's not broken
jinja_env.filters["indent"] = indent_filter
jinja_env.globals.update(JinjaGlobal.salt_jinja_globals)
# globals
jinja_env.globals["odict"] = OrderedDict
jinja_env.globals["show_full_context"] = salt.utils.jinja.show_full_context
jinja_env.tests["list"] = salt.utils.data.is_list
decoded_context = {}
for key, value in context.items():
if not isinstance(value, str):
if isinstance(value, NamedLoaderContext):
decoded_context[key] = value.value()
else:
decoded_context[key] = value
continue
try:
decoded_context[key] = salt.utils.stringutils.to_unicode(
value, encoding=SLS_ENCODING
)
except UnicodeDecodeError as ex:
log.debug(
"Failed to decode using default encoding (%s), trying system encoding",
SLS_ENCODING,
)
decoded_context[key] = salt.utils.data.decode(value)
jinja_env.globals.update(decoded_context)
try:
template = jinja_env.from_string(tmplstr)
output = template.render(**decoded_context)
except jinja2.exceptions.UndefinedError as exc:
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
raise SaltRenderError("Jinja variable {}{}".format(exc, out), line, tmplstr)
except (
jinja2.exceptions.TemplateRuntimeError,
jinja2.exceptions.TemplateSyntaxError,
jinja2.exceptions.SecurityError,
) as exc:
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
raise SaltRenderError(
"Jinja syntax error: {}{}".format(exc, out), line, tmplstr
)
except (SaltInvocationError, CommandExecutionError) as exc:
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
raise SaltRenderError(
"Problem running salt function in Jinja template: {}{}".format(exc, out),
line,
tmplstr,
)
except Exception as exc: # pylint: disable=broad-except
tracestr = traceback.format_exc()
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
if not saltenv:
if tmplpath:
loader = jinja2.FileSystemLoader(os.path.dirname(tmplpath))
else:
tmplstr += "\n{}".format(tracestr)
log.debug("Jinja Error")
log.debug("Exception:", exc_info=True)
log.debug("Out: %s", out)
log.debug("Line: %s", line)
log.debug("TmplStr: %s", tmplstr)
log.debug("TraceStr: %s", tracestr)
loader = salt.utils.jinja.SaltCacheLoader(
opts,
saltenv,
pillar_rend=context.get("_pillar_rend", False),
_file_client=file_client,
)
raise SaltRenderError(
"Jinja error: {}{}".format(exc, out), line, tmplstr, trace=tracestr
env_args = {"extensions": [], "loader": loader}
if hasattr(jinja2.ext, "with_"):
env_args["extensions"].append("jinja2.ext.with_")
if hasattr(jinja2.ext, "do"):
env_args["extensions"].append("jinja2.ext.do")
if hasattr(jinja2.ext, "loopcontrols"):
env_args["extensions"].append("jinja2.ext.loopcontrols")
env_args["extensions"].append(salt.utils.jinja.SerializerExtension)
opt_jinja_env = opts.get("jinja_env", {})
opt_jinja_sls_env = opts.get("jinja_sls_env", {})
opt_jinja_env = opt_jinja_env if isinstance(opt_jinja_env, dict) else {}
opt_jinja_sls_env = (
opt_jinja_sls_env if isinstance(opt_jinja_sls_env, dict) else {}
)
# Pass through trim_blocks and lstrip_blocks Jinja parameters
# trim_blocks removes newlines around Jinja blocks
# lstrip_blocks strips tabs and spaces from the beginning of
# line to the start of a block.
if opts.get("jinja_trim_blocks", False):
log.debug("Jinja2 trim_blocks is enabled")
log.warning(
"jinja_trim_blocks is deprecated and will be removed in a future release,"
" please use jinja_env and/or jinja_sls_env instead"
)
opt_jinja_env["trim_blocks"] = True
opt_jinja_sls_env["trim_blocks"] = True
if opts.get("jinja_lstrip_blocks", False):
log.debug("Jinja2 lstrip_blocks is enabled")
log.warning(
"jinja_lstrip_blocks is deprecated and will be removed in a future release,"
" please use jinja_env and/or jinja_sls_env instead"
)
opt_jinja_env["lstrip_blocks"] = True
opt_jinja_sls_env["lstrip_blocks"] = True
def opt_jinja_env_helper(opts, optname):
for k, v in opts.items():
k = k.lower()
if hasattr(jinja2.defaults, k.upper()):
log.debug(
"Jinja2 environment %s was set to %s by %s", k, v, optname
)
env_args[k] = v
else:
log.warning("Jinja2 environment %s is not recognized", k)
if "sls" in context and context["sls"] != "":
opt_jinja_env_helper(opt_jinja_sls_env, "jinja_sls_env")
else:
opt_jinja_env_helper(opt_jinja_env, "jinja_env")
if opts.get("allow_undefined", False):
jinja_env = jinja2.sandbox.SandboxedEnvironment(**env_args)
else:
jinja_env = jinja2.sandbox.SandboxedEnvironment(
undefined=jinja2.StrictUndefined, **env_args
)
indent_filter = jinja_env.filters.get("indent")
jinja_env.tests.update(JinjaTest.salt_jinja_tests)
jinja_env.filters.update(JinjaFilter.salt_jinja_filters)
if salt.utils.jinja.JINJA_VERSION >= Version("2.11"):
# Use the existing indent filter on Jinja versions where it's not broken
jinja_env.filters["indent"] = indent_filter
jinja_env.globals.update(JinjaGlobal.salt_jinja_globals)
# globals
jinja_env.globals["odict"] = OrderedDict
jinja_env.globals["show_full_context"] = salt.utils.jinja.show_full_context
jinja_env.tests["list"] = salt.utils.data.is_list
decoded_context = {}
for key, value in context.items():
if not isinstance(value, str):
if isinstance(value, NamedLoaderContext):
decoded_context[key] = value.value()
else:
decoded_context[key] = value
continue
try:
decoded_context[key] = salt.utils.stringutils.to_unicode(
value, encoding=SLS_ENCODING
)
except UnicodeDecodeError:
log.debug(
"Failed to decode using default encoding (%s), trying system encoding",
SLS_ENCODING,
)
decoded_context[key] = salt.utils.data.decode(value)
jinja_env.globals.update(decoded_context)
try:
template = jinja_env.from_string(tmplstr)
output = template.render(**decoded_context)
except jinja2.exceptions.UndefinedError as exc:
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
raise SaltRenderError("Jinja variable {}{}".format(exc, out), line, tmplstr)
except (
jinja2.exceptions.TemplateRuntimeError,
jinja2.exceptions.TemplateSyntaxError,
jinja2.exceptions.SecurityError,
) as exc:
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
raise SaltRenderError(
"Jinja syntax error: {}{}".format(exc, out), line, tmplstr
)
except (SaltInvocationError, CommandExecutionError) as exc:
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
raise SaltRenderError(
"Problem running salt function in Jinja template: {}{}".format(
exc, out
),
line,
tmplstr,
)
except Exception as exc: # pylint: disable=broad-except
tracestr = traceback.format_exc()
trace = traceback.extract_tb(sys.exc_info()[2])
line, out = _get_jinja_error(trace, context=decoded_context)
if not line:
tmplstr = ""
else:
tmplstr += "\n{}".format(tracestr)
log.debug("Jinja Error")
log.debug("Exception:", exc_info=True)
log.debug("Out: %s", out)
log.debug("Line: %s", line)
log.debug("TmplStr: %s", tmplstr)
log.debug("TraceStr: %s", tracestr)
raise SaltRenderError(
"Jinja error: {}{}".format(exc, out), line, tmplstr, trace=tracestr
)
finally:
if loader and hasattr(loader, "_file_client"):
if hasattr(loader._file_client, "destroy"):
loader._file_client.destroy()
if loader and isinstance(loader, salt.utils.jinja.SaltCacheLoader):
loader.destroy()
# Workaround a bug in Jinja that removes the final newline
# (https://github.com/mitsuhiko/jinja2/issues/75)
@ -569,9 +575,8 @@ def render_mako_tmpl(tmplstr, context, tmplpath=None):
except Exception: # pylint: disable=broad-except
raise SaltRenderError(mako.exceptions.text_error_template().render())
finally:
if lookup and hasattr(lookup, "_file_client"):
if hasattr(lookup._file_client, "destroy"):
lookup._file_client.destroy()
if lookup and isinstance(lookup, SaltMakoTemplateLookup):
lookup.destroy()
def render_wempy_tmpl(tmplstr, context, tmplpath=None):