Wean of tcp transport bits in ws transport

This commit is contained in:
Daniel A. Wozniak 2023-08-13 21:34:24 -07:00 committed by Daniel Wozniak
parent 85c282e51a
commit 9adfd29c54
5 changed files with 139 additions and 94 deletions

View file

@ -368,8 +368,7 @@ class PublishClient(salt.transport.base.PublishClient):
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]
try:
events, _, _ = select.select([self._stream.socket], [], [], 0)
except TimeoutError:
@ -390,8 +389,7 @@ class PublishClient(salt.transport.base.PublishClient):
return
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]
elif timeout:
try:
return await asyncio.wait_for(self.recv(), timeout=timeout)
@ -405,8 +403,7 @@ class PublishClient(salt.transport.base.PublishClient):
return
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]
while not self._closing:
async with self._read_in_progress:
try:
@ -423,8 +420,7 @@ class PublishClient(salt.transport.base.PublishClient):
continue
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]
async def on_recv_handler(self, callback):
while not self._stream:
@ -1455,7 +1451,10 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
primarily be used to create IPC channels and create our daemon process to
do the actual publishing
"""
process_manager.add_process(self.publish_daemon, name=self.__class__.__name__)
process_manager.add_process(
self.publish_daemon,
args=[self.publish_payload],
name=self.__class__.__name__)
async def publish_payload(self, payload, *args):
return await self.pub_server.publish_payload(payload)

View file

@ -4,6 +4,7 @@ import multiprocessing
import socket
import time
import warnings
import functools
import aiohttp
import aiohttp.web
@ -16,11 +17,9 @@ import salt.transport.frame
from salt.transport.tcp import (
USE_LOAD_BALANCER,
LoadBalancerServer,
TCPPuller,
_get_bind_addr,
_get_socket,
_set_tcp_keepalive,
_TCPPubServerPublisher,
)
log = logging.getLogger(__name__)
@ -46,13 +45,13 @@ class PublishClient(salt.transport.base.PublishClient):
def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
self.opts = opts
self.io_loop = io_loop
self.message_client = None
self.unpacker = salt.utils.msgpack.Unpacker()
self.connected = False
self._closing = False
self._stream = None
self._closing = False
self._closed = False
self.backoff = opts.get("tcp_reconnect_backoff", 1)
self.resolver = kwargs.get("resolver")
self._read_in_progress = Lock()
@ -75,15 +74,19 @@ class PublishClient(salt.transport.base.PublishClient):
self._ws = None
self._session = None
self._closing = False
self._closed = False
self.on_recv_task = None
async def _close(self):
if self._ws is not None:
await self._ws.close()
self._ws = None
if self._session is not None:
await self._session.close()
self._session = None
if self.on_recv_task:
self.on_recv_task.cancel()
await self.on_recv_task
self.on_recv_task = None
if self._ws is not None:
await self._ws.close()
self._ws = None
self._closed = True
def close(self):
@ -149,8 +152,6 @@ class PublishClient(salt.transport.base.PublishClient):
async def _connect(self, timeout=None):
if self._ws is None:
self._ws, self._session = await self.getstream(timeout=timeout)
# if not self._stream_return_running:
# self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
self.connected = True
@ -170,16 +171,6 @@ class PublishClient(salt.transport.base.PublishClient):
self.disconnect_callback = None
await self._connect(timeout=timeout)
def _decode_messages(self, messages):
if not isinstance(messages, dict):
# TODO: For some reason we need to decode here for things
# to work. Fix this.
body = salt.utils.msgpack.loads(messages)
body = salt.transport.frame.decode_embedded_strs(body)
else:
body = messages
return body
async def send(self, msg):
await self.message_client.send(msg, reply=False)
@ -189,8 +180,7 @@ class PublishClient(salt.transport.base.PublishClient):
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
try:
raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001)
except TimeoutError:
@ -201,8 +191,7 @@ class PublishClient(salt.transport.base.PublishClient):
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s", self._ws.exception()
@ -211,12 +200,10 @@ class PublishClient(salt.transport.base.PublishClient):
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
while True:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
raw_msg = await self._ws.receive()
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
@ -224,26 +211,32 @@ class PublishClient(salt.transport.base.PublishClient):
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s",
self._ws.exception(),
)
async def handle_on_recv(self, callback):
async def on_recv_handler(self, callback):
while not self._ws:
await asyncio.sleep(0.003)
while True:
msg = await self.recv()
callback(msg)
await callback(msg)
def on_recv(self, callback):
"""
Register a callback for received messages (that we didn't initiate)
"""
self.io_loop.spawn_callback(self.handle_on_recv, callback)
if self.on_recv_task:
# XXX: We are not awaiting this canceled task. This still needs to
# be addressed.
self.on_recv_task.cancel()
if callback is None:
self.on_recv_task = None
else:
self.on_recv_task = asyncio.create_task(self.on_recv_handler(callback))
def __enter__(self):
return self
@ -287,7 +280,9 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
self.ssl = ssl
self.clients = set()
self._run = None
self.pub_sock = None
self.pub_writer = None
self.pub_reader = None
self._connecting = None
@property
def topic_support(self):
@ -333,6 +328,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
except (KeyboardInterrupt, SystemExit):
pass
finally:
print("CLOSE")
self.close()
async def publisher(
@ -369,25 +365,30 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
await runner.setup()
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port)
print('start site')
await site.start()
print('start puller')
self._pub_payload = publish_payload
if self.pull_path:
pull_uri = self.pull_path
with salt.utils.files.set_umask(0o177):
self.puller = await asyncio.start_unix_server(self.pull_handler, self.pull_path)
else:
pull_uri = self.pull_port
self.pull_sock = TCPPuller(
pull_uri,
io_loop=io_loop,
payload_handler=publish_payload,
)
# Securely create socket
log.warning("Starting the Salt Puller on %s", pull_uri)
with salt.utils.files.set_umask(0o177):
self.pull_sock.start()
self.puller = await asyncio.start_server(self.pull_handler, self.pull_host, self.pull_port)
print('puller started')
while self._run.is_set():
await asyncio.sleep(0.3)
await server.stop()
await self.server.stop()
await self.puller.wait_closed()
async def pull_handler(self, reader, writer):
print("puller got connection")
unpacker = salt.utils.msgpack.Unpacker()
while True:
data = await reader.read(1024)
unpacker.feed(data)
for msg in unpacker:
await self._pub_payload(msg)
def pre_fork(self, process_manager):
"""
@ -395,7 +396,10 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
primarily be used to create IPC channels and create our daemon process to
do the actual publishing
"""
process_manager.add_process(self.publish_daemon, name=self.__class__.__name__)
process_manager.add_process(
self.publish_daemon,
args=[self.publish_payload],
name=self.__class__.__name__)
async def handle_request(self, request):
try:
@ -412,29 +416,30 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
while True:
await asyncio.sleep(1)
async def _connect(self):
if self.pull_path:
self.pub_reader, self.pub_writer = await asyncio.open_unix_connection(self.pull_path)
else:
self.pub_reader, self.pub_writer = await asyncio.open_connection(self.pull_host, self.pull_port)
self._connecting = None
def connect(self):
log.debug("Connect pusher %s", self.pull_path)
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
_TCPPubServerPublisher,
(
self.pull_host,
self.pull_port,
self.pull_path,
),
loop_kwarg="io_loop",
)
self.pub_sock.connect()
if self._connecting is None:
self._connecting = asyncio.create_task(self._connect())
return self._connecting
async def publish(self, payload, **kwargs):
"""
Publish "load" to minions
"""
if not self.pub_sock:
self.connect()
self.pub_sock.send(payload)
if not self.pub_writer:
await self.connect()
self.pub_writer.write(salt.payload.dumps(payload))
await self.pub_writer.drain()
async def publish_payload(self, package, *args):
payload = salt.transport.frame.frame_msg(package)
payload = salt.payload.dumps(package)
for ws in list(self.clients):
try:
await ws.send_bytes(payload)
@ -442,11 +447,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
self.clients.discard(ws)
def close(self):
if self.pub_sock:
self.pub_sock.close()
self.pub_sock = None
if self.pub_writer:
self.pub_writer.close()
self.pub_writer = None
self.pub_reader = None
if self._run is not None:
self._run.clear()
if self._connecting:
self._connecting.cancel()
class RequestServer(salt.transport.base.DaemonizedRequestServer):

View file

@ -0,0 +1,20 @@
import salt.utils.process
import pytest
def transport_ids(value):
return "Transport({})".format(value)
@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
def transport(request):
return request.param
@pytest.fixture
def process_manager():
pm = salt.utils.process.ProcessManager()
try:
yield pm
finally:
pm.terminate()

View file

@ -0,0 +1,37 @@
import asyncio
import salt.transport
async def test_publsh_server(
io_loop, minion_opts, master_opts, transport, process_manager
):
minion_opts["transport"] = master_opts["transport"] = transport
pub_server = salt.transport.publish_server(master_opts)
pub_server.pre_fork(process_manager)
await asyncio.sleep(3)
pub_client = salt.transport.publish_client(minion_opts, io_loop, master_opts["interface"], master_opts["publish_port"])
await pub_client.connect()
# Yield to loop in order to allow pub client to connect.
event = asyncio.Event()
messages = []
async def handle_msg(msg):
messages.append(msg)
event.set()
try:
pub_client.on_recv(handle_msg)
msg = b"meh"
await pub_server.publish(msg)
await asyncio.wait_for(event.wait(), 1)
assert [msg] == messages
finally:
pub_server.close()
pub_client.close()
# Yield to loop in order to allow background close methods to finish.
await asyncio.sleep(.3)

View file

@ -1,29 +1,10 @@
import asyncio
import pytest
import salt.transport
import salt.utils.process
def transport_ids(value):
return "Transport({})".format(value)
@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
def transport(request):
return request.param
@pytest.fixture
def process_manager():
pm = salt.utils.process.ProcessManager()
try:
yield pm
finally:
pm.terminate()
async def test_request_server(
io_loop, minion_opts, master_opts, transport, process_manager
):
@ -57,5 +38,5 @@ async def test_request_server(
req_client.close()
req_server.close()
# Yield to loop in order to allow cleanup methods to finish.
# Yield to loop in order to allow background close methods to finish.
await asyncio.sleep(0.3)