Fix payload and defaults (#62458)

* Allow default to be none

* Additional tests

* Switch away from using Elipsis

```
Python 3.7.12 (default, Mar 16 2022, 11:48:18)
[GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pickle
>>> import salt.defaults
>>> salt.defaults.NOT_SET
<NOT_SET>
>>> d = pickle.dumps(salt.defaults.NOT_SET)
>>> ld = pickle.loads(d)
>>> ld
<NOT_SET>
>>> ld == salt.defaults.NOT_SET
True
>>>
```

Signed-off-by: Pedro Algarvio <palgarvio@vmware.com>

Signed-off-by: Pedro Algarvio <palgarvio@vmware.com>
Co-authored-by: Pedro Algarvio <palgarvio@vmware.com>
This commit is contained in:
Shane Lee 2022-08-17 17:24:35 -07:00 committed by GitHub
parent 427e85fa52
commit 6ff7116195
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 171 additions and 28 deletions

View file

@ -5,5 +5,58 @@ Do NOT, import any salt modules (salt.utils, salt.config, etc.) into this file,
as this may result in circular imports.
"""
class _Constant:
"""
This class implements a way to create constants in python.
NOTE:
- This is not really a constant, ie, the `is` check will not work, you'll
have to use `==`.
- This class SHALL NOT be considered public API and might change or even
go away at any given time.
"""
__slots__ = ("name", "value")
def __init__(self, name, value=None):
self.name = name
self.value = value
def __hash__(self):
return hash((self.name, self.value))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
if self.name != other.name:
return False
return self.value == other.value
def __get_state__(self):
return {
"name": self.name,
"value": self.value,
}
def __set_state__(self, state):
return self.__class__(state["name"], state["value"])
def __repr__(self):
if self.value:
return "<Constant.{} value={}>".format(self.name, self.value)
return "<Constant.{}>".format(self.name)
# Default delimiter for multi-level traversal in targeting
DEFAULT_TARGET_DELIM = ":"
"""
Used in functions to define that a keyword default is not set.
It's used to differentiate from `None`, `True`, `False` which, in some
cases are proper defaults and are also proper values to pass.
"""
NOT_SET = _Constant("NOT_SET")

View file

@ -14,7 +14,7 @@ import salt.utils.dictupdate
import salt.utils.functools
import salt.utils.odict
import salt.utils.yaml
from salt.defaults import DEFAULT_TARGET_DELIM
from salt.defaults import DEFAULT_TARGET_DELIM, NOT_SET
from salt.exceptions import CommandExecutionError
__proxyenabled__ = ["*"]
@ -24,7 +24,7 @@ log = logging.getLogger(__name__)
def get(
key,
default=None,
default=NOT_SET,
merge=False,
merge_nested_lists=None,
delimiter=DEFAULT_TARGET_DELIM,
@ -122,7 +122,7 @@ def get(
salt '*' pillar.get pkg:apache
salt '*' pillar.get abc::def|ghi delimiter='|'
"""
if default is None:
if default == NOT_SET:
default = KeyError
if not __opts__.get("pillar_raise_on_missing"):
if default is KeyError:

View file

@ -178,12 +178,13 @@ Functions to interact with Hashicorp Vault.
import logging
import os
from salt.defaults import NOT_SET
from salt.exceptions import CommandExecutionError
log = logging.getLogger(__name__)
def read_secret(path, key=None, metadata=False, default=None):
def read_secret(path, key=None, metadata=False, default=NOT_SET):
"""
.. versionchanged:: 3001
The ``default`` argument has been added. When the path or path/key
@ -211,7 +212,7 @@ def read_secret(path, key=None, metadata=False, default=None):
first: {{ supersecret.first }}
second: {{ supersecret.second }}
"""
if default is None:
if default == NOT_SET:
default = CommandExecutionError
version2 = __utils__["vault.is_v2"](path)
if version2["v2"]:
@ -358,7 +359,7 @@ def destroy_secret(path, *args):
return False
def list_secrets(path, default=None):
def list_secrets(path, default=NOT_SET):
"""
.. versionchanged:: 3001
The ``default`` argument has been added. When the path or path/key
@ -374,7 +375,7 @@ def list_secrets(path, default=None):
salt '*' vault.list_secrets "secret/my/"
"""
if default is None:
if default == NOT_SET:
default = CommandExecutionError
log.debug("Listing vault secret keys for %s in %s", __grains__["id"], path)
version2 = __utils__["vault.is_v2"](path)

View file

@ -14,6 +14,7 @@ import salt.transport.frame
import salt.utils.immutabletypes as immutabletypes
import salt.utils.msgpack
import salt.utils.stringutils
from salt.defaults import _Constant
from salt.exceptions import SaltDeserializationError, SaltReqTimeoutError
from salt.utils.data import CaseInsensitiveDict
@ -77,6 +78,9 @@ def loads(msg, encoding=None, raw=False):
if code == 78:
data = salt.utils.stringutils.to_unicode(data)
return datetime.datetime.strptime(data, "%Y%m%dT%H:%M:%S.%f")
if code == 79:
name, value = salt.utils.msgpack.loads(data, raw=False)
return _Constant(name, value)
return data
gc.disable() # performance optimization for msgpack
@ -144,6 +148,12 @@ def dumps(msg, use_bin_type=False):
78,
salt.utils.stringutils.to_bytes(obj.strftime("%Y%m%dT%H:%M:%S.%f")),
)
elif isinstance(obj, _Constant):
# Special case our constants.
return salt.utils.msgpack.ExtType(
79,
salt.utils.msgpack.dumps((obj.name, obj.value), use_bin_type=True),
)
# The same for immutable types
elif isinstance(obj, immutabletypes.ImmutableDict):
return dict(obj)

View file

@ -0,0 +1,34 @@
# This tests the pillar module with `pillar_raise_on_missing` set to True in the
# minion config. This effects all tests in this file
import pytest
pytestmark = [
pytest.mark.windows_whitelisted,
]
@pytest.fixture(scope="module")
def pillar(modules):
return modules.pillar
@pytest.fixture(scope="module")
def minion_config_overrides():
yield {"pillar_raise_on_missing": True}
def test_get_non_existing(pillar):
"""
Test pillar.get when the item does not exist. Should raise a KeyError when
`pillar_raise_on_missing` is True in the minion config
"""
with pytest.raises(KeyError):
pillar.get("non-existing-pillar-item")
def test_get_default_none(pillar):
"""
Tests pillar.get when default is set to `None`. Should return `None`
"""
result = pillar.get("non-existing-pillar-item", default=None)
assert result is None

View file

@ -0,0 +1,42 @@
import pytest
pytestmark = [
pytest.mark.windows_whitelisted,
]
@pytest.fixture(scope="module")
def sys_mod(modules):
return modules.sys
@pytest.fixture(scope="module")
def pillar(modules):
return modules.pillar
def test_pillar_get_issue_61084(sys_mod):
"""
Test issue 61084. `sys.argspec` should return valid data and not throw a
TypeError due to pickling
This should probably be a pre-commit check or something
"""
result = sys_mod.argspec("pillar.get")
assert isinstance(result, dict)
assert isinstance(result.get("pillar.get"), dict)
def test_get_non_existing(pillar):
"""
Tests pillar.get when the item does not exist. Should return an empty string
"""
result = pillar.get("non-existing-pillar-item")
assert result == ""
def test_get_default_none(pillar):
"""
Tests pillar.get when default is set to `None`. Should return `None`
"""
result = pillar.get("non-existing-pillar-item", default=None)
assert result is None

View file

@ -1,21 +0,0 @@
import pytest
pytestmark = [
pytest.mark.windows_whitelisted,
]
@pytest.fixture(scope="module")
def sys_mod(modules):
return modules.sys
def test_pillar_get_issue_61084(sys_mod):
"""
Test issue 61084. `sys.argspec` should return valid data and not throw a
TypeError due to pickling
This should probably be a pre-commit check or something
"""
result = sys_mod.argspec("pillar.get")
assert isinstance(result, dict)
assert isinstance(result.get("pillar.get"), dict)

View file

@ -0,0 +1,13 @@
import pickle
from salt.defaults import _Constant
def test_pickle_constants():
"""
That that we can pickle and unpickle constants.
"""
constant = _Constant("Foo", 123)
sdata = pickle.dumps(constant)
odata = pickle.loads(sdata)
assert odata == constant

View file

@ -9,6 +9,7 @@ import logging
import salt.exceptions
import salt.payload
from salt.defaults import _Constant
from salt.utils import immutabletypes
from salt.utils.odict import OrderedDict
@ -199,3 +200,13 @@ def test_raw_vs_encoding_utf8():
sdata = salt.payload.dumps(idata.copy())
odata = salt.payload.loads(sdata, encoding="utf-8")
assert isinstance(odata[dtvalue], str)
def test_constants():
"""
That that we handle encoding and decoding of constants.
"""
constant = _Constant("Foo", "bar")
sdata = salt.payload.dumps(constant)
odata = salt.payload.loads(sdata)
assert odata == constant