Wrap the file client in a context manager

This commit is contained in:
Megan Wilhite 2023-03-28 12:58:13 -06:00
parent 7ff9a899d9
commit 17dcb42ece
11 changed files with 189 additions and 181 deletions

View file

@ -223,36 +223,37 @@ def prep_trans_tar(
env_root = os.path.join(gendir, saltenv)
if not os.path.isdir(env_root):
os.makedirs(env_root)
for ref in file_refs[saltenv]:
for name in ref:
short = salt.utils.url.parse(name)[0].lstrip("/")
cache_dest = os.path.join(cache_dest_root, short)
try:
path = file_client.cache_file(name, saltenv, cachedir=cachedir)
except OSError:
path = ""
if path:
tgt = os.path.join(env_root, short)
tgt_dir = os.path.dirname(tgt)
if not os.path.isdir(tgt_dir):
os.makedirs(tgt_dir)
shutil.copy(path, tgt)
continue
try:
files = file_client.cache_dir(name, saltenv, cachedir=cachedir)
except OSError:
files = ""
if files:
for filename in files:
fn = filename[
len(file_client.get_cachedir(cache_dest)) :
].strip("/")
tgt = os.path.join(env_root, short, fn)
with file_client:
for ref in file_refs[saltenv]:
for name in ref:
short = salt.utils.url.parse(name)[0].lstrip("/")
cache_dest = os.path.join(cache_dest_root, short)
try:
path = file_client.cache_file(name, saltenv, cachedir=cachedir)
except OSError:
path = ""
if path:
tgt = os.path.join(env_root, short)
tgt_dir = os.path.dirname(tgt)
if not os.path.isdir(tgt_dir):
os.makedirs(tgt_dir)
shutil.copy(filename, tgt)
continue
shutil.copy(path, tgt)
continue
try:
files = file_client.cache_dir(name, saltenv, cachedir=cachedir)
except OSError:
files = ""
if files:
for filename in files:
fn = filename[
len(file_client.get_cachedir(cache_dest)) :
].strip("/")
tgt = os.path.join(env_root, short, fn)
tgt_dir = os.path.dirname(tgt)
if not os.path.isdir(tgt_dir):
os.makedirs(tgt_dir)
shutil.copy(filename, tgt)
continue
try:
# cwd may not exist if it was removed but salt was run from it
cwd = os.getcwd()

View file

@ -28,10 +28,7 @@ def update_master_cache(states, saltenv="base"):
# Setup for copying states to gendir
gendir = tempfile.mkdtemp()
trans_tar = salt.utils.files.mkstemp()
if "cp.fileclient_{}".format(id(__opts__)) not in __context__:
__context__[
"cp.fileclient_{}".format(id(__opts__))
] = salt.fileclient.get_file_client(__opts__)
client = salt.fileclient.get_file_client(__opts__)
# generate cp.list_states output and save to gendir
cp_output = salt.utils.json.dumps(__salt__["cp.list_states"]())
@ -59,34 +56,33 @@ def update_master_cache(states, saltenv="base"):
log.debug("copying %s to %s", state_name, gendir)
qualified_name = salt.utils.url.create(state_name, saltenv)
# Duplicate cp.get_dir to gendir
copy_result = __context__["cp.fileclient_{}".format(id(__opts__))].get_dir(
qualified_name, gendir, saltenv
)
if copy_result:
copy_result = [dir.replace(gendir, state_cache) for dir in copy_result]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
else:
# If files were not copied, assume state.file.sls was given and just copy state
state_name = os.path.dirname(state_name)
file_copy_file = os.path.join(gendir, state_name + ".copy")
if state_name in already_processed:
log.debug("Already cached state for %s", state_name)
with client:
copy_result = client.get_dir(qualified_name, gendir, saltenv)
if copy_result:
copy_result = [
dir.replace(gendir, state_cache) for dir in copy_result
]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
else:
qualified_name = salt.utils.url.create(state_name, saltenv)
copy_result = __context__[
"cp.fileclient_{}".format(id(__opts__))
].get_dir(qualified_name, gendir, saltenv)
if copy_result:
copy_result = [
dir.replace(gendir, state_cache) for dir in copy_result
]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
# If files were not copied, assume state.file.sls was given and just copy state
state_name = os.path.dirname(state_name)
file_copy_file = os.path.join(gendir, state_name + ".copy")
if state_name in already_processed:
log.debug("Already cached state for %s", state_name)
else:
qualified_name = salt.utils.url.create(state_name, saltenv)
copy_result = client.get_dir(qualified_name, gendir, saltenv)
if copy_result:
copy_result = [
dir.replace(gendir, state_cache) for dir in copy_result
]
copy_result_output = salt.utils.json.dumps(copy_result)
with salt.utils.files.fopen(file_copy_file, "w") as fp:
fp.write(copy_result_output)
already_processed.append(state_name)
# turn gendir into tarball and remove gendir
try:

View file

@ -321,30 +321,31 @@ def sls(root, mods, saltenv="base", test=None, exclude=None, **kwargs):
if isinstance(mods, str):
mods = mods.split(",")
high_data, errors = st_.render_highstate({saltenv: mods})
if exclude:
if isinstance(exclude, str):
exclude = exclude.split(",")
if "__exclude__" in high_data:
high_data["__exclude__"].extend(exclude)
else:
high_data["__exclude__"] = exclude
with st_:
high_data, errors = st_.render_highstate({saltenv: mods})
if exclude:
if isinstance(exclude, str):
exclude = exclude.split(",")
if "__exclude__" in high_data:
high_data["__exclude__"].extend(exclude)
else:
high_data["__exclude__"] = exclude
high_data, ext_errors = st_.state.reconcile_extend(high_data)
errors += ext_errors
errors += st_.state.verify_high(high_data)
if errors:
return errors
high_data, ext_errors = st_.state.reconcile_extend(high_data)
errors += ext_errors
errors += st_.state.verify_high(high_data)
if errors:
return errors
high_data, req_in_errors = st_.state.requisite_in(high_data)
errors += req_in_errors
if errors:
return errors
high_data, req_in_errors = st_.state.requisite_in(high_data)
errors += req_in_errors
if errors:
return errors
high_data = st_.state.apply_exclude(high_data)
high_data = st_.state.apply_exclude(high_data)
# Compile and verify the raw chunks
chunks = st_.state.compile_high_data(high_data)
# Compile and verify the raw chunks
chunks = st_.state.compile_high_data(high_data)
file_refs = salt.client.ssh.state.lowstate_file_refs(
chunks,
salt.client.ssh.wrapper.state._merge_extra_filerefs(

View file

@ -157,26 +157,11 @@ def recv_chunked(dest, chunk, append=False, compressed=True, mode=None):
pass
def _mk_client():
"""
Create a file client and add it to the context.
Each file client needs to correspond to a unique copy
of the opts dictionary, therefore it's hashed by the
id of the __opts__ dict
"""
if "cp.fileclient_{}".format(id(__opts__)) not in __context__:
__context__[
"cp.fileclient_{}".format(id(__opts__))
] = salt.fileclient.get_file_client(__opts__)
def _client():
"""
Return a client, hashed by the list of masters
"""
_mk_client()
return __context__["cp.fileclient_{}".format(id(__opts__))]
return salt.fileclient.get_file_client(__opts__)
def _render_filenames(path, dest, saltenv, template, **kw):
@ -294,7 +279,8 @@ def get_file(
if not hash_file(path, saltenv):
return ""
else:
return _client().get_file(path, dest, makedirs, saltenv, gzip)
with _client() as client:
return client.get_file(path, dest, makedirs, saltenv, gzip)
def envs():
@ -307,7 +293,8 @@ def envs():
salt '*' cp.envs
"""
return _client().envs()
with _client() as client:
return client.envs()
def get_template(path, dest, template="jinja", saltenv=None, makedirs=False, **kwargs):
@ -336,7 +323,8 @@ def get_template(path, dest, template="jinja", saltenv=None, makedirs=False, **k
kwargs["grains"] = __grains__
if "opts" not in kwargs:
kwargs["opts"] = __opts__
return _client().get_template(path, dest, template, makedirs, saltenv, **kwargs)
with _client() as client:
return client.get_template(path, dest, template, makedirs, saltenv, **kwargs)
def get_dir(path, dest, saltenv=None, template=None, gzip=None, **kwargs):
@ -359,7 +347,8 @@ def get_dir(path, dest, saltenv=None, template=None, gzip=None, **kwargs):
(path, dest) = _render_filenames(path, dest, saltenv, template, **kwargs)
return _client().get_dir(path, dest, saltenv, gzip)
with _client() as client:
return client.get_dir(path, dest, saltenv, gzip)
def get_url(path, dest="", saltenv=None, makedirs=False, source_hash=None):
@ -417,13 +406,16 @@ def get_url(path, dest="", saltenv=None, makedirs=False, source_hash=None):
saltenv = __opts__["saltenv"] or "base"
if isinstance(dest, str):
result = _client().get_url(
path, dest, makedirs, saltenv, source_hash=source_hash
)
with _client() as client:
result = client.get_url(
path, dest, makedirs, saltenv, source_hash=source_hash
)
else:
result = _client().get_url(
path, None, makedirs, saltenv, no_cache=True, source_hash=source_hash
)
with _client() as client:
result = client.get_url(
path, None, makedirs, saltenv, no_cache=True, source_hash=source_hash
)
if not result:
log.error(
"Unable to fetch file %s from saltenv %s.",
@ -550,9 +542,14 @@ def cache_file(path, saltenv=None, source_hash=None, verify_ssl=True, use_etag=F
if senv:
saltenv = senv
result = _client().cache_file(
path, saltenv, source_hash=source_hash, verify_ssl=verify_ssl, use_etag=use_etag
)
with _client() as client:
result = client.cache_file(
path,
saltenv,
source_hash=source_hash,
verify_ssl=verify_ssl,
use_etag=use_etag,
)
if not result and not use_etag:
log.error("Unable to cache file '%s' from saltenv '%s'.", path, saltenv)
if path_is_remote:
@ -587,7 +584,8 @@ def cache_dest(url, saltenv=None):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().cache_dest(url, saltenv)
with _client() as client:
return client.cache_dest(url, saltenv)
def cache_files(paths, saltenv=None):
@ -631,7 +629,8 @@ def cache_files(paths, saltenv=None):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().cache_files(paths, saltenv)
with _client() as client:
return client.cache_files(paths, saltenv)
def cache_dir(
@ -672,7 +671,8 @@ def cache_dir(
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().cache_dir(path, saltenv, include_empty, include_pat, exclude_pat)
with _client() as client:
return client.cache_dir(path, saltenv, include_empty, include_pat, exclude_pat)
def cache_master(saltenv=None):
@ -690,7 +690,8 @@ def cache_master(saltenv=None):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().cache_master(saltenv)
with _client() as client:
return client.cache_master(saltenv)
def cache_local_file(path):
@ -717,7 +718,8 @@ def cache_local_file(path):
return path_cached
# The file hasn't been cached or has changed; cache it
return _client().cache_local_file(path)
with _client() as client:
return client.cache_local_file(path)
def list_states(saltenv=None):
@ -735,7 +737,8 @@ def list_states(saltenv=None):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().list_states(saltenv)
with _client() as client:
return client.list_states(saltenv)
def list_master(saltenv=None, prefix=""):
@ -753,7 +756,8 @@ def list_master(saltenv=None, prefix=""):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().file_list(saltenv, prefix)
with _client() as client:
return client.file_list(saltenv, prefix)
def list_master_dirs(saltenv=None, prefix=""):
@ -771,7 +775,8 @@ def list_master_dirs(saltenv=None, prefix=""):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().dir_list(saltenv, prefix)
with _client() as client:
return client.dir_list(saltenv, prefix)
def list_master_symlinks(saltenv=None, prefix=""):
@ -789,7 +794,8 @@ def list_master_symlinks(saltenv=None, prefix=""):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().symlink_list(saltenv, prefix)
with _client() as client:
return client.symlink_list(saltenv, prefix)
def list_minion(saltenv=None):
@ -807,7 +813,8 @@ def list_minion(saltenv=None):
"""
if not saltenv:
saltenv = __opts__["saltenv"] or "base"
return _client().file_local_list(saltenv)
with _client() as client:
return client.file_local_list(saltenv)
def is_cached(path, saltenv=None):
@ -831,7 +838,8 @@ def is_cached(path, saltenv=None):
if senv:
saltenv = senv
return _client().is_cached(path, saltenv)
with _client() as client:
return client.is_cached(path, saltenv)
def hash_file(path, saltenv=None):
@ -856,7 +864,8 @@ def hash_file(path, saltenv=None):
if senv:
saltenv = senv
return _client().hash_file(path, saltenv)
with _client() as client:
return client.hash_file(path, saltenv)
def stat_file(path, saltenv=None, octal=True):
@ -881,7 +890,8 @@ def stat_file(path, saltenv=None, octal=True):
if senv:
saltenv = senv
stat = _client().hash_and_stat_file(path, saltenv)[1]
with _client() as client:
stat = client.hash_and_stat_file(path, saltenv)[1]
if stat is None:
return stat
return salt.utils.files.st_mode_to_octal(stat[0]) if octal is True else stat[0]

View file

@ -25,8 +25,7 @@ def _mk_client():
"""
Create a file client and add it to the context
"""
if "cp.fileclient" not in __context__:
__context__["cp.fileclient"] = salt.fileclient.get_file_client(__opts__)
return salt.fileclient.get_file_client(__opts__)
def _load(formula):
@ -44,7 +43,8 @@ def _load(formula):
source_url = salt.utils.url.create(formula + "/defaults." + ext)
paths.append(source_url)
# Fetch files from master
defaults_files = __context__["cp.fileclient"].cache_files(paths)
with _mk_client() as client:
defaults_files = client.cache_files(paths)
for file_ in defaults_files:
if not file_:

View file

@ -6648,8 +6648,7 @@ def _mk_fileclient():
"""
Create a file client and add it to the context.
"""
if "cp.fileclient" not in __context__:
__context__["cp.fileclient"] = salt.fileclient.get_file_client(__opts__)
return salt.fileclient.get_file_client(__opts__)
def _generate_tmp_path():
@ -6665,9 +6664,8 @@ def _prepare_trans_tar(name, sls_opts, mods=None, pillar=None, extra_filerefs=""
# reuse it from salt.ssh, however this function should
# be somewhere else
refs = salt.client.ssh.state.lowstate_file_refs(chunks, extra_filerefs)
_mk_fileclient()
trans_tar = salt.client.ssh.state.prep_trans_tar(
__context__["cp.fileclient"], chunks, refs, pillar, name
_mk_fileclient(), chunks, refs, pillar, name
)
return trans_tar

View file

@ -461,15 +461,16 @@ def render(template, saltenv="base", sls="", salt_data=True, **kwargs):
# that we're importing everything
imports = None
state_file = client.cache_file(import_file, saltenv)
if not state_file:
raise ImportError(
"Could not find the file '{}'".format(import_file)
)
with client:
state_file = client.cache_file(import_file, saltenv)
if not state_file:
raise ImportError(
"Could not find the file '{}'".format(import_file)
)
with salt.utils.files.fopen(state_file) as state_fh:
state_contents, state_globals = process_template(state_fh)
exec(state_contents, state_globals)
with salt.utils.files.fopen(state_file) as state_fh:
state_contents, state_globals = process_template(state_fh)
exec(state_contents, state_globals)
# if no imports have been specified then we are being imported as: import salt://foo.sls
# so we want to stick all of the locals from our state file into the template globals

View file

@ -171,7 +171,8 @@ def playbooks(name, rundir=None, git_repo=None, git_kwargs=None, ansible_kwargs=
}
if git_repo:
if not isinstance(rundir, str) or not os.path.isdir(rundir):
rundir = _client()._extrn_path(git_repo, "base")
with _client() as client:
rundir = client._extrn_path(git_repo, "base")
log.trace("rundir set to %s", rundir)
if not isinstance(git_kwargs, dict):
log.debug("Setting git_kwargs to empty dict: %s", git_kwargs)

View file

@ -76,12 +76,12 @@ def sync(opts, form, saltenv=None, extmod_whitelist=None, extmod_blacklist=None)
mod_dir,
)
fileclient = salt.fileclient.get_file_client(opts)
for sub_env in saltenv:
log.info("Syncing %s for environment '%s'", form, sub_env)
cache = []
log.info("Loading cache from %s, for %s", source, sub_env)
# Grab only the desired files (.py, .pyx, .so)
with fileclient:
with fileclient:
for sub_env in saltenv:
log.info("Syncing %s for environment '%s'", form, sub_env)
cache = []
log.info("Loading cache from %s, for %s", source, sub_env)
# Grab only the desired files (.py, .pyx, .so)
cache.extend(
fileclient.cache_dir(
source,
@ -91,43 +91,43 @@ def sync(opts, form, saltenv=None, extmod_whitelist=None, extmod_blacklist=None)
exclude_pat=None,
)
)
local_cache_dir = os.path.join(
opts["cachedir"], "files", sub_env, "_{}".format(form)
)
log.debug("Local cache dir: '%s'", local_cache_dir)
for fn_ in cache:
relpath = os.path.relpath(fn_, local_cache_dir)
relname = os.path.splitext(relpath)[0].replace(os.sep, ".")
if (
extmod_whitelist
and form in extmod_whitelist
and relname not in extmod_whitelist[form]
):
continue
if (
extmod_blacklist
and form in extmod_blacklist
and relname in extmod_blacklist[form]
):
continue
remote.add(relpath)
dest = os.path.join(mod_dir, relpath)
log.info("Copying '%s' to '%s'", fn_, dest)
if os.path.isfile(dest):
# The file is present, if the sum differs replace it
hash_type = opts.get("hash_type", "md5")
src_digest = salt.utils.hashutils.get_hash(fn_, hash_type)
dst_digest = salt.utils.hashutils.get_hash(dest, hash_type)
if src_digest != dst_digest:
# The downloaded file differs, replace!
local_cache_dir = os.path.join(
opts["cachedir"], "files", sub_env, "_{}".format(form)
)
log.debug("Local cache dir: '%s'", local_cache_dir)
for fn_ in cache:
relpath = os.path.relpath(fn_, local_cache_dir)
relname = os.path.splitext(relpath)[0].replace(os.sep, ".")
if (
extmod_whitelist
and form in extmod_whitelist
and relname not in extmod_whitelist[form]
):
continue
if (
extmod_blacklist
and form in extmod_blacklist
and relname in extmod_blacklist[form]
):
continue
remote.add(relpath)
dest = os.path.join(mod_dir, relpath)
log.info("Copying '%s' to '%s'", fn_, dest)
if os.path.isfile(dest):
# The file is present, if the sum differs replace it
hash_type = opts.get("hash_type", "md5")
src_digest = salt.utils.hashutils.get_hash(fn_, hash_type)
dst_digest = salt.utils.hashutils.get_hash(dest, hash_type)
if src_digest != dst_digest:
# The downloaded file differs, replace!
shutil.copyfile(fn_, dest)
ret.append("{}.{}".format(form, relname))
else:
dest_dir = os.path.dirname(dest)
if not os.path.isdir(dest_dir):
os.makedirs(dest_dir)
shutil.copyfile(fn_, dest)
ret.append("{}.{}".format(form, relname))
else:
dest_dir = os.path.dirname(dest)
if not os.path.isdir(dest_dir):
os.makedirs(dest_dir)
shutil.copyfile(fn_, dest)
ret.append("{}.{}".format(form, relname))
touched = bool(ret)
if opts["clean_dynamic_modules"] is True:

View file

@ -128,7 +128,8 @@ class SaltCacheLoader(BaseLoader):
"""
saltpath = salt.utils.url.create(template)
fcl = self.file_client()
return fcl.get_file(saltpath, "", True, self.saltenv)
with fcl:
return fcl.get_file(saltpath, "", True, self.saltenv)
def check_cache(self, template):
"""

View file

@ -94,6 +94,5 @@ if HAS_MAKO:
def cache_file(self, fpath):
if fpath not in self.cache:
self.cache[fpath] = self.file_client().get_file(
fpath, "", True, self.saltenv
)
with self.file_client() as client:
self.cache[fpath] = client.get_file(fpath, "", True, self.saltenv)