Fix tests

This commit is contained in:
Daniel A. Wozniak 2023-08-13 02:38:53 -07:00 committed by Daniel Wozniak
parent f62f6469ff
commit 3347d543f5
5 changed files with 69 additions and 86 deletions

View file

@ -184,23 +184,40 @@ class PublishClient(salt.transport.base.PublishClient):
await self.message_client.send(msg, reply=False)
async def recv(self, timeout=None):
try:
await self._read_in_progress.acquire(timeout=0.001)
except tornado.gen.TimeoutError:
log.error("Unable to acquire read lock")
return
try:
if timeout == 0:
if not self._ws:
await asyncio.sleep(0.001)
return
while self._ws is None:
await self.connect()
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
try:
raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001)
except TimeoutError:
return
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
await self._ws.close()
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
try:
raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001)
except TimeoutError:
return
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s", self._ws.exception()
)
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
raw_msg = await self._ws.receive()
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
await self._ws.close()
@ -211,34 +228,9 @@ class PublishClient(salt.transport.base.PublishClient):
return framed_msg["body"]
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s", self._ws.exception()
"ws connection closed with exception %s",
self._ws.exception(),
)
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
raw_msg = await self._ws.receive()
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
await self._ws.close()
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s",
self._ws.exception(),
)
finally:
self._read_in_progress.release()
async def handle_on_recv(self, callback):
while not self._ws:

View file

@ -170,7 +170,7 @@ def test_batch_issue_56273():
"extension_modules": "",
"failhard": True,
}
with patch("salt.transport.tcp.TCPPubClient", MockSubscriber):
with patch("salt.transport.tcp.PublishClient", MockSubscriber):
batch = salt.cli.batch.Batch(opts, quiet=True)
with patch.object(batch.local, "pub", Mock(side_effect=mock_pub)):
with patch.object(

View file

@ -56,7 +56,7 @@ def server(config):
@pytest.fixture
def client(io_loop, config):
client = salt.transport.tcp.TCPPubClient(
client = salt.transport.tcp.PublishClient(
config.copy(), io_loop, host=config["master_ip"], port=config["publish_port"]
)
try:

View file

@ -1,4 +1,3 @@
<<<<<<< HEAD
"""
Unit tests for salt.transport.base.
"""
@ -26,62 +25,54 @@ def test_unclosed_warning():
with pytest.warns(salt.transport.base.TransportWarning):
del transport
@patch('ssl.SSLContext')
@patch("ssl.SSLContext")
def test_ssl_context_legacy_opts(mock):
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_NONE",
"ca_certs": "ca.crt",
})
ctx = salt.transport.base.ssl_context(
{
"certfile": "server.crt",
"keyfile": "server.key",
"cert_reqs": "CERT_NONE",
"ca_certs": "ca.crt",
}
)
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_called_with(
"ca.crt"
)
ctx.load_verify_locations.assert_called_with("ca.crt")
assert ssl.VerifyMode.CERT_NONE == ctx.verify_mode
assert not ctx.check_hostname
@patch('ssl.SSLContext')
@patch("ssl.SSLContext")
def test_ssl_context_opts(mock):
mock.verify_flags = ssl.VerifyFlags.VERIFY_X509_TRUSTED_FIRST
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_OPTIONAL",
"verify_locations": [
"ca.crt",
{"cafile": "crl.pem"},
{"capath": "/tmp/mycapathsdf"},
{"cadata": "mycadataother"},
{"CADATA": "mycadatasdf"},
],
"verify_flags": [
"VERIFY_CRL_CHECK_CHAIN",
]
})
ctx = salt.transport.base.ssl_context(
{
"certfile": "server.crt",
"keyfile": "server.key",
"cert_reqs": "CERT_OPTIONAL",
"verify_locations": [
"ca.crt",
{"cafile": "crl.pem"},
{"capath": "/tmp/mycapathsdf"},
{"cadata": "mycadataother"},
{"CADATA": "mycadatasdf"},
],
"verify_flags": [
"VERIFY_CRL_CHECK_CHAIN",
],
}
)
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_any_call(
cafile="ca.crt"
)
ctx.load_verify_locations.assert_any_call(
cafile="crl.pem"
)
ctx.load_verify_locations.assert_any_call(
capath="/tmp/mycapathsdf"
)
ctx.load_verify_locations.assert_any_call(
cadata="mycadataother"
)
ctx.load_verify_locations.assert_called_with(
cadata="mycadatasdf"
)
ctx.load_verify_locations.assert_any_call(cafile="ca.crt")
ctx.load_verify_locations.assert_any_call(cafile="crl.pem")
ctx.load_verify_locations.assert_any_call(capath="/tmp/mycapathsdf")
ctx.load_verify_locations.assert_any_call(cadata="mycadataother")
ctx.load_verify_locations.assert_called_with(cadata="mycadatasdf")
assert ssl.VerifyMode.CERT_OPTIONAL == ctx.verify_mode
assert ctx.check_hostname
assert ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN & ctx.verify_flags

View file

@ -152,7 +152,7 @@ async def test_async_tcp_pub_channel_connect_publish_port(
future.set_result(True)
with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
"salt.crypt.AsyncAuth.authenticated", patch_auth
), patch("salt.transport.tcp.TCPPubClient", transport):
), patch("salt.transport.tcp.PublishClient", transport):
channel = salt.channel.client.AsyncPubChannel.factory(opts)
with channel:
# We won't be able to succeed the connection because we're not mocking the tornado coroutine