From fea99b1335303234901b474b1bc426338384b1b2 Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Mon, 26 Jun 2023 02:09:06 -0700 Subject: [PATCH] Fix wart in pub_connect --- salt/transport/base.py | 16 +++- salt/transport/tcp.py | 96 ++++++++++++------- salt/transport/zeromq.py | 42 +++++--- salt/utils/event.py | 43 +++++---- .../transport/ipc/test_pub_server_channel.py | 2 + tests/support/pytest/transport.py | 5 + 6 files changed, 137 insertions(+), 67 deletions(-) diff --git a/salt/transport/base.py b/salt/transport/base.py index e75ef1b1415..9938656d37f 100644 --- a/salt/transport/base.py +++ b/salt/transport/base.py @@ -1,3 +1,4 @@ +import os import tornado.gen TRANSPORTS = ( @@ -58,6 +59,19 @@ def publish_server(opts, **kwargs): ttype = opts["transport"] elif "transport" in opts.get("pillar", {}).get("master", {}): ttype = opts["pillar"]["master"]["transport"] + + if "pub_host" not in kwargs and "pub_path" not in kwargs: + kwargs["pub_host"] = opts["interface"] + if "pub_port" not in kwargs and "pub_path" not in kwargs: + kwargs["pub_port"] = opts["publish_port"] + + if "pull_host" not in kwargs and "pull_path" not in kwargs: + if opts.get("ipc_mode", "") == "tcp": + kwargs["pull_host"] = "127.0.0.1" + kwargs["pull_port"] = opts.get("tcp_master_publish_pull", 4514) + else: + kwargs["pull_path"] = os.path.join(opts["sock_dir"], "publish_pull.ipc") + # switch on available ttypes if ttype == "zeromq": import salt.transport.zeromq @@ -66,7 +80,7 @@ def publish_server(opts, **kwargs): elif ttype == "tcp": import salt.transport.tcp - return salt.transport.tcp.TCPPublishServer(opts) + return salt.transport.tcp.TCPPublishServer(opts, **kwargs) elif ttype == "local": # TODO: import salt.transport.local diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index e1cc057fe03..71a271f8d46 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -1207,24 +1207,43 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): "close", ] - def __init__(self, opts): + def __init__(self, opts, **kwargs): self.opts = opts self.pub_sock = None # Set up Salt IPC server - if self.opts.get("ipc_mode", "") == "tcp": - self.pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514)) - else: - self.pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc") - interface = self.opts.get("interface", "127.0.0.1") - self.publish_port = self.opts.get("publish_port", 4560) - self.pub_uri = f"tcp://{interface}:{self.publish_port}" - log.error( - "TCPPubServer %r %s %s %s", - self, - self.pull_uri, - self.publish_port, - self.pub_uri, - ) + #if self.opts.get("ipc_mode", "") == "tcp": + # self.pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514)) + #else: + # self.pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc") + #interface = self.opts.get("interface", "127.0.0.1") + #self.publish_port = self.opts.get("publish_port", 4560) + #self.pub_uri = f"tcp://{interface}:{self.publish_port}" + self.pub_host = kwargs.get("pub_host", None) + self.pub_port = kwargs.get("pub_port", None) + self.pub_path = kwargs.get("pub_path", None) + #if pub_path: + # self.pub_path = pub_path + # self.pub_uri = f"ipc://{pub_path}" + #else: + # self.pub_uri = f"tcp://{pub_host}:{pub_port}" + + #self.publish_port = self.opts.get("publish_port", 4560) + + + self.pull_host = kwargs.get("pull_host", None) + self.pull_port = kwargs.get("pull_port", None) + self.pull_path = kwargs.get("pull_path", None) + #if pull_path: + # self.pull_uri = f"ipc://{pull_path}" + #else: + # self.pull_uri = f"tcp://{pub_host}:{pub_port}" + #log.error( + # "TCPPubServer %r %s %s", + # self, + # self.pull_uri, + # #self.publish_port, + # self.pub_uri, + #) @property def topic_support(self): @@ -1246,13 +1265,13 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): Bind to the interface specified in the configuration file """ io_loop = tornado.ioloop.IOLoop() - log.error( - "TCPPubServer daemon %r %s %s %s", - self, - self.pull_uri, - self.publish_port, - self.pub_uri, - ) + #log.error( + # "TCPPubServer daemon %r %s %s %s", + # self, + # self.pull_uri, + # self.publish_port, + # self.pub_uri, + #) # Spin up the publisher self.pub_server = pub_server = PubServer( @@ -1261,15 +1280,14 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): presence_callback=presence_callback, remove_presence_callback=remove_presence_callback, ) - if self.pub_uri.startswith("ipc://"): - pub_path = self.pub_uri.replace("ipc://", "") - sock = tornado.netutil.bind_unix_socket(pub_path) + if self.pub_path: + sock = tornado.netutil.bind_unix_socket(self.pub_path) else: sock = _get_socket(self.opts) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _set_tcp_keepalive(sock, self.opts) sock.setblocking(0) - sock.bind(_get_bind_addr(self.opts, "publish_port")) + sock.bind((self.pub_host, self.pub_port)) sock.listen(self.backlog) # pub_server will take ownership of the socket pub_server.add_socket(sock) @@ -1280,14 +1298,18 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): # else: # pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc") self.pub_server = pub_server - if "ipc://" in self.pull_uri: - pull_uri = pull_uri = self.pull_uri.replace("ipc://", "") - log.error("WTF PULL URI %r", pull_uri) - elif "tcp://" in self.pull_uri: - log.error("Fallback to publish port %r", self.pull_uri) - pull_uri = self.publish_port + #if "ipc://" in self.pull_uri: + # pull_uri = pull_uri = self.pull_uri.replace("ipc://", "") + # log.error("WTF PULL URI %r", pull_uri) + #elif "tcp://" in self.pull_uri: + # log.error("Fallback to publish port %r", self.pull_uri) + # pull_uri = self.publish_port + #else: + # pull_uri = self.pull_uri + if self.pull_path: + pull_uri = self.pull_path else: - pull_uri = self.pull_uri + pull_uri = self.pull_port pull_sock = salt.transport.ipc.IPCMessageServer( pull_uri, @@ -1296,7 +1318,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): ) # Securely create socket - log.warning("Starting the Salt Puller on %s", self.pull_uri) + log.warning("Starting the Salt Puller on %s", pull_uri) with salt.utils.files.set_umask(0o177): pull_sock.start() @@ -1323,8 +1345,8 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): raise tornado.gen.Return(ret) def connect(self): - path = self.pull_uri.replace("ipc://", "") - log.error("Connect pusher %s", path) + #path = self.pull_uri.replace("ipc://", "") + log.error("Connect pusher %s", self.pull_path) # self.pub_sock = salt.utils.asynchronous.SyncWrapper( # salt.transport.ipc.IPCMessageClient, # (path,), @@ -1332,7 +1354,7 @@ class TCPPublishServer(salt.transport.base.DaemonizedPublishServer): # ) self.pub_sock = salt.utils.asynchronous.SyncWrapper( salt.transport.ipc.IPCMessageClient, - (path,), + (self.pull_path,), loop_kwarg="io_loop", ) # self.pub_sock = salt.transport.ipc.IPCMessageClient(path) diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index 347c0190497..c3ddf74c806 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -807,19 +807,38 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): "close", ] - def __init__(self, opts): + def __init__(self, opts, **kwargs): self.opts = opts - if self.opts.get("ipc_mode", "") == "tcp": - self.pull_uri = "tcp://127.0.0.1:{}".format( - self.opts.get("tcp_master_publish_pull", 4514) - ) + #if self.opts.get("ipc_mode", "") == "tcp": + # self.pull_uri = "tcp://127.0.0.1:{}".format( + # self.opts.get("tcp_master_publish_pull", 4514) + # ) + #else: + # self.pull_uri = "ipc://{}".format( + # os.path.join(self.opts["sock_dir"], "publish_pull.ipc") + # ) + #interface = self.opts.get("interface", "127.0.0.1") + #publish_port = self.opts.get("publish_port", 4560) + #self.pub_uri = f"tcp://{interface}:{publish_port}" + + pub_host = kwargs.get("pub_host", None) + pub_port = kwargs.get("pub_port", None) + pub_path = kwargs.get("pub_path", None) + if pub_path: + self.pub_uri = f"ipc://{pub_path}" else: - self.pull_uri = "ipc://{}".format( - os.path.join(self.opts["sock_dir"], "publish_pull.ipc") - ) - interface = self.opts.get("interface", "127.0.0.1") - publish_port = self.opts.get("publish_port", 4560) - self.pub_uri = f"tcp://{interface}:{publish_port}" + self.pub_uri = f"tcp://{pub_host}:{pub_port}" + + + pull_host = kwargs.get("pull_host", None) + pull_port = kwargs.get("pull_port", None) + pull_path = kwargs.get("pull_path", None) + if pull_path: + self.pull_uri = f"ipc://{pull_path}" + else: + self.pull_uri = f"tcp://{pull_host}:{pull_port}" + + self.ctx = None self.sock = None self.daemon_context = None @@ -876,6 +895,7 @@ class PublishServer(salt.transport.base.DaemonizedPublishServer): salt.utils.zeromq.check_ipc_path_max_len(self.pull_uri) # Start the minion command publisher # Securely create socket + log.error("PULL URI %r PUB URI %r", self.pull_uri, self.pub_uri) with salt.utils.files.set_umask(0o177): log.info("Starting the Salt Publisher on %s", self.pub_uri) pub_sock.bind(self.pub_uri) diff --git a/salt/utils/event.py b/salt/utils/event.py index 84ee6289c40..8c9dd4e78fc 100644 --- a/salt/utils/event.py +++ b/salt/utils/event.py @@ -353,7 +353,7 @@ class SaltEvent: if self.cpub: return True - kwargs = {"io_loop": self.io_loop} + kwargs = {} if isinstance(self.puburi, int): kwargs.update(host="127.0.0.1", port=self.puburi) else: @@ -387,6 +387,7 @@ class SaltEvent: if self.subscriber is None: if "master_ip" not in self.opts: self.opts["master_ip"] = "" + kwargs["io_loop"] = self.io_loop self.subscriber = salt.transport.publish_client(self.opts, **kwargs) log.debug("Event connect subscriber %r", self.puburi) self.io_loop.spawn_callback(self.subscriber.connect) @@ -427,8 +428,11 @@ class SaltEvent: self.pusher = salt.utils.asynchronous.SyncWrapper( salt.transport.publish_server, args=(self.opts,), + kwargs={ + "pub_path": self.puburi, + "pull_path": self.pulluri, + } ) - log.error("PUSHER %r %r", self, self.pusher.io_loop.asyncio_loop) self.pusher.obj.pub_uri = "ipc://{}".format(self.puburi) self.pusher.obj.pull_uri = "ipc://{}".format(self.pulluri) # self.pusher = salt.utils.asynchronous.SyncWrapper( @@ -454,7 +458,11 @@ class SaltEvent: # self.pusher = salt.transport.ipc.IPCMessageClient( # self.pulluri, io_loop=self.io_loop # ) - self.pusher = salt.transport.publish_server(self.opts) + self.pusher = salt.transport.publish_server( + self.opts, + pub_path=self.puburi, + pull_path=self.pulluri + ) self.pusher.pub_uri = "ipc://{}".format(self.puburi) self.pusher.pull_uri = "ipc://{}".format(self.pulluri) # For the asynchronous case, the connect will be deferred to when @@ -682,21 +690,20 @@ class SaltEvent: ret = self._check_pending(tag, match_func) if ret is None: - with salt.utils.asynchronous.current_ioloop(self.io_loop): - if auto_reconnect: - raise_errors = self.raise_errors - self.raise_errors = True - while True: - try: - ret = self._get_event(wait, tag, match_func, no_block) - break - except tornado.iostream.StreamClosedError: - self.close_pub() - self.connect_pub(timeout=wait) - continue - self.raise_errors = raise_errors - else: - ret = self._get_event(wait, tag, match_func, no_block) + if auto_reconnect: + raise_errors = self.raise_errors + self.raise_errors = True + while True: + try: + ret = self._get_event(wait, tag, match_func, no_block) + break + except tornado.iostream.StreamClosedError: + self.close_pub() + self.connect_pub(timeout=wait) + continue + self.raise_errors = raise_errors + else: + ret = self._get_event(wait, tag, match_func, no_block) if ret is None or full: return ret diff --git a/tests/pytests/functional/transport/ipc/test_pub_server_channel.py b/tests/pytests/functional/transport/ipc/test_pub_server_channel.py index b4c8b0d0db5..83e4cfdf5de 100644 --- a/tests/pytests/functional/transport/ipc/test_pub_server_channel.py +++ b/tests/pytests/functional/transport/ipc/test_pub_server_channel.py @@ -76,10 +76,12 @@ def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport): with PubServerChannelProcess(opts, minion_opts) as server_channel: send_num = 10000 expect = [] + log.error("Sending %d messages", send_num) for idx in range(send_num): expect.append(idx) load = {"tgt_type": "glob", "tgt": "*", "jid": idx} server_channel.publish(load) + log.error("Finished sending messages") results = server_channel.collector.results assert len(results) == send_num, "{} != {}, difference: {}".format( len(results), send_num, set(expect).difference(results) diff --git a/tests/support/pytest/transport.py b/tests/support/pytest/transport.py index c81c79c96ca..c0f01bdfbc6 100644 --- a/tests/support/pytest/transport.py +++ b/tests/support/pytest/transport.py @@ -142,6 +142,7 @@ class Collector(salt.utils.process.SignalHandlingProcess): self.start = last_msg serial = salt.payload.Serial(self.minion_config) crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key) + self.gotone = False try: while True: curr_time = time.time() @@ -175,6 +176,10 @@ class Collector(salt.utils.process.SignalHandlingProcess): if not self.zmq_filtering: log.exception("Failed to deserialize...") break + if self.gotone is False: + log.error("Collector started recieving") + self.gotone = True + log.error("Collector finished recieving") self.end = time.time() print(f"Total time {self.end - self.start}") finally: