Allow recursive salt:// imports

Fixes #30465
This commit is contained in:
Evan Borgstrom 2016-01-20 11:20:50 -08:00
parent 51bfa16173
commit b46df0e4b5
2 changed files with 84 additions and 48 deletions

View file

@ -154,9 +154,6 @@ See the Map Data section for a more practical use.
Caveats:
* You cannot use the ``as`` syntax, you can only import objects using their
existing name.
* Imported objects are ALWAYS put into the global scope of your template,
regardless of where your import statement is.
@ -262,6 +259,7 @@ TODO
from __future__ import absolute_import
import logging
import os
import re
from salt.ext.six import exec_
@ -272,9 +270,9 @@ from salt.fileclient import get_file_client
from salt.utils.pyobjects import Registry, StateFactory, SaltObject, Map
# our import regexes
FROM_RE = r'^\s*from\s+(salt:\/\/.*)\s+import (.*)$'
IMPORT_RE = r'^\s*import\s+(salt:\/\/.*)$'
FROM_AS_RE = r'^(.*) as (.*)$'
FROM_RE = re.compile(r'^\s*from\s+(salt:\/\/.*)\s+import (.*)$')
IMPORT_RE = re.compile(r'^\s*import\s+(salt:\/\/.*)$')
FROM_AS_RE = re.compile(r'^(.*) as (.*)$')
log = logging.getLogger(__name__)
@ -284,6 +282,15 @@ except NameError:
__context__ = {}
class PyobjectsModule(object):
'''This projects a wrapper for bare imports.'''
def __init__(self, name, attrs):
self.name = name
self.__dict__ = attrs
def __repr__(self):
return "<module '%s' (pyobjects)>" % self.name
def load_states():
'''
This loads our states into the salt __context__
@ -374,59 +381,69 @@ def render(template, saltenv='base', sls='', salt_data=True, **kwargs):
# so that they may bring in objects from other files. while we do this we
# disable the registry since all we're looking for here is python objects,
# not salt state data
template_data = []
Registry.enabled = False
for line in template.readlines():
line = line.rstrip('\r\n')
matched = False
for RE in (IMPORT_RE, FROM_RE):
matches = re.match(RE, line)
if not matches:
continue
import_file = matches.group(1).strip()
try:
imports = matches.group(2).split(',')
except IndexError:
# if we don't have a third group in the matches object it means
# that we're importing everything
imports = None
def process_template(template, template_globals):
template_data = []
state_globals = {}
for line in template.readlines():
line = line.rstrip('\r\n')
matched = False
for RE in (IMPORT_RE, FROM_RE):
matches = RE.match(line)
if not matches:
continue
state_file = client.cache_file(import_file, saltenv)
if not state_file:
raise ImportError("Could not find the file {0!r}".format(import_file))
import_file = matches.group(1).strip()
try:
imports = matches.group(2).split(',')
except IndexError:
# if we don't have a third group in the matches object it means
# that we're importing everything
imports = None
with salt.utils.fopen(state_file) as f:
state_contents = f.read()
state_file = client.cache_file(import_file, saltenv)
if not state_file:
raise ImportError("Could not find the file {0!r}".format(import_file))
state_locals = {}
exec_(state_contents, _globals, state_locals)
state_locals = {}
with salt.utils.fopen(state_file) as state_fh:
state_contents, state_locals = process_template(state_fh, template_globals)
exec_(state_contents, template_globals, state_locals)
if imports is None:
imports = list(state_locals.keys())
# if no imports have been specified then we are being imported as: import salt://foo.sls
# so we want to stick all of the locals from our state file into the template globals
# under the name of the module -> i.e. foo.MapClass
if imports is None:
import_name = os.path.splitext(os.path.basename(state_file))[0]
state_globals[import_name] = PyobjectsModule(import_name, state_locals)
else:
for name in imports:
name = alias = name.strip()
for name in imports:
name = alias = name.strip()
matches = FROM_AS_RE.match(name)
if matches is not None:
name = matches.group(1).strip()
alias = matches.group(2).strip()
matches = re.match(FROM_AS_RE, name)
if matches is not None:
name = matches.group(1).strip()
alias = matches.group(2).strip()
if name not in state_locals:
raise ImportError("{0!r} was not found in {1!r}".format(
name,
import_file
))
state_globals[alias] = state_locals[name]
if name not in state_locals:
raise ImportError("{0!r} was not found in {1!r}".format(
name,
import_file
))
_globals[alias] = state_locals[name]
matched = True
break
matched = True
break
if not matched:
template_data.append(line)
if not matched:
template_data.append(line)
return "\n".join(template_data), state_globals
final_template = "\n".join(template_data)
# process the template that triggered the render
final_template, final_locals = process_template(template, _globals)
_globals.update(final_locals)
# re-enable the registry
Registry.enabled = True

View file

@ -79,7 +79,7 @@ with Pkg.installed("samba", names=[Samba.server, Samba.client]):
import_template = '''#!pyobjects
import salt://map.sls
Pkg.removed("samba-imported", names=[Samba.server, Samba.client])
Pkg.removed("samba-imported", names=[map.Samba.server, map.Samba.client])
'''
recursive_map_template = '''#!pyobjects
@ -94,6 +94,12 @@ from salt://recursive_map.sls import CustomSamba
Pkg.removed("samba-imported", names=[CustomSamba.server, CustomSamba.client])'''
scope_test_import_template = '''#!pyobjects
from salt://recursive_map.sls import CustomSamba
# since we import CustomSamba we should shouldn't be able to see Samba
Pkg.removed("samba-imported", names=[Samba.server, Samba.client])'''
from_import_template = '''#!pyobjects
# this spacing is like this on purpose to ensure it's stripped properly
from salt://map.sls import Samba
@ -331,6 +337,19 @@ class RendererTests(RendererMixin, StateTests):
self.write_template_file("recursive_map.sls", recursive_map_template)
render_and_assert(recursive_import_template)
def test_import_scope(self):
self.write_template_file("map.sls", map_template)
self.write_template_file("recursive_map.sls", recursive_map_template)
def do_render():
ret = self.render(scope_test_import_template,
{'grains': {
'os_family': 'Debian',
'os': 'Debian'
}})
self.assertRaises(NameError, do_render)
def test_random_password(self):
'''Test for https://github.com/saltstack/salt/issues/21796'''
ret = self.render(random_password_template)