Only warn when connect was called

This commit is contained in:
Daniel A. Wozniak 2023-11-14 18:44:59 -07:00 committed by Pedro Algarvio
parent dcc9976d9b
commit d85644015c
5 changed files with 83 additions and 7 deletions

View file

@ -97,18 +97,38 @@ def publish_client(opts, io_loop):
raise Exception("Transport type not found: {}".format(ttype))
class TransportWarning(Warning):
"""
Transport warning.
"""
class Transport:
def __init__(self, *args, **kwargs):
self._trace = "\n".join(traceback.format_stack()[:-1])
if not hasattr(self, "_closing"):
self._closing = False
if not hasattr(self, "_connect_called"):
self._connect_called = False
def connect(self, *args, **kwargs):
self._connect_called = True
# pylint: disable=W1701
def __del__(self):
if not self._closing:
"""
Warn the user if the transport's close method was never called.
If the _closing attribute is missing we won't raise a warning. This
prevents issues when class's dunder init method is called with improper
arguments, and is later getting garbage collected. Users of this class
should take care to call super() and validate the functionality with a
test.
"""
if getattr(self, "_connect_called") and not getattr(self, "_closing", True):
warnings.warn(
f"Unclosed transport {self!r} \n{self._trace}",
ResourceWarning,
f"Unclosed transport! {self!r} \n{self._trace}",
TransportWarning,
source=self,
)
@ -137,7 +157,7 @@ class RequestClient(Transport):
"""
raise NotImplementedError
def connect(self):
def connect(self): # pylint: disable=W0221
"""
Connect to the server / broker.
"""
@ -233,7 +253,7 @@ class PublishClient(Transport):
raise NotImplementedError
@salt.ext.tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
def connect(self, publish_port, connect_callback=None, disconnect_callback=None): # pylint: disable=W0221
"""
Create a network connection to the the PublishServer or broker.
"""

View file

@ -231,6 +231,7 @@ class TCPPubClient(salt.transport.base.PublishClient):
@salt.ext.tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
self._connect_called = True
self.publish_port = publish_port
self.message_client = MessageClient(
self.opts,
@ -1054,6 +1055,7 @@ class TCPReqClient(salt.transport.base.RequestClient):
@salt.ext.tornado.gen.coroutine
def connect(self):
self._connect_called = True
yield self.message_client.connect()
@salt.ext.tornado.gen.coroutine

View file

@ -207,6 +207,7 @@ class PublishClient(salt.transport.base.PublishClient):
# TODO: this is the time to see if we are connected, maybe use the req channel to guess?
@salt.ext.tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
self._connect_called = True
self.publish_port = publish_port
log.debug(
"Connecting the Minion to the Master publish port, using the URI: %s",
@ -214,7 +215,8 @@ class PublishClient(salt.transport.base.PublishClient):
)
log.debug("%r connecting to %s", self, self.master_pub)
self._socket.connect(self.master_pub)
connect_callback(True)
if connect_callback is not None:
connect_callback(True)
@property
def master_pub(self):
@ -886,13 +888,16 @@ class RequestClient(salt.transport.base.RequestClient):
io_loop=io_loop,
)
self._closing = False
self._connect_called = False
@salt.ext.tornado.gen.coroutine
def connect(self):
self._connect_called = True
self.message_client.connect()
@salt.ext.tornado.gen.coroutine
def send(self, load, timeout=60):
self.connect()
yield self.connect()
ret = yield self.message_client.send(load, timeout=timeout)
raise salt.ext.tornado.gen.Return(ret)

View file

@ -0,0 +1,21 @@
"""
Unit tests for salt.transport.base.
"""
import pytest
import salt.transport.base
pytestmark = [
pytest.mark.core_test,
]
def test_unclosed_warning():
transport = salt.transport.base.Transport()
assert transport._closing is False
assert transport._connect_called is False
transport.connect()
assert transport._connect_called is True
with pytest.warns(salt.transport.base.TransportWarning):
del transport

View file

@ -1498,3 +1498,31 @@ def test_pub_client_init(minion_opts, io_loop):
client = salt.transport.zeromq.PublishClient(minion_opts, io_loop)
client.send(b"asf")
client.close()
async def test_unclosed_request_client(minion_opts, io_loop):
minion_opts["master_uri"] = "tcp://127.0.0.1:4506"
client = salt.transport.zeromq.RequestClient(minion_opts, io_loop)
await client.connect()
try:
assert client._closing is False
with pytest.warns(salt.transport.base.TransportWarning):
client.__del__()
finally:
client.close()
async def test_unclosed_publish_client(minion_opts, io_loop):
minion_opts["id"] = "minion"
minion_opts["__role"] = "minion"
minion_opts["master_ip"] = "127.0.0.1"
minion_opts["zmq_filtering"] = True
minion_opts["zmq_monitor"] = True
client = salt.transport.zeromq.PublishClient(minion_opts, io_loop)
await client.connect(2121)
try:
assert client._closing is False
with pytest.warns(salt.transport.base.TransportWarning):
client.__del__()
finally:
client.close()