mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
feat: added allow_one_of()
and require_one_of()
decorators
- `allow_one_of()` raises an error if more than one of the allowed parameters is supplied. - `require_one_of()` operates the same as `allow_one_of()` but also raises an error if none of the allowed parameters are supplied.
This commit is contained in:
parent
0e95dbb2bb
commit
5d0b7de8b8
3 changed files with 183 additions and 36 deletions
1
changelog/58742.added
Normal file
1
changelog/58742.added
Normal file
|
@ -0,0 +1 @@
|
|||
New decorators `allow_one_of()` and `require_one_of()`
|
|
@ -1,10 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Helpful decorators for module writing
|
||||
"""
|
||||
|
||||
# Import python libs
|
||||
from __future__ import absolute_import, print_function, unicode_literals
|
||||
|
||||
import errno
|
||||
import inspect
|
||||
|
@ -15,13 +12,14 @@ import time
|
|||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
|
||||
# Import salt libs
|
||||
import salt.utils.args
|
||||
import salt.utils.data
|
||||
import salt.utils.versions
|
||||
from salt.exceptions import CommandExecutionError, SaltConfigurationError
|
||||
|
||||
# Import 3rd-party libs
|
||||
from salt.exceptions import (
|
||||
CommandExecutionError,
|
||||
SaltConfigurationError,
|
||||
SaltInvocationError,
|
||||
)
|
||||
from salt.ext import six
|
||||
from salt.log import LOG_LEVELS
|
||||
|
||||
|
@ -32,7 +30,7 @@ if getattr(sys, "getwindowsversion", False):
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Depends(object):
|
||||
class Depends:
|
||||
"""
|
||||
This decorator will check the module when it is loaded and check that the
|
||||
dependencies passed in are in the globals of the module. If not, it will
|
||||
|
@ -121,7 +119,7 @@ class Depends(object):
|
|||
|
||||
@staticmethod
|
||||
def run_command(dependency, mod_name, func_name):
|
||||
full_name = "{0}.{1}".format(mod_name, func_name)
|
||||
full_name = "{}.{}".format(mod_name, func_name)
|
||||
log.trace("Running '%s' for '%s'", dependency, full_name)
|
||||
if IS_WINDOWS:
|
||||
args = salt.utils.args.shlex_split(dependency, posix=False)
|
||||
|
@ -145,8 +143,8 @@ class Depends(object):
|
|||
It will modify the "functions" dict and remove/replace modules that
|
||||
are missing dependencies.
|
||||
"""
|
||||
for dependency, dependent_dict in six.iteritems(cls.dependency_dict[kind]):
|
||||
for (mod_name, func_name), (frame, params) in six.iteritems(dependent_dict):
|
||||
for dependency, dependent_dict in cls.dependency_dict[kind].items():
|
||||
for (mod_name, func_name), (frame, params) in dependent_dict.items():
|
||||
if mod_name != tgt_mod:
|
||||
continue
|
||||
# Imports from local context take presedence over those from the global context.
|
||||
|
@ -232,7 +230,7 @@ class Depends(object):
|
|||
except (AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
mod_key = "{0}.{1}".format(mod_name, func_name)
|
||||
mod_key = "{}.{}".format(mod_name, func_name)
|
||||
|
||||
# if we don't have this module loaded, skip it!
|
||||
if mod_key not in functions:
|
||||
|
@ -267,9 +265,7 @@ def timing(function):
|
|||
mod_name = function.__module__[16:]
|
||||
else:
|
||||
mod_name = function.__module__
|
||||
fstr = "Function %s.%s took %.{0}f seconds to execute".format(
|
||||
sys.float_info.dig
|
||||
)
|
||||
fstr = "Function %s.%s took %.{}f seconds to execute".format(sys.float_info.dig)
|
||||
log.profile(fstr, mod_name, function.__name__, end_time - start_time)
|
||||
return ret
|
||||
|
||||
|
@ -291,13 +287,13 @@ def memoize(func):
|
|||
def _memoize(*args, **kwargs):
|
||||
str_args = []
|
||||
for arg in args:
|
||||
if not isinstance(arg, six.string_types):
|
||||
str_args.append(six.text_type(arg))
|
||||
if not isinstance(arg, str):
|
||||
str_args.append(str(arg))
|
||||
else:
|
||||
str_args.append(arg)
|
||||
|
||||
args_ = ",".join(
|
||||
list(str_args) + ["{0}={1}".format(k, kwargs[k]) for k in sorted(kwargs)]
|
||||
list(str_args) + ["{}={}".format(k, kwargs[k]) for k in sorted(kwargs)]
|
||||
)
|
||||
if args_ not in cache:
|
||||
cache[args_] = func(*args, **kwargs)
|
||||
|
@ -306,7 +302,7 @@ def memoize(func):
|
|||
return _memoize
|
||||
|
||||
|
||||
class _DeprecationDecorator(object):
|
||||
class _DeprecationDecorator:
|
||||
"""
|
||||
Base mix-in class for the deprecation decorator.
|
||||
Takes care of a common functionality, used in its derivatives.
|
||||
|
@ -359,7 +355,7 @@ class _DeprecationDecorator(object):
|
|||
try:
|
||||
return self._function(*args, **kwargs)
|
||||
except TypeError as error:
|
||||
error = six.text_type(error).replace(
|
||||
error = str(error).replace(
|
||||
self._function, self._orig_f_name
|
||||
) # Hide hidden functions
|
||||
log.error(
|
||||
|
@ -374,7 +370,7 @@ class _DeprecationDecorator(object):
|
|||
self._function.__name__,
|
||||
error,
|
||||
)
|
||||
six.reraise(*sys.exc_info())
|
||||
raise
|
||||
else:
|
||||
raise CommandExecutionError(
|
||||
"Function is deprecated, but the successor function was not found."
|
||||
|
@ -626,11 +622,11 @@ class _WithDeprecated(_DeprecationDecorator):
|
|||
|
||||
if use_deprecated and use_superseded:
|
||||
raise SaltConfigurationError(
|
||||
"Function '{0}' is mentioned both in deprecated "
|
||||
"Function '{}' is mentioned both in deprecated "
|
||||
"and superseded sections. Please remove any of that.".format(full_name)
|
||||
)
|
||||
old_function = self._globals.get(
|
||||
self._with_name or "_{0}".format(function.__name__)
|
||||
self._with_name or "_{}".format(function.__name__)
|
||||
)
|
||||
if self._policy == self.OPT_IN:
|
||||
self._function = function if use_superseded else old_function
|
||||
|
@ -753,6 +749,92 @@ class _WithDeprecated(_DeprecationDecorator):
|
|||
with_deprecated = _WithDeprecated
|
||||
|
||||
|
||||
def require_one_of(*kwarg_names):
|
||||
"""
|
||||
Decorator to filter out exclusive arguments from the call.
|
||||
|
||||
kwarg_names:
|
||||
Limit which combination of arguments may be passed to the call.
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
# Require one of the following arguments to be supplied to foo()
|
||||
@require_one_of('arg1', 'arg2', 'arg3')
|
||||
def foo(arg1, arg2, arg3):
|
||||
|
||||
"""
|
||||
|
||||
def wrapper(f):
|
||||
@wraps(f)
|
||||
def func(*args, **kwargs):
|
||||
names = [key for key in kwargs if kwargs[key] and key in kwarg_names]
|
||||
names.extend(
|
||||
[
|
||||
args[i]
|
||||
for i, arg in enumerate(args)
|
||||
if args[i] and f.__code__.co_varnames[i] in kwarg_names
|
||||
]
|
||||
)
|
||||
if len(names) > 1:
|
||||
raise SaltInvocationError(
|
||||
"Only one of the following is allowed: {}".format(
|
||||
", ".join(kwarg_names)
|
||||
)
|
||||
)
|
||||
if not names:
|
||||
raise SaltInvocationError(
|
||||
"One of the following must be provided: {}".format(
|
||||
", ".join(kwarg_names)
|
||||
)
|
||||
)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def allow_one_of(*kwarg_names):
|
||||
"""
|
||||
Decorator to filter out exclusive arguments from the call.
|
||||
|
||||
kwarg_names:
|
||||
Limit which combination of arguments may be passed to the call.
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
# Allow only one of the following arguments to be supplied to foo()
|
||||
@allow_one_of('arg1', 'arg2', 'arg3')
|
||||
def foo(arg1, arg2, arg3):
|
||||
|
||||
"""
|
||||
|
||||
def wrapper(f):
|
||||
@wraps(f)
|
||||
def func(*args, **kwargs):
|
||||
names = [key for key in kwargs if kwargs[key] and key in kwarg_names]
|
||||
names.extend(
|
||||
[
|
||||
args[i]
|
||||
for i, arg in enumerate(args)
|
||||
if args[i] and f.__code__.co_varnames[i] in kwarg_names
|
||||
]
|
||||
)
|
||||
if len(names) > 1:
|
||||
raise SaltInvocationError(
|
||||
"Only of the following is allowed: {}".format(
|
||||
", ".join(kwarg_names)
|
||||
)
|
||||
)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def ignores_kwargs(*kwarg_names):
|
||||
"""
|
||||
Decorator to filter out unexpected keyword arguments from the call
|
||||
|
@ -782,12 +864,6 @@ def ensure_unicode_args(function):
|
|||
|
||||
@wraps(function)
|
||||
def wrapped(*args, **kwargs):
|
||||
if six.PY2:
|
||||
return function(
|
||||
*salt.utils.data.decode_list(args),
|
||||
**salt.utils.data.decode_dict(kwargs)
|
||||
)
|
||||
else:
|
||||
return function(*args, **kwargs)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
:codeauthor: Bo Maryniuk (bo@suse.de)
|
||||
unit.utils.decorators_test
|
||||
"""
|
||||
|
||||
# Import Python libs
|
||||
from __future__ import absolute_import, print_function, unicode_literals
|
||||
|
||||
import inspect
|
||||
|
||||
# Import Salt libs
|
||||
import salt.utils.decorators as decorators
|
||||
from salt.exceptions import CommandExecutionError, SaltConfigurationError
|
||||
from salt.exceptions import (
|
||||
CommandExecutionError,
|
||||
SaltConfigurationError,
|
||||
SaltInvocationError,
|
||||
)
|
||||
from salt.version import SaltStackVersion
|
||||
from tests.support.mock import MagicMock, patch
|
||||
from tests.support.unit import TestCase
|
||||
|
||||
|
||||
class DummyLogger(object):
|
||||
class DummyLogger:
|
||||
"""
|
||||
Dummy logger accepts everything and simply logs
|
||||
"""
|
||||
|
@ -54,6 +54,9 @@ class DecoratorsTest(TestCase):
|
|||
"""
|
||||
return name, SaltStackVersion.from_name(name)
|
||||
|
||||
def arg_function(self, arg1=None, arg2=None, arg3=None):
|
||||
return "old"
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Setup a test
|
||||
|
@ -413,6 +416,73 @@ class DecoratorsTest(TestCase):
|
|||
with self.assertRaises(SaltConfigurationError):
|
||||
assert depr(self.new_function)() == self.new_function()
|
||||
|
||||
def test_allow_one_of(self):
|
||||
"""
|
||||
Test allow_one_of properly does not error when only one of the
|
||||
required arguments is passed.
|
||||
|
||||
:return:
|
||||
"""
|
||||
allow_one_of = decorators.allow_one_of("arg1", "arg2", "arg3")
|
||||
assert allow_one_of(self.arg_function)(arg1="good") == self.arg_function(
|
||||
arg1="good"
|
||||
)
|
||||
|
||||
def test_allow_one_of_succeeds_when_no_arguments_supplied(self):
|
||||
"""
|
||||
Test allow_one_of properly does not error when none of the allowed
|
||||
arguments are supplied.
|
||||
|
||||
:return:
|
||||
"""
|
||||
allow_one_of = decorators.allow_one_of("arg1", "arg2", "arg3")
|
||||
assert allow_one_of(self.arg_function)() == self.arg_function()
|
||||
|
||||
def test_allow_one_of_raises_error_when_multiple_allowed_arguments_supplied(self):
|
||||
"""
|
||||
Test allow_one_of properly does not error when only one of the
|
||||
required arguments is passed.
|
||||
|
||||
:return:
|
||||
"""
|
||||
allow_one_of = decorators.allow_one_of("arg1", "arg2", "arg3")
|
||||
with self.assertRaises(SaltInvocationError):
|
||||
allow_one_of(self.arg_function)(arg1="good", arg2="bad")
|
||||
|
||||
def test_require_one_of(self):
|
||||
"""
|
||||
Test require_one_of properly does not error when only one of the
|
||||
required arguments is passed.
|
||||
|
||||
:return:
|
||||
"""
|
||||
require_one_of = decorators.require_one_of("arg1", "arg2", "arg3")
|
||||
assert require_one_of(self.arg_function)(arg1="good") == self.arg_function(
|
||||
arg1="good"
|
||||
)
|
||||
|
||||
def test_require_one_of_raises_error_when_none_of_allowed_arguments_supplied(self):
|
||||
"""
|
||||
Test require_one_of properly raises an error when none of the required
|
||||
arguments are supplied.
|
||||
|
||||
:return:
|
||||
"""
|
||||
require_one_of = decorators.require_one_of("arg1", "arg2", "arg3")
|
||||
with self.assertRaises(SaltInvocationError):
|
||||
require_one_of(self.arg_function)()
|
||||
|
||||
def test_require_one_of_raises_error_when_multiple_allowed_arguments_supplied(self):
|
||||
"""
|
||||
Test require_one_of properly raises an error when multiples of the
|
||||
allowed arguments are supplied.
|
||||
|
||||
:return:
|
||||
"""
|
||||
require_one_of = decorators.require_one_of("arg1", "arg2", "arg3")
|
||||
with self.assertRaises(SaltInvocationError):
|
||||
require_one_of(self.new_function)(arg1="good", arg2="bad")
|
||||
|
||||
def test_with_depreciated_should_wrap_function(self):
|
||||
wrapped = decorators.with_deprecated({}, "Beryllium")(self.old_function)
|
||||
assert wrapped.__module__ == self.old_function.__module__
|
||||
|
|
Loading…
Add table
Reference in a new issue