Fix transport merge warts and tests

This commit is contained in:
Daniel A. Wozniak 2023-11-07 00:04:25 -07:00
parent 5aba669eb4
commit c97a6dafde
5 changed files with 122 additions and 91 deletions

View file

@ -177,34 +177,22 @@ class LoadBalancerServer(SignalHandlingProcess):
self._socket.setblocking(1)
self._socket.bind(_get_bind_addr(self.opts, "ret_port"))
self._socket.listen(self.backlog)
def run(self):
"""
Start the load balancer
"""
self._socket = _get_socket(self.opts)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(self._socket, self.opts)
self._socket.setblocking(1)
self._socket.bind(_get_bind_addr(self.opts, "ret_port"))
self._socket.listen(self.backlog)
while True:
try:
# Wait for a connection to occur since the socket is
# blocking.
connection, address = self._socket.accept()
# Wait for a free slot to be available to put
# the connection into.
# Sockets are picklable on Windows in Python 3.
self.socket_queue.put((connection, address), True, None)
except OSError as e:
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if tornado.util.errno_from_exception(e) == errno.ECONNABORTED:
continue
raise
while True:
try:
# Wait for a connection to occur since the socket is
# blocking.
connection, address = self._socket.accept()
# Wait for a free slot to be available to put
# the connection into.
# Sockets are picklable on Windows in Python 3.
self.socket_queue.put((connection, address), True, None)
except OSError as e:
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if tornado.util.errno_from_exception(e) == errno.ECONNABORTED:
continue
raise
class Resolver(tornado.netutil.DefaultLoopResolver):
@ -343,8 +331,6 @@ class TCPPubClient(salt.transport.base.PublishClient):
self._closed = False
self._stream = await self.getstream(timeout=timeout)
if self._stream:
# if not self._stream_return_running:
# self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
self.connected = True
@ -1039,7 +1025,7 @@ class PubServer(tornado.tcpserver.TCPServer):
return
self._closing = True
for client in self.clients:
client.stream.disconnect()
client.stream.close()
# pylint: disable=W1701
def __del__(self):

View file

@ -1000,13 +1000,14 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
ctx = self.ctx
self.ctx = None
ctx.term()
if self.daemon_monitor:
self.daemon_monitor.stop()
if self.daemon_pub_sock:
self.daemon_pub_sock.close()
if self.daemon_pull_sock:
self.daemon_pull_sock.close()
if self.daemon_monitor:
self.daemon_monitor.stop()
if self.daemon_context:
self.daemon_context.destroy(1)
self.daemon_context.term()
async def publish(self, payload, **kwargs):
@ -1019,6 +1020,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer):
if not self.sock:
self.connect()
await self.sock.send(payload)
# await self.sock.send(salt.payload.dumps(payload))
@property
def topic_support(self):

View file

@ -1,3 +1,4 @@
import asyncio
import multiprocessing
import socket
import threading
@ -24,10 +25,10 @@ def test_tcp_load_balancer_server(master_opts, io_loop):
worker = salt.transport.tcp.LoadBalancerWorker(queue, handler, io_loop=io_loop)
def run_loop():
io_loop.start()
loop_thread = threading.Thread(target=run_loop)
loop_thread.start()
try:
io_loop.start()
except Exception as exc:
print(f"Caught exeption {exc}")
thread = threading.Thread(target=server.run)
thread.start()
@ -41,15 +42,22 @@ def test_tcp_load_balancer_server(master_opts, io_loop):
sock.connect(("127.0.0.1", master_opts["ret_port"]))
sock.send(payload)
try:
start = time.monotonic()
start = time.monotonic()
async def check_test():
while not messages:
time.sleep(0.3)
await asyncio.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "Took longer than 30 seconds to receive message"
break
io_loop.run_sync(lambda: check_test())
try:
if time.monotonic() - start > 30:
assert False, "Took longer than 30 seconds to receive message"
assert [package] == messages
finally:
server.close()
thread.join()
io_loop.stop()
worker.close()

View file

@ -1,3 +1,4 @@
import os
import threading
import time
@ -16,36 +17,42 @@ async def test_pub_channel(master_opts, minion_opts, io_loop):
master_opts["transport"] = "tcp"
minion_opts.update(master_ip="127.0.0.1", transport="tcp")
server = salt.transport.tcp.TCPPublishServer(master_opts)
server = salt.transport.tcp.TCPPublishServer(
master_opts,
pub_host="127.0.0.1",
pub_port=master_opts["publish_port"],
pull_path=os.path.join(master_opts["sock_dir"], "publish_pull.ipc"),
)
client = salt.transport.tcp.TCPPubClient(minion_opts, io_loop)
client = salt.transport.tcp.TCPPubClient(
minion_opts,
io_loop,
host="127.0.0.1",
port=master_opts["publish_port"],
)
payloads = []
publishes = []
def publish_payload(payload, callback):
server.publish_payload(payload)
async def publish_payload(payload, callback):
await server.publish_payload(payload)
payloads.append(payload)
def on_recv(message):
print("ON RECV")
async def on_recv(message):
publishes.append(message)
thread = threading.Thread(
target=server.publish_daemon,
args=(publish_payload, presence_callback, remove_presence_callback),
io_loop.add_callback(
server.publisher, publish_payload, presence_callback, remove_presence_callback
)
thread.start()
# Wait for socket to bind.
time.sleep(3)
await tornado.gen.sleep(3)
await client.connect(master_opts["publish_port"])
client.on_recv(on_recv)
print("Publish message")
server.publish({"meh": "bah"})
await server.publish({"meh": "bah"})
start = time.monotonic()
try:
@ -54,6 +61,6 @@ async def test_pub_channel(master_opts, minion_opts, io_loop):
if time.monotonic() - start > 30:
assert False, "Message not published after 30 seconds"
finally:
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)
server.close()
server.pub_server.close()
client.close()

View file

@ -1,4 +1,6 @@
import asyncio
import logging
import os
import threading
import time
@ -56,84 +58,110 @@ def test_zeromq_filtering(salt_master, salt_minion):
)
def test_pub_channel(master_opts):
server = salt.transport.zeromq.PublishServer(master_opts)
async def test_pub_channel(master_opts, io_loop):
server = salt.transport.zeromq.PublishServer(
master_opts,
pub_host="127.0.0.1",
pub_port=4506,
pull_path=os.path.join(master_opts["sock_dir"], "publish_pull.ipc"),
)
payloads = []
def publish_payload(payload):
server.publish_payload(payload)
async def publish_payload(payload):
await server.publish_payload(payload)
payloads.append(payload)
thread = threading.Thread(target=server.publish_daemon, args=(publish_payload,))
thread.start()
io_loop.add_callback(
server.publisher,
publish_payload,
ioloop=io_loop,
)
server.publish({"meh": "bah"})
await asyncio.sleep(3)
await server.publish(salt.payload.dumps({"meh": "bah"}))
start = time.monotonic()
try:
while not payloads:
time.sleep(0.3)
await asyncio.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "No message received after 30 seconds"
assert payloads
finally:
server.close()
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)
def test_pub_channel_filtering(master_opts):
async def test_pub_channel_filtering(master_opts, io_loop):
master_opts["zmq_filtering"] = True
server = salt.transport.zeromq.PublishServer(master_opts)
server = salt.transport.zeromq.PublishServer(
master_opts,
pub_host="127.0.0.1",
pub_port=4506,
pull_path=os.path.join(master_opts["sock_dir"], "publish_pull.ipc"),
)
payloads = []
def publish_payload(payload):
server.publish_payload(payload)
async def publish_payload(payload):
await server.publish_payload(payload)
payloads.append(payload)
thread = threading.Thread(target=server.publish_daemon, args=(publish_payload,))
thread.start()
io_loop.add_callback(
server.publisher,
publish_payload,
ioloop=io_loop,
)
server.publish({"meh": "bah"})
await asyncio.sleep(3)
await server.publish(salt.payload.dumps({"meh": "bah"}))
start = time.monotonic()
try:
while not payloads:
time.sleep(0.3)
await asyncio.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "No message received after 30 seconds"
finally:
server.close()
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)
def test_pub_channel_filtering_topic(master_opts):
async def test_pub_channel_filtering_topic(master_opts, io_loop):
master_opts["zmq_filtering"] = True
server = salt.transport.zeromq.PublishServer(master_opts)
server = salt.transport.zeromq.PublishServer(
master_opts,
pub_host="127.0.0.1",
pub_port=4506,
pull_path=os.path.join(master_opts["sock_dir"], "publish_pull.ipc"),
)
payloads = []
def publish_payload(payload):
server.publish_payload(payload, topic_list=["meh"])
async def publish_payload(payload):
await server.publish_payload(payload, topic_list=["meh"])
payloads.append(payload)
thread = threading.Thread(target=server.publish_daemon, args=(publish_payload,))
thread.start()
io_loop.add_callback(
server.publisher,
publish_payload,
ioloop=io_loop,
)
server.publish({"meh": "bah"})
await asyncio.sleep(3)
await server.publish(salt.payload.dumps({"meh": "bah"}))
start = time.monotonic()
try:
while not payloads:
time.sleep(0.3)
await asyncio.sleep(0.3)
if time.monotonic() - start > 30:
assert False, "No message received after 30 seconds"
finally:
server.close()
server.io_loop.stop()
thread.join()
server.io_loop.close(all_fds=True)