This commit is contained in:
Daniel A. Wozniak 2023-08-18 14:34:48 -07:00 committed by Daniel Wozniak
parent fdbb4ed333
commit 24257072bb
10 changed files with 67 additions and 56 deletions

View file

@ -478,13 +478,12 @@ class AsyncPubChannel:
if callback is None:
return self.transport.on_recv(None)
@tornado.gen.coroutine
def wrap_callback(messages):
async def wrap_callback(messages):
payload = self.transport._decode_messages(messages)
decoded = yield self._decode_payload(payload)
log.debug("PubChannel received: %r", decoded)
if decoded is not None:
callback(decoded)
decoded = await self._decode_payload(payload)
log.debug("PubChannel received: %r %r", decoded, callback)
if decoded is not None and callback is not None:
await callback(decoded)
return self.transport.on_recv(wrap_callback)

View file

@ -1100,7 +1100,9 @@ class SaltAPIHandler(BaseSaltAPIHandler): # pylint: disable=W0223
minions,
is_finished,
)
print("$" * 80)
print(f"Get minion returns {events!r}")
print("$" * 80)
result = yield self.get_minion_returns(
events=events,
is_finished=is_finished,

View file

@ -353,8 +353,9 @@ class PublishClient(salt.transport.base.PublishClient):
if not isinstance(messages, dict):
# TODO: For some reason we need to decode here for things
# to work. Fix this.
body = salt.utils.msgpack.loads(messages)
body = salt.transport.frame.decode_embedded_strs(body)
body = salt.payload.loads(messages)
#body = salt.utils.msgpack.loads(messages)
#body = salt.transport.frame.decode_embedded_strs(body)
else:
body = messages
return body
@ -368,6 +369,9 @@ class PublishClient(salt.transport.base.PublishClient):
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
print("^" * 80)
print(f"RECV {msg!r}")
print("^" * 80)
return msg[b"body"]
try:
events, _, _ = select.select([self._stream.socket], [], [], 0)
@ -389,6 +393,9 @@ class PublishClient(salt.transport.base.PublishClient):
return
self.unpacker.feed(byts)
for msg in self.unpacker:
print("^" * 80)
print(f"RECV {msg!r}")
print("^" * 80)
return msg[b"body"]
elif timeout:
try:
@ -403,6 +410,9 @@ class PublishClient(salt.transport.base.PublishClient):
return
else:
for msg in self.unpacker:
print("^" * 80)
print(f"RECV {msg!r}")
print("^" * 80)
return msg[b"body"]
while not self._closing:
async with self._read_in_progress:
@ -420,6 +430,9 @@ class PublishClient(salt.transport.base.PublishClient):
continue
self.unpacker.feed(byts)
for msg in self.unpacker:
print("^" * 80)
print(f"RECV {msg!r}")
print("^" * 80)
return msg[b"body"]
async def on_recv_handler(self, callback):
@ -427,6 +440,7 @@ class PublishClient(salt.transport.base.PublishClient):
# Retry quickly, we may want to increase this if it's hogging cpu.
await asyncio.sleep(0.003)
while True:
print("On RECV READ")
msg = await self.recv()
if msg:
try:

View file

@ -102,6 +102,13 @@ class PublishClient(salt.transport.base.PublishClient):
)
# pylint: enable=W1701
def _decode_messages(self, messages):
if not isinstance(messages, dict):
body =salt.payload.loads(messages)
else:
body = messages
return body
async def getstream(self, **kwargs):
if self.source_ip or self.source_port:
@ -327,7 +334,6 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
except (KeyboardInterrupt, SystemExit):
pass
finally:
print("CLOSE")
self.close()
async def publisher(
@ -364,9 +370,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
await runner.setup()
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port)
print("start site")
await site.start()
print("start puller")
self._pub_payload = publish_payload
if self.pull_path:
@ -378,14 +382,12 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
self.puller = await asyncio.start_server(
self.pull_handler, self.pull_host, self.pull_port
)
print("puller started")
while self._run.is_set():
await asyncio.sleep(0.3)
await self.server.stop()
await self.puller.wait_closed()
async def pull_handler(self, reader, writer):
print("puller got connection")
unpacker = salt.utils.msgpack.Unpacker()
while True:
data = await reader.read(1024)

View file

@ -223,6 +223,7 @@ class PublishClient(salt.transport.base.PublishClient):
elif self.host and self.port:
if self.path:
raise Exception("A host and port or a path must be provided, not both")
self.on_recv_task = None
def close(self):
if self._closing is True:
@ -341,44 +342,28 @@ class PublishClient(salt.transport.base.PublishClient):
# raise Exception("Send not supported")
# await self._socket.send(msg)
def on_recv(self, callback):
async def on_recv_handler(self, callback):
while not self._socket:
# Retry quickly, we may want to increase this if it's hogging cpu.
await asyncio.sleep(0.003)
while True:
msg = await self.recv()
if msg:
await callback(msg)
def on_recv(self, callback):
"""
Register a callback for received messages (that we didn't initiate)
:param func callback: A function which should be called when data is received
"""
if self.on_recv_task:
# XXX: We are not awaiting this canceled task. This still needs to
# be addressed.
self.on_recv_task.cancel()
if callback is None:
callbacks = self.callbacks
self.callbacks = {}
for callback, (running, task) in callbacks.items():
running.clear()
return
self.on_recv_task = None
else:
self.on_recv_task = asyncio.create_task(self.on_recv_handler(callback))
running = asyncio.Event()
running.set()
async def consume(running):
try:
while running.is_set():
try:
msg = await self.recv(timeout=None)
except zmq.error.ZMQError as exc:
# We've disconnected just die
break
if msg:
try:
await callback(msg)
except Exception: # pylint: disable=broad-except
log.error("Exception while running callback", exc_info=True)
# log.debug("Callback done %r", callback)
except Exception as exc: # pylint: disable=broad-except
log.error(
"Exception while consuming%s %s", self.uri, exc, exc_info=True
)
task = self.io_loop.spawn_callback(consume, running)
self.callbacks[callback] = running, task
class RequestServer(salt.transport.base.DaemonizedRequestServer):

View file

@ -550,6 +550,9 @@ class SaltEvent:
try:
if not self.cpub and not self.connect_pub(timeout=wait):
break
print("%" * 80)
print(f"get event {wait}")
print("%" * 80)
raw = self.subscriber.recv(timeout=wait)
if raw is None:
break
@ -636,6 +639,9 @@ class SaltEvent:
request, it MUST subscribe the result to ensure the response is not lost
should other regions of code call get_event for other purposes.
"""
print("%" * 80)
print("GET EVENT CALLED")
print("%" * 80)
log.trace("Get event. tag: %s", tag)
assert self._run_io_loop_sync

View file

@ -1,3 +1,4 @@
import asyncio
import ctypes
import logging
import multiprocessing
@ -53,7 +54,8 @@ def transport_ids(value):
return f"transport({value})"
@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids)
#@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids)
@pytest.fixture(params=["ws",], ids=transport_ids)
def transport(request):
return request.param
@ -123,13 +125,12 @@ def master_secrets():
salt.master.SMaster.secrets.pop("aes")
@tornado.gen.coroutine
def _connect_and_publish(
async def _connect_and_publish(
io_loop, channel_minion_id, channel, server, received, timeout=60
):
yield channel.connect()
await channel.connect()
def cb(payload):
async def cb(payload):
received.append(payload)
io_loop.stop()
@ -139,7 +140,7 @@ def _connect_and_publish(
)
start = time.time()
while time.time() - start < timeout:
yield tornado.gen.sleep(1)
await asyncio.sleep(1)
io_loop.stop()
@ -158,7 +159,7 @@ def test_pub_server_channel(
req_server_channel = salt.channel.server.ReqServerChannel.factory(master_config)
req_server_channel.pre_fork(process_manager)
def handle_payload(payload):
async def handle_payload(payload):
log.debug("Payload handler got %r", payload)
req_server_channel.post_fork(handle_payload, io_loop=io_loop)

View file

@ -76,7 +76,7 @@ async def test_message_client_reconnect(config, client, server):
received = []
def handler(msg):
async def handler(msg):
received.append(msg)
client.on_recv(handler)
@ -119,5 +119,6 @@ async def test_message_client_reconnect(config, client, server):
# Close the client
client.close()
# Provide time for the on_recv task to complete
await tornado.gen.sleep(1)
await asyncio.sleep(.3)

View file

@ -19,6 +19,7 @@ async def test_get_no_mid(http_client, salt_minion, salt_sub_minion):
method="GET",
follow_redirects=False,
)
print(f"{response!r}")
response_obj = salt.utils.json.loads(response.body)
assert len(response_obj["return"]) == 1
assert isinstance(response_obj["return"][0], dict)

View file

@ -276,7 +276,7 @@ async def test_publish_client_connect_server_comes_up(transport, io_loop):
async def handler(request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
data = salt.transport.frame.frame_msg(msg, header=None)
data = salt.transport.dumps(msg)
await ws.send_bytes(data)
server = aiohttp.web.Server(handler)