mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Wean of tcp transport bits in ws transport
This commit is contained in:
parent
85c282e51a
commit
9adfd29c54
5 changed files with 139 additions and 94 deletions
|
@ -368,8 +368,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
await asyncio.sleep(0.001)
|
||||
if timeout == 0:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg[b"body"]
|
||||
try:
|
||||
events, _, _ = select.select([self._stream.socket], [], [], 0)
|
||||
except TimeoutError:
|
||||
|
@ -390,8 +389,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
return
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg[b"body"]
|
||||
elif timeout:
|
||||
try:
|
||||
return await asyncio.wait_for(self.recv(), timeout=timeout)
|
||||
|
@ -405,8 +403,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
return
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg[b"body"]
|
||||
while not self._closing:
|
||||
async with self._read_in_progress:
|
||||
try:
|
||||
|
@ -423,8 +420,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
continue
|
||||
self.unpacker.feed(byts)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg[b"body"]
|
||||
|
||||
async def on_recv_handler(self, callback):
|
||||
while not self._stream:
|
||||
|
@ -1455,7 +1451,10 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
primarily be used to create IPC channels and create our daemon process to
|
||||
do the actual publishing
|
||||
"""
|
||||
process_manager.add_process(self.publish_daemon, name=self.__class__.__name__)
|
||||
process_manager.add_process(
|
||||
self.publish_daemon,
|
||||
args=[self.publish_payload],
|
||||
name=self.__class__.__name__)
|
||||
|
||||
async def publish_payload(self, payload, *args):
|
||||
return await self.pub_server.publish_payload(payload)
|
||||
|
|
|
@ -4,6 +4,7 @@ import multiprocessing
|
|||
import socket
|
||||
import time
|
||||
import warnings
|
||||
import functools
|
||||
|
||||
import aiohttp
|
||||
import aiohttp.web
|
||||
|
@ -16,11 +17,9 @@ import salt.transport.frame
|
|||
from salt.transport.tcp import (
|
||||
USE_LOAD_BALANCER,
|
||||
LoadBalancerServer,
|
||||
TCPPuller,
|
||||
_get_bind_addr,
|
||||
_get_socket,
|
||||
_set_tcp_keepalive,
|
||||
_TCPPubServerPublisher,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -46,13 +45,13 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
|
||||
self.opts = opts
|
||||
self.io_loop = io_loop
|
||||
self.message_client = None
|
||||
self.unpacker = salt.utils.msgpack.Unpacker()
|
||||
|
||||
self.connected = False
|
||||
self._closing = False
|
||||
self._stream = None
|
||||
self._closing = False
|
||||
self._closed = False
|
||||
|
||||
self.backoff = opts.get("tcp_reconnect_backoff", 1)
|
||||
self.resolver = kwargs.get("resolver")
|
||||
self._read_in_progress = Lock()
|
||||
|
@ -75,15 +74,19 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
self._ws = None
|
||||
self._session = None
|
||||
self._closing = False
|
||||
self._closed = False
|
||||
self.on_recv_task = None
|
||||
|
||||
async def _close(self):
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self.on_recv_task:
|
||||
self.on_recv_task.cancel()
|
||||
await self.on_recv_task
|
||||
self.on_recv_task = None
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
self._closed = True
|
||||
|
||||
def close(self):
|
||||
|
@ -149,8 +152,6 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
async def _connect(self, timeout=None):
|
||||
if self._ws is None:
|
||||
self._ws, self._session = await self.getstream(timeout=timeout)
|
||||
# if not self._stream_return_running:
|
||||
# self.io_loop.spawn_callback(self._stream_return)
|
||||
if self.connect_callback:
|
||||
self.connect_callback(True)
|
||||
self.connected = True
|
||||
|
@ -170,16 +171,6 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
self.disconnect_callback = None
|
||||
await self._connect(timeout=timeout)
|
||||
|
||||
def _decode_messages(self, messages):
|
||||
if not isinstance(messages, dict):
|
||||
# TODO: For some reason we need to decode here for things
|
||||
# to work. Fix this.
|
||||
body = salt.utils.msgpack.loads(messages)
|
||||
body = salt.transport.frame.decode_embedded_strs(body)
|
||||
else:
|
||||
body = messages
|
||||
return body
|
||||
|
||||
async def send(self, msg):
|
||||
await self.message_client.send(msg, reply=False)
|
||||
|
||||
|
@ -189,8 +180,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
await asyncio.sleep(0.001)
|
||||
if timeout == 0:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg
|
||||
try:
|
||||
raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001)
|
||||
except TimeoutError:
|
||||
|
@ -201,8 +191,7 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
if raw_msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self.unpacker.feed(raw_msg.data)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg
|
||||
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
|
||||
log.error(
|
||||
"ws connection closed with exception %s", self._ws.exception()
|
||||
|
@ -211,12 +200,10 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
return await asyncio.wait_for(self.recv(), timeout=timeout)
|
||||
else:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg
|
||||
while True:
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg
|
||||
raw_msg = await self._ws.receive()
|
||||
if raw_msg.type == aiohttp.WSMsgType.TEXT:
|
||||
if raw_msg.data == "close":
|
||||
|
@ -224,26 +211,32 @@ class PublishClient(salt.transport.base.PublishClient):
|
|||
if raw_msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self.unpacker.feed(raw_msg.data)
|
||||
for msg in self.unpacker:
|
||||
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
|
||||
return framed_msg["body"]
|
||||
return msg
|
||||
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
|
||||
log.error(
|
||||
"ws connection closed with exception %s",
|
||||
self._ws.exception(),
|
||||
)
|
||||
|
||||
async def handle_on_recv(self, callback):
|
||||
async def on_recv_handler(self, callback):
|
||||
while not self._ws:
|
||||
await asyncio.sleep(0.003)
|
||||
while True:
|
||||
msg = await self.recv()
|
||||
callback(msg)
|
||||
await callback(msg)
|
||||
|
||||
def on_recv(self, callback):
|
||||
"""
|
||||
Register a callback for received messages (that we didn't initiate)
|
||||
"""
|
||||
self.io_loop.spawn_callback(self.handle_on_recv, callback)
|
||||
if self.on_recv_task:
|
||||
# XXX: We are not awaiting this canceled task. This still needs to
|
||||
# be addressed.
|
||||
self.on_recv_task.cancel()
|
||||
if callback is None:
|
||||
self.on_recv_task = None
|
||||
else:
|
||||
self.on_recv_task = asyncio.create_task(self.on_recv_handler(callback))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -287,7 +280,9 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
self.ssl = ssl
|
||||
self.clients = set()
|
||||
self._run = None
|
||||
self.pub_sock = None
|
||||
self.pub_writer = None
|
||||
self.pub_reader = None
|
||||
self._connecting = None
|
||||
|
||||
@property
|
||||
def topic_support(self):
|
||||
|
@ -333,6 +328,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
finally:
|
||||
print("CLOSE")
|
||||
self.close()
|
||||
|
||||
async def publisher(
|
||||
|
@ -369,25 +365,30 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
await runner.setup()
|
||||
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
|
||||
log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port)
|
||||
print('start site')
|
||||
await site.start()
|
||||
print('start puller')
|
||||
|
||||
self._pub_payload = publish_payload
|
||||
if self.pull_path:
|
||||
pull_uri = self.pull_path
|
||||
with salt.utils.files.set_umask(0o177):
|
||||
self.puller = await asyncio.start_unix_server(self.pull_handler, self.pull_path)
|
||||
else:
|
||||
pull_uri = self.pull_port
|
||||
|
||||
self.pull_sock = TCPPuller(
|
||||
pull_uri,
|
||||
io_loop=io_loop,
|
||||
payload_handler=publish_payload,
|
||||
)
|
||||
# Securely create socket
|
||||
log.warning("Starting the Salt Puller on %s", pull_uri)
|
||||
with salt.utils.files.set_umask(0o177):
|
||||
self.pull_sock.start()
|
||||
self.puller = await asyncio.start_server(self.pull_handler, self.pull_host, self.pull_port)
|
||||
print('puller started')
|
||||
while self._run.is_set():
|
||||
await asyncio.sleep(0.3)
|
||||
await server.stop()
|
||||
await self.server.stop()
|
||||
await self.puller.wait_closed()
|
||||
|
||||
async def pull_handler(self, reader, writer):
|
||||
print("puller got connection")
|
||||
unpacker = salt.utils.msgpack.Unpacker()
|
||||
while True:
|
||||
data = await reader.read(1024)
|
||||
unpacker.feed(data)
|
||||
for msg in unpacker:
|
||||
await self._pub_payload(msg)
|
||||
|
||||
def pre_fork(self, process_manager):
|
||||
"""
|
||||
|
@ -395,7 +396,10 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
primarily be used to create IPC channels and create our daemon process to
|
||||
do the actual publishing
|
||||
"""
|
||||
process_manager.add_process(self.publish_daemon, name=self.__class__.__name__)
|
||||
process_manager.add_process(
|
||||
self.publish_daemon,
|
||||
args=[self.publish_payload],
|
||||
name=self.__class__.__name__)
|
||||
|
||||
async def handle_request(self, request):
|
||||
try:
|
||||
|
@ -412,29 +416,30 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _connect(self):
|
||||
if self.pull_path:
|
||||
self.pub_reader, self.pub_writer = await asyncio.open_unix_connection(self.pull_path)
|
||||
else:
|
||||
self.pub_reader, self.pub_writer = await asyncio.open_connection(self.pull_host, self.pull_port)
|
||||
self._connecting = None
|
||||
|
||||
def connect(self):
|
||||
log.debug("Connect pusher %s", self.pull_path)
|
||||
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
|
||||
_TCPPubServerPublisher,
|
||||
(
|
||||
self.pull_host,
|
||||
self.pull_port,
|
||||
self.pull_path,
|
||||
),
|
||||
loop_kwarg="io_loop",
|
||||
)
|
||||
self.pub_sock.connect()
|
||||
if self._connecting is None:
|
||||
self._connecting = asyncio.create_task(self._connect())
|
||||
return self._connecting
|
||||
|
||||
async def publish(self, payload, **kwargs):
|
||||
"""
|
||||
Publish "load" to minions
|
||||
"""
|
||||
if not self.pub_sock:
|
||||
self.connect()
|
||||
self.pub_sock.send(payload)
|
||||
if not self.pub_writer:
|
||||
await self.connect()
|
||||
self.pub_writer.write(salt.payload.dumps(payload))
|
||||
await self.pub_writer.drain()
|
||||
|
||||
async def publish_payload(self, package, *args):
|
||||
payload = salt.transport.frame.frame_msg(package)
|
||||
payload = salt.payload.dumps(package)
|
||||
for ws in list(self.clients):
|
||||
try:
|
||||
await ws.send_bytes(payload)
|
||||
|
@ -442,11 +447,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
|
|||
self.clients.discard(ws)
|
||||
|
||||
def close(self):
|
||||
if self.pub_sock:
|
||||
self.pub_sock.close()
|
||||
self.pub_sock = None
|
||||
if self.pub_writer:
|
||||
self.pub_writer.close()
|
||||
self.pub_writer = None
|
||||
self.pub_reader = None
|
||||
if self._run is not None:
|
||||
self._run.clear()
|
||||
if self._connecting:
|
||||
self._connecting.cancel()
|
||||
|
||||
|
||||
class RequestServer(salt.transport.base.DaemonizedRequestServer):
|
||||
|
|
20
tests/pytests/functional/transport/server/conftest.py
Normal file
20
tests/pytests/functional/transport/server/conftest.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import salt.utils.process
|
||||
|
||||
import pytest
|
||||
|
||||
def transport_ids(value):
|
||||
return "Transport({})".format(value)
|
||||
|
||||
|
||||
@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
|
||||
def transport(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def process_manager():
|
||||
pm = salt.utils.process.ProcessManager()
|
||||
try:
|
||||
yield pm
|
||||
finally:
|
||||
pm.terminate()
|
|
@ -0,0 +1,37 @@
|
|||
import asyncio
|
||||
import salt.transport
|
||||
|
||||
|
||||
async def test_publsh_server(
|
||||
io_loop, minion_opts, master_opts, transport, process_manager
|
||||
):
|
||||
minion_opts["transport"] = master_opts["transport"] = transport
|
||||
|
||||
pub_server = salt.transport.publish_server(master_opts)
|
||||
pub_server.pre_fork(process_manager)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
pub_client = salt.transport.publish_client(minion_opts, io_loop, master_opts["interface"], master_opts["publish_port"])
|
||||
await pub_client.connect()
|
||||
|
||||
# Yield to loop in order to allow pub client to connect.
|
||||
event = asyncio.Event()
|
||||
|
||||
messages = []
|
||||
|
||||
async def handle_msg(msg):
|
||||
messages.append(msg)
|
||||
event.set()
|
||||
|
||||
try:
|
||||
pub_client.on_recv(handle_msg)
|
||||
msg = b"meh"
|
||||
await pub_server.publish(msg)
|
||||
await asyncio.wait_for(event.wait(), 1)
|
||||
assert [msg] == messages
|
||||
finally:
|
||||
pub_server.close()
|
||||
pub_client.close()
|
||||
|
||||
# Yield to loop in order to allow background close methods to finish.
|
||||
await asyncio.sleep(.3)
|
|
@ -1,29 +1,10 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
import salt.transport
|
||||
import salt.utils.process
|
||||
|
||||
|
||||
def transport_ids(value):
|
||||
return "Transport({})".format(value)
|
||||
|
||||
|
||||
@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
|
||||
def transport(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def process_manager():
|
||||
pm = salt.utils.process.ProcessManager()
|
||||
try:
|
||||
yield pm
|
||||
finally:
|
||||
pm.terminate()
|
||||
|
||||
|
||||
async def test_request_server(
|
||||
io_loop, minion_opts, master_opts, transport, process_manager
|
||||
):
|
||||
|
@ -57,5 +38,5 @@ async def test_request_server(
|
|||
req_client.close()
|
||||
req_server.close()
|
||||
|
||||
# Yield to loop in order to allow cleanup methods to finish.
|
||||
# Yield to loop in order to allow background close methods to finish.
|
||||
await asyncio.sleep(0.3)
|
||||
|
|
Loading…
Add table
Reference in a new issue