feat: added allow_one_of() and require_one_of() decorators

- `allow_one_of()` raises an error if more than one of the allowed
  parameters is supplied.
- `require_one_of()` operates the same as `allow_one_of()` but also
  raises an error if none of the allowed parameters are supplied.
This commit is contained in:
Mark Ferrell 2020-10-15 14:20:18 -07:00 committed by Megan Wilhite
parent 0e95dbb2bb
commit 5d0b7de8b8
3 changed files with 183 additions and 36 deletions

1
changelog/58742.added Normal file
View file

@ -0,0 +1 @@
New decorators `allow_one_of()` and `require_one_of()`

View file

@ -1,10 +1,7 @@
# -*- coding: utf-8 -*-
"""
Helpful decorators for module writing
"""
# Import python libs
from __future__ import absolute_import, print_function, unicode_literals
import errno
import inspect
@ -15,13 +12,14 @@ import time
from collections import defaultdict
from functools import wraps
# Import salt libs
import salt.utils.args
import salt.utils.data
import salt.utils.versions
from salt.exceptions import CommandExecutionError, SaltConfigurationError
# Import 3rd-party libs
from salt.exceptions import (
CommandExecutionError,
SaltConfigurationError,
SaltInvocationError,
)
from salt.ext import six
from salt.log import LOG_LEVELS
@ -32,7 +30,7 @@ if getattr(sys, "getwindowsversion", False):
log = logging.getLogger(__name__)
class Depends(object):
class Depends:
"""
This decorator will check the module when it is loaded and check that the
dependencies passed in are in the globals of the module. If not, it will
@ -121,7 +119,7 @@ class Depends(object):
@staticmethod
def run_command(dependency, mod_name, func_name):
full_name = "{0}.{1}".format(mod_name, func_name)
full_name = "{}.{}".format(mod_name, func_name)
log.trace("Running '%s' for '%s'", dependency, full_name)
if IS_WINDOWS:
args = salt.utils.args.shlex_split(dependency, posix=False)
@ -145,8 +143,8 @@ class Depends(object):
It will modify the "functions" dict and remove/replace modules that
are missing dependencies.
"""
for dependency, dependent_dict in six.iteritems(cls.dependency_dict[kind]):
for (mod_name, func_name), (frame, params) in six.iteritems(dependent_dict):
for dependency, dependent_dict in cls.dependency_dict[kind].items():
for (mod_name, func_name), (frame, params) in dependent_dict.items():
if mod_name != tgt_mod:
continue
# Imports from local context take presedence over those from the global context.
@ -232,7 +230,7 @@ class Depends(object):
except (AttributeError, KeyError):
pass
mod_key = "{0}.{1}".format(mod_name, func_name)
mod_key = "{}.{}".format(mod_name, func_name)
# if we don't have this module loaded, skip it!
if mod_key not in functions:
@ -267,9 +265,7 @@ def timing(function):
mod_name = function.__module__[16:]
else:
mod_name = function.__module__
fstr = "Function %s.%s took %.{0}f seconds to execute".format(
sys.float_info.dig
)
fstr = "Function %s.%s took %.{}f seconds to execute".format(sys.float_info.dig)
log.profile(fstr, mod_name, function.__name__, end_time - start_time)
return ret
@ -291,13 +287,13 @@ def memoize(func):
def _memoize(*args, **kwargs):
str_args = []
for arg in args:
if not isinstance(arg, six.string_types):
str_args.append(six.text_type(arg))
if not isinstance(arg, str):
str_args.append(str(arg))
else:
str_args.append(arg)
args_ = ",".join(
list(str_args) + ["{0}={1}".format(k, kwargs[k]) for k in sorted(kwargs)]
list(str_args) + ["{}={}".format(k, kwargs[k]) for k in sorted(kwargs)]
)
if args_ not in cache:
cache[args_] = func(*args, **kwargs)
@ -306,7 +302,7 @@ def memoize(func):
return _memoize
class _DeprecationDecorator(object):
class _DeprecationDecorator:
"""
Base mix-in class for the deprecation decorator.
Takes care of a common functionality, used in its derivatives.
@ -359,7 +355,7 @@ class _DeprecationDecorator(object):
try:
return self._function(*args, **kwargs)
except TypeError as error:
error = six.text_type(error).replace(
error = str(error).replace(
self._function, self._orig_f_name
) # Hide hidden functions
log.error(
@ -374,7 +370,7 @@ class _DeprecationDecorator(object):
self._function.__name__,
error,
)
six.reraise(*sys.exc_info())
raise
else:
raise CommandExecutionError(
"Function is deprecated, but the successor function was not found."
@ -626,11 +622,11 @@ class _WithDeprecated(_DeprecationDecorator):
if use_deprecated and use_superseded:
raise SaltConfigurationError(
"Function '{0}' is mentioned both in deprecated "
"Function '{}' is mentioned both in deprecated "
"and superseded sections. Please remove any of that.".format(full_name)
)
old_function = self._globals.get(
self._with_name or "_{0}".format(function.__name__)
self._with_name or "_{}".format(function.__name__)
)
if self._policy == self.OPT_IN:
self._function = function if use_superseded else old_function
@ -753,6 +749,92 @@ class _WithDeprecated(_DeprecationDecorator):
with_deprecated = _WithDeprecated
def require_one_of(*kwarg_names):
"""
Decorator to filter out exclusive arguments from the call.
kwarg_names:
Limit which combination of arguments may be passed to the call.
Example:
# Require one of the following arguments to be supplied to foo()
@require_one_of('arg1', 'arg2', 'arg3')
def foo(arg1, arg2, arg3):
"""
def wrapper(f):
@wraps(f)
def func(*args, **kwargs):
names = [key for key in kwargs if kwargs[key] and key in kwarg_names]
names.extend(
[
args[i]
for i, arg in enumerate(args)
if args[i] and f.__code__.co_varnames[i] in kwarg_names
]
)
if len(names) > 1:
raise SaltInvocationError(
"Only one of the following is allowed: {}".format(
", ".join(kwarg_names)
)
)
if not names:
raise SaltInvocationError(
"One of the following must be provided: {}".format(
", ".join(kwarg_names)
)
)
return f(*args, **kwargs)
return func
return wrapper
def allow_one_of(*kwarg_names):
"""
Decorator to filter out exclusive arguments from the call.
kwarg_names:
Limit which combination of arguments may be passed to the call.
Example:
# Allow only one of the following arguments to be supplied to foo()
@allow_one_of('arg1', 'arg2', 'arg3')
def foo(arg1, arg2, arg3):
"""
def wrapper(f):
@wraps(f)
def func(*args, **kwargs):
names = [key for key in kwargs if kwargs[key] and key in kwarg_names]
names.extend(
[
args[i]
for i, arg in enumerate(args)
if args[i] and f.__code__.co_varnames[i] in kwarg_names
]
)
if len(names) > 1:
raise SaltInvocationError(
"Only of the following is allowed: {}".format(
", ".join(kwarg_names)
)
)
return f(*args, **kwargs)
return func
return wrapper
def ignores_kwargs(*kwarg_names):
"""
Decorator to filter out unexpected keyword arguments from the call
@ -782,12 +864,6 @@ def ensure_unicode_args(function):
@wraps(function)
def wrapped(*args, **kwargs):
if six.PY2:
return function(
*salt.utils.data.decode_list(args),
**salt.utils.data.decode_dict(kwargs)
)
else:
return function(*args, **kwargs)
return function(*args, **kwargs)
return wrapped

View file

@ -1,23 +1,23 @@
# -*- coding: utf-8 -*-
"""
:codeauthor: Bo Maryniuk (bo@suse.de)
unit.utils.decorators_test
"""
# Import Python libs
from __future__ import absolute_import, print_function, unicode_literals
import inspect
# Import Salt libs
import salt.utils.decorators as decorators
from salt.exceptions import CommandExecutionError, SaltConfigurationError
from salt.exceptions import (
CommandExecutionError,
SaltConfigurationError,
SaltInvocationError,
)
from salt.version import SaltStackVersion
from tests.support.mock import MagicMock, patch
from tests.support.unit import TestCase
class DummyLogger(object):
class DummyLogger:
"""
Dummy logger accepts everything and simply logs
"""
@ -54,6 +54,9 @@ class DecoratorsTest(TestCase):
"""
return name, SaltStackVersion.from_name(name)
def arg_function(self, arg1=None, arg2=None, arg3=None):
return "old"
def setUp(self):
"""
Setup a test
@ -413,6 +416,73 @@ class DecoratorsTest(TestCase):
with self.assertRaises(SaltConfigurationError):
assert depr(self.new_function)() == self.new_function()
def test_allow_one_of(self):
"""
Test allow_one_of properly does not error when only one of the
required arguments is passed.
:return:
"""
allow_one_of = decorators.allow_one_of("arg1", "arg2", "arg3")
assert allow_one_of(self.arg_function)(arg1="good") == self.arg_function(
arg1="good"
)
def test_allow_one_of_succeeds_when_no_arguments_supplied(self):
"""
Test allow_one_of properly does not error when none of the allowed
arguments are supplied.
:return:
"""
allow_one_of = decorators.allow_one_of("arg1", "arg2", "arg3")
assert allow_one_of(self.arg_function)() == self.arg_function()
def test_allow_one_of_raises_error_when_multiple_allowed_arguments_supplied(self):
"""
Test allow_one_of properly does not error when only one of the
required arguments is passed.
:return:
"""
allow_one_of = decorators.allow_one_of("arg1", "arg2", "arg3")
with self.assertRaises(SaltInvocationError):
allow_one_of(self.arg_function)(arg1="good", arg2="bad")
def test_require_one_of(self):
"""
Test require_one_of properly does not error when only one of the
required arguments is passed.
:return:
"""
require_one_of = decorators.require_one_of("arg1", "arg2", "arg3")
assert require_one_of(self.arg_function)(arg1="good") == self.arg_function(
arg1="good"
)
def test_require_one_of_raises_error_when_none_of_allowed_arguments_supplied(self):
"""
Test require_one_of properly raises an error when none of the required
arguments are supplied.
:return:
"""
require_one_of = decorators.require_one_of("arg1", "arg2", "arg3")
with self.assertRaises(SaltInvocationError):
require_one_of(self.arg_function)()
def test_require_one_of_raises_error_when_multiple_allowed_arguments_supplied(self):
"""
Test require_one_of properly raises an error when multiples of the
allowed arguments are supplied.
:return:
"""
require_one_of = decorators.require_one_of("arg1", "arg2", "arg3")
with self.assertRaises(SaltInvocationError):
require_one_of(self.new_function)(arg1="good", arg2="bad")
def test_with_depreciated_should_wrap_function(self):
wrapped = decorators.with_deprecated({}, "Beryllium")(self.old_function)
assert wrapped.__module__ == self.old_function.__module__