From d85644015cf1a461b3e77904ac617e64ca5ec5c1 Mon Sep 17 00:00:00 2001 From: "Daniel A. Wozniak" Date: Tue, 14 Nov 2023 18:44:59 -0700 Subject: [PATCH] Only warn when connect was called --- salt/transport/base.py | 30 +++++++++++++++++---- salt/transport/tcp.py | 2 ++ salt/transport/zeromq.py | 9 +++++-- tests/pytests/unit/transport/test_base.py | 21 +++++++++++++++ tests/pytests/unit/transport/test_zeromq.py | 28 +++++++++++++++++++ 5 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 tests/pytests/unit/transport/test_base.py diff --git a/salt/transport/base.py b/salt/transport/base.py index 30c57fb9f97..6fa6a5fee5d 100644 --- a/salt/transport/base.py +++ b/salt/transport/base.py @@ -97,18 +97,38 @@ def publish_client(opts, io_loop): raise Exception("Transport type not found: {}".format(ttype)) +class TransportWarning(Warning): + """ + Transport warning. + """ + + class Transport: def __init__(self, *args, **kwargs): self._trace = "\n".join(traceback.format_stack()[:-1]) if not hasattr(self, "_closing"): self._closing = False + if not hasattr(self, "_connect_called"): + self._connect_called = False + + def connect(self, *args, **kwargs): + self._connect_called = True # pylint: disable=W1701 def __del__(self): - if not self._closing: + """ + Warn the user if the transport's close method was never called. + + If the _closing attribute is missing we won't raise a warning. This + prevents issues when class's dunder init method is called with improper + arguments, and is later getting garbage collected. Users of this class + should take care to call super() and validate the functionality with a + test. + """ + if getattr(self, "_connect_called") and not getattr(self, "_closing", True): warnings.warn( - f"Unclosed transport {self!r} \n{self._trace}", - ResourceWarning, + f"Unclosed transport! {self!r} \n{self._trace}", + TransportWarning, source=self, ) @@ -137,7 +157,7 @@ class RequestClient(Transport): """ raise NotImplementedError - def connect(self): + def connect(self): # pylint: disable=W0221 """ Connect to the server / broker. """ @@ -233,7 +253,7 @@ class PublishClient(Transport): raise NotImplementedError @salt.ext.tornado.gen.coroutine - def connect(self, publish_port, connect_callback=None, disconnect_callback=None): + def connect(self, publish_port, connect_callback=None, disconnect_callback=None): # pylint: disable=W0221 """ Create a network connection to the the PublishServer or broker. """ diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index 94912c89497..2c3b5644fe6 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -231,6 +231,7 @@ class TCPPubClient(salt.transport.base.PublishClient): @salt.ext.tornado.gen.coroutine def connect(self, publish_port, connect_callback=None, disconnect_callback=None): + self._connect_called = True self.publish_port = publish_port self.message_client = MessageClient( self.opts, @@ -1054,6 +1055,7 @@ class TCPReqClient(salt.transport.base.RequestClient): @salt.ext.tornado.gen.coroutine def connect(self): + self._connect_called = True yield self.message_client.connect() @salt.ext.tornado.gen.coroutine diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index 12454216c24..e166d346926 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -207,6 +207,7 @@ class PublishClient(salt.transport.base.PublishClient): # TODO: this is the time to see if we are connected, maybe use the req channel to guess? @salt.ext.tornado.gen.coroutine def connect(self, publish_port, connect_callback=None, disconnect_callback=None): + self._connect_called = True self.publish_port = publish_port log.debug( "Connecting the Minion to the Master publish port, using the URI: %s", @@ -214,7 +215,8 @@ class PublishClient(salt.transport.base.PublishClient): ) log.debug("%r connecting to %s", self, self.master_pub) self._socket.connect(self.master_pub) - connect_callback(True) + if connect_callback is not None: + connect_callback(True) @property def master_pub(self): @@ -886,13 +888,16 @@ class RequestClient(salt.transport.base.RequestClient): io_loop=io_loop, ) self._closing = False + self._connect_called = False + @salt.ext.tornado.gen.coroutine def connect(self): + self._connect_called = True self.message_client.connect() @salt.ext.tornado.gen.coroutine def send(self, load, timeout=60): - self.connect() + yield self.connect() ret = yield self.message_client.send(load, timeout=timeout) raise salt.ext.tornado.gen.Return(ret) diff --git a/tests/pytests/unit/transport/test_base.py b/tests/pytests/unit/transport/test_base.py new file mode 100644 index 00000000000..da5a6fa2615 --- /dev/null +++ b/tests/pytests/unit/transport/test_base.py @@ -0,0 +1,21 @@ +""" +Unit tests for salt.transport.base. +""" +import pytest + +import salt.transport.base + +pytestmark = [ + pytest.mark.core_test, +] + + +def test_unclosed_warning(): + + transport = salt.transport.base.Transport() + assert transport._closing is False + assert transport._connect_called is False + transport.connect() + assert transport._connect_called is True + with pytest.warns(salt.transport.base.TransportWarning): + del transport diff --git a/tests/pytests/unit/transport/test_zeromq.py b/tests/pytests/unit/transport/test_zeromq.py index 2bad5f9ae5f..61f4aaf3f84 100644 --- a/tests/pytests/unit/transport/test_zeromq.py +++ b/tests/pytests/unit/transport/test_zeromq.py @@ -1498,3 +1498,31 @@ def test_pub_client_init(minion_opts, io_loop): client = salt.transport.zeromq.PublishClient(minion_opts, io_loop) client.send(b"asf") client.close() + + +async def test_unclosed_request_client(minion_opts, io_loop): + minion_opts["master_uri"] = "tcp://127.0.0.1:4506" + client = salt.transport.zeromq.RequestClient(minion_opts, io_loop) + await client.connect() + try: + assert client._closing is False + with pytest.warns(salt.transport.base.TransportWarning): + client.__del__() + finally: + client.close() + + +async def test_unclosed_publish_client(minion_opts, io_loop): + minion_opts["id"] = "minion" + minion_opts["__role"] = "minion" + minion_opts["master_ip"] = "127.0.0.1" + minion_opts["zmq_filtering"] = True + minion_opts["zmq_monitor"] = True + client = salt.transport.zeromq.PublishClient(minion_opts, io_loop) + await client.connect(2121) + try: + assert client._closing is False + with pytest.warns(salt.transport.base.TransportWarning): + client.__del__() + finally: + client.close()