diff --git a/salt/transport/ws.py b/salt/transport/ws.py index 5ddaea894eb..a142fafa931 100644 --- a/salt/transport/ws.py +++ b/salt/transport/ws.py @@ -184,23 +184,40 @@ class PublishClient(salt.transport.base.PublishClient): await self.message_client.send(msg, reply=False) async def recv(self, timeout=None): - try: - await self._read_in_progress.acquire(timeout=0.001) - except tornado.gen.TimeoutError: - log.error("Unable to acquire read lock") - return - try: - if timeout == 0: - if not self._ws: - await asyncio.sleep(0.001) - return + while self._ws is None: + await self.connect() + await asyncio.sleep(0.001) + if timeout == 0: + for msg in self.unpacker: + framed_msg = salt.transport.frame.decode_embedded_strs(msg) + return framed_msg["body"] + try: + raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001) + except TimeoutError: + return + if raw_msg.type == aiohttp.WSMsgType.TEXT: + if raw_msg.data == "close": + await self._ws.close() + if raw_msg.type == aiohttp.WSMsgType.BINARY: + self.unpacker.feed(raw_msg.data) for msg in self.unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(msg) return framed_msg["body"] - try: - raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001) - except TimeoutError: - return + elif raw_msg.type == aiohttp.WSMsgType.ERROR: + log.error( + "ws connection closed with exception %s", self._ws.exception() + ) + elif timeout: + return await asyncio.wait_for(self.recv(), timeout=timeout) + else: + for msg in self.unpacker: + framed_msg = salt.transport.frame.decode_embedded_strs(msg) + return framed_msg["body"] + while True: + for msg in self.unpacker: + framed_msg = salt.transport.frame.decode_embedded_strs(msg) + return framed_msg["body"] + raw_msg = await self._ws.receive() if raw_msg.type == aiohttp.WSMsgType.TEXT: if raw_msg.data == "close": await self._ws.close() @@ -211,34 +228,9 @@ class PublishClient(salt.transport.base.PublishClient): return framed_msg["body"] elif raw_msg.type == aiohttp.WSMsgType.ERROR: log.error( - "ws connection closed with exception %s", self._ws.exception() + "ws connection closed with exception %s", + self._ws.exception(), ) - elif timeout: - return await asyncio.wait_for(self.recv(), timeout=timeout) - else: - for msg in self.unpacker: - framed_msg = salt.transport.frame.decode_embedded_strs(msg) - return framed_msg["body"] - while True: - for msg in self.unpacker: - framed_msg = salt.transport.frame.decode_embedded_strs(msg) - return framed_msg["body"] - raw_msg = await self._ws.receive() - if raw_msg.type == aiohttp.WSMsgType.TEXT: - if raw_msg.data == "close": - await self._ws.close() - if raw_msg.type == aiohttp.WSMsgType.BINARY: - self.unpacker.feed(raw_msg.data) - for msg in self.unpacker: - framed_msg = salt.transport.frame.decode_embedded_strs(msg) - return framed_msg["body"] - elif raw_msg.type == aiohttp.WSMsgType.ERROR: - log.error( - "ws connection closed with exception %s", - self._ws.exception(), - ) - finally: - self._read_in_progress.release() async def handle_on_recv(self, callback): while not self._ws: diff --git a/tests/pytests/functional/cli/test_batch.py b/tests/pytests/functional/cli/test_batch.py index adc5406737a..7b602c9b825 100644 --- a/tests/pytests/functional/cli/test_batch.py +++ b/tests/pytests/functional/cli/test_batch.py @@ -170,7 +170,7 @@ def test_batch_issue_56273(): "extension_modules": "", "failhard": True, } - with patch("salt.transport.tcp.TCPPubClient", MockSubscriber): + with patch("salt.transport.tcp.PublishClient", MockSubscriber): batch = salt.cli.batch.Batch(opts, quiet=True) with patch.object(batch.local, "pub", Mock(side_effect=mock_pub)): with patch.object( diff --git a/tests/pytests/functional/transport/tcp/test_message_client.py b/tests/pytests/functional/transport/tcp/test_message_client.py index 2ddf308531c..7dd8dbe1961 100644 --- a/tests/pytests/functional/transport/tcp/test_message_client.py +++ b/tests/pytests/functional/transport/tcp/test_message_client.py @@ -56,7 +56,7 @@ def server(config): @pytest.fixture def client(io_loop, config): - client = salt.transport.tcp.TCPPubClient( + client = salt.transport.tcp.PublishClient( config.copy(), io_loop, host=config["master_ip"], port=config["publish_port"] ) try: diff --git a/tests/pytests/unit/transport/test_base.py b/tests/pytests/unit/transport/test_base.py index aac2d84502a..b01d0b92319 100644 --- a/tests/pytests/unit/transport/test_base.py +++ b/tests/pytests/unit/transport/test_base.py @@ -1,4 +1,3 @@ -<<<<<<< HEAD """ Unit tests for salt.transport.base. """ @@ -26,62 +25,54 @@ def test_unclosed_warning(): with pytest.warns(salt.transport.base.TransportWarning): del transport -@patch('ssl.SSLContext') +@patch("ssl.SSLContext") def test_ssl_context_legacy_opts(mock): - ctx = salt.transport.base.ssl_context({ - 'certfile': "server.crt", - 'keyfile': "server.key", - 'cert_reqs': "CERT_NONE", - "ca_certs": "ca.crt", - }) + ctx = salt.transport.base.ssl_context( + { + "certfile": "server.crt", + "keyfile": "server.key", + "cert_reqs": "CERT_NONE", + "ca_certs": "ca.crt", + } + ) ctx.load_cert_chain.assert_called_with( "server.crt", "server.key", ) - ctx.load_verify_locations.assert_called_with( - "ca.crt" - ) + ctx.load_verify_locations.assert_called_with("ca.crt") assert ssl.VerifyMode.CERT_NONE == ctx.verify_mode assert not ctx.check_hostname -@patch('ssl.SSLContext') +@patch("ssl.SSLContext") def test_ssl_context_opts(mock): mock.verify_flags = ssl.VerifyFlags.VERIFY_X509_TRUSTED_FIRST - ctx = salt.transport.base.ssl_context({ - 'certfile': "server.crt", - 'keyfile': "server.key", - 'cert_reqs': "CERT_OPTIONAL", - "verify_locations": [ - "ca.crt", - {"cafile": "crl.pem"}, - {"capath": "/tmp/mycapathsdf"}, - {"cadata": "mycadataother"}, - {"CADATA": "mycadatasdf"}, - ], - "verify_flags": [ - "VERIFY_CRL_CHECK_CHAIN", - ] - }) + ctx = salt.transport.base.ssl_context( + { + "certfile": "server.crt", + "keyfile": "server.key", + "cert_reqs": "CERT_OPTIONAL", + "verify_locations": [ + "ca.crt", + {"cafile": "crl.pem"}, + {"capath": "/tmp/mycapathsdf"}, + {"cadata": "mycadataother"}, + {"CADATA": "mycadatasdf"}, + ], + "verify_flags": [ + "VERIFY_CRL_CHECK_CHAIN", + ], + } + ) ctx.load_cert_chain.assert_called_with( "server.crt", "server.key", ) - ctx.load_verify_locations.assert_any_call( - cafile="ca.crt" - ) - ctx.load_verify_locations.assert_any_call( - cafile="crl.pem" - ) - ctx.load_verify_locations.assert_any_call( - capath="/tmp/mycapathsdf" - ) - ctx.load_verify_locations.assert_any_call( - cadata="mycadataother" - ) - ctx.load_verify_locations.assert_called_with( - cadata="mycadatasdf" - ) + ctx.load_verify_locations.assert_any_call(cafile="ca.crt") + ctx.load_verify_locations.assert_any_call(cafile="crl.pem") + ctx.load_verify_locations.assert_any_call(capath="/tmp/mycapathsdf") + ctx.load_verify_locations.assert_any_call(cadata="mycadataother") + ctx.load_verify_locations.assert_called_with(cadata="mycadatasdf") assert ssl.VerifyMode.CERT_OPTIONAL == ctx.verify_mode assert ctx.check_hostname assert ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN & ctx.verify_flags diff --git a/tests/pytests/unit/transport/test_tcp.py b/tests/pytests/unit/transport/test_tcp.py index 0dcf772f560..b7466742ae6 100644 --- a/tests/pytests/unit/transport/test_tcp.py +++ b/tests/pytests/unit/transport/test_tcp.py @@ -152,7 +152,7 @@ async def test_async_tcp_pub_channel_connect_publish_port( future.set_result(True) with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch( "salt.crypt.AsyncAuth.authenticated", patch_auth - ), patch("salt.transport.tcp.TCPPubClient", transport): + ), patch("salt.transport.tcp.PublishClient", transport): channel = salt.channel.client.AsyncPubChannel.factory(opts) with channel: # We won't be able to succeed the connection because we're not mocking the tornado coroutine