Merge pull request #51963 from DSRCorporation/bugs/49147_ipc_subscriber

Allow multiple instances of IPCMessageSubscriber in one process
This commit is contained in:
Daniel Wozniak 2019-03-27 15:48:55 -07:00 committed by GitHub
commit 7b2b5217bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 236 additions and 133 deletions

View file

@ -131,6 +131,7 @@ MOCK_MODULES = [
'tornado.ioloop',
'tornado.iostream',
'tornado.netutil',
'tornado.queues',
'tornado.simple_httpclient',
'tornado.stack_context',
'tornado.web',

View file

@ -18,7 +18,8 @@ import tornado
import tornado.gen
import tornado.netutil
import tornado.concurrent
from tornado.locks import Semaphore
import tornado.queues
from tornado.locks import Lock
from tornado.ioloop import IOLoop, TimeoutError as TornadoTimeoutError
from tornado.iostream import IOStream
# Import Salt libs
@ -582,11 +583,116 @@ class IPCMessagePublisher(object):
self.close()
class IPCMessageSubscriber(IPCClient):
class IPCMessageSubscriberService(IPCClient):
'''
IPC message subscriber service that is a standalone singleton class starting once for a number
of IPCMessageSubscriber instances feeding all of them with data. It closes automatically when
there are no more subscribers.
To use this refer to IPCMessageSubscriber documentation.
'''
def __singleton_init__(self, socket_path, io_loop=None):
super(IPCMessageSubscriberService, self).__singleton_init__(
socket_path, io_loop=io_loop)
self.saved_data = []
self._read_in_progress = Lock()
self.handlers = weakref.WeakSet()
def _subscribe(self, handler):
self.handlers.add(handler)
def unsubscribe(self, handler):
self.handlers.discard(handler)
def _has_subscribers(self):
return bool(self.handlers)
def _feed_subscribers(self, data):
for subscriber in self.handlers:
subscriber._feed(data)
@tornado.gen.coroutine
def _read(self, timeout, callback=None):
try:
yield self._read_in_progress.acquire(timeout=0)
except tornado.gen.TimeoutError:
raise tornado.gen.Return(None)
log.debug('IPC Subscriber Service is starting reading')
# If timeout is not specified we need to set some here to make the service able to check
# is there any handler waiting for data.
if timeout is None:
timeout = 5
read_stream_future = None
while self._has_subscribers():
if read_stream_future is None:
read_stream_future = self.stream.read_bytes(4096, partial=True)
try:
wire_bytes = yield FutureWithTimeout(self.io_loop,
read_stream_future,
timeout)
read_stream_future = None
self.unpacker.feed(wire_bytes)
msgs = [msg['body'] for msg in self.unpacker]
self._feed_subscribers(msgs)
except TornadoTimeoutError:
# Continue checking are there alive waiting handlers
# Keep 'read_stream_future' alive to wait it more in the next loop
continue
except tornado.iostream.StreamClosedError as exc:
log.trace('Subscriber disconnected from IPC %s', self.socket_path)
self._feed_subscribers([None])
break
except Exception as exc:
log.error('Exception occurred in Subscriber while handling stream: %s', exc)
self._feed_subscribers([exc])
break
log.debug('IPC Subscriber Service is stopping due to a lack of subscribers')
self._read_in_progress.release()
raise tornado.gen.Return(None)
@tornado.gen.coroutine
def read(self, handler, timeout=None):
'''
Asynchronously read messages and invoke a callback when they are ready.
:param callback: A callback with the received data
'''
self._subscribe(handler)
while not self.connected():
try:
yield self.connect(timeout=5)
except tornado.iostream.StreamClosedError:
log.trace('Subscriber closed stream on IPC %s before connect', self.socket_path)
yield tornado.gen.sleep(1)
except Exception as exc:
log.error('Exception occurred while Subscriber connecting: %s', exc)
yield tornado.gen.sleep(1)
self._read(timeout)
def close(self):
'''
Routines to handle any cleanup before the instance shuts down.
Sockets and filehandles should be closed explicitly, to prevent
leaks.
'''
if not self._closing:
super(IPCMessageSubscriberService, self).close()
def __del__(self):
if IPCMessageSubscriberService in globals():
self.close()
class IPCMessageSubscriber(object):
'''
Salt IPC message subscriber
Create an IPC client to receive messages from IPC publisher
Create or reuse an IPC client to receive messages from IPC publisher
An example of a very simple IPCMessageSubscriber connecting to an IPCMessagePublisher.
This example assumes an already running IPCMessagePublisher.
@ -615,147 +721,60 @@ class IPCMessageSubscriber(IPCClient):
# Wait for some data
package = ipc_subscriber.read_sync()
'''
def __singleton_init__(self, socket_path, io_loop=None):
super(IPCMessageSubscriber, self).__singleton_init__(
socket_path, io_loop=io_loop)
self._read_sync_future = None
self._read_stream_future = None
self._sync_ioloop_running = False
self.saved_data = []
self._sync_read_in_progress = Semaphore()
def __init__(self, socket_path, io_loop=None):
self.service = IPCMessageSubscriberService(socket_path, io_loop)
self.queue = tornado.queues.Queue()
def connected(self):
return self.service.connected()
def connect(self, callback=None, timeout=None):
return self.service.connect(callback=callback, timeout=timeout)
@tornado.gen.coroutine
def _read_sync(self, timeout):
yield self._sync_read_in_progress.acquire()
exc_to_raise = None
ret = None
try:
while True:
if self._read_stream_future is None:
self._read_stream_future = self.stream.read_bytes(4096, partial=True)
if timeout is None:
wire_bytes = yield self._read_stream_future
else:
future_with_timeout = FutureWithTimeout(
self.io_loop, self._read_stream_future, timeout)
wire_bytes = yield future_with_timeout
self._read_stream_future = None
# Remove the timeout once we get some data or an exception
# occurs. We will assume that the rest of the data is already
# there or is coming soon if an exception doesn't occur.
timeout = None
self.unpacker.feed(wire_bytes)
first = True
for framed_msg in self.unpacker:
if first:
ret = framed_msg['body']
first = False
else:
self.saved_data.append(framed_msg['body'])
if not first:
# We read at least one piece of data
break
except TornadoTimeoutError:
# In the timeout case, just return None.
# Keep 'self._read_stream_future' alive.
ret = None
except tornado.iostream.StreamClosedError as exc:
log.trace('Subscriber disconnected from IPC %s', self.socket_path)
self._read_stream_future = None
exc_to_raise = exc
except Exception as exc:
log.error('Exception occurred in Subscriber while handling stream: %s', exc)
self._read_stream_future = None
exc_to_raise = exc
if self._sync_ioloop_running:
# Stop the IO Loop so that self.io_loop.start() will return in
# read_sync().
self.io_loop.spawn_callback(self.io_loop.stop)
if exc_to_raise is not None:
raise exc_to_raise # pylint: disable=E0702
self._sync_read_in_progress.release()
raise tornado.gen.Return(ret)
def read_sync(self, timeout=None):
'''
Read a message from an IPC socket
The socket must already be connected.
The associated IO Loop must NOT be running.
:param int timeout: Timeout when receiving message
:return: message data if successful. None if timed out. Will raise an
exception for all other error conditions.
'''
if self.saved_data:
return self.saved_data.pop(0)
self._sync_ioloop_running = True
self._read_sync_future = self._read_sync(timeout)
self.io_loop.start()
self._sync_ioloop_running = False
ret_future = self._read_sync_future
self._read_sync_future = None
return ret_future.result()
def _feed(self, msgs):
for msg in msgs:
yield self.queue.put(msg)
@tornado.gen.coroutine
def _read_async(self, callback):
while not self.stream.closed():
try:
self._read_stream_future = self.stream.read_bytes(4096, partial=True)
wire_bytes = yield self._read_stream_future
self._read_stream_future = None
self.unpacker.feed(wire_bytes)
for framed_msg in self.unpacker:
body = framed_msg['body']
self.io_loop.spawn_callback(callback, body)
except tornado.iostream.StreamClosedError:
log.trace('Subscriber disconnected from IPC %s', self.socket_path)
break
except Exception as exc:
log.error('Exception occurred while Subscriber handling stream: %s', exc)
@tornado.gen.coroutine
def read_async(self, callback):
def read_async(self, callback, timeout=None):
'''
Asynchronously read messages and invoke a callback when they are ready.
:param callback: A callback with the received data
'''
while not self.connected():
self.service.read(self)
while True:
try:
yield self.connect(timeout=5)
except tornado.iostream.StreamClosedError:
log.trace('Subscriber closed stream on IPC %s before connect', self.socket_path)
yield tornado.gen.sleep(1)
except Exception as exc:
log.error('Exception occurred while Subscriber connecting: %s', exc)
yield tornado.gen.sleep(1)
yield self._read_async(callback)
if timeout is not None:
deadline = time.time() + timeout
else:
deadline = None
data = yield self.queue.get(timeout=deadline)
except tornado.gen.TimeoutError:
raise tornado.gen.Return(None)
if data is None:
break
elif isinstance(data, Exception):
raise data
elif callback:
self.service.io_loop.spawn_callback(callback, data)
else:
raise tornado.gen.Return(data)
def read_sync(self, timeout=None):
'''
Read a message from an IPC socket
The associated IO Loop must NOT be running.
:param int timeout: Timeout when receiving message
:return: message data if successful. None if timed out. Will raise an
exception for all other error conditions.
'''
return self.service.io_loop.run_sync(lambda: self.read_async(None, timeout))
def close(self):
'''
Routines to handle any cleanup before the instance shuts down.
Sockets and filehandles should be closed explicitly, to prevent
leaks.
'''
if not self._closing:
IPCClient.close(self)
# This will prevent this message from showing up:
# '[ERROR ] Future exception was never retrieved:
# StreamClosedError'
if self._read_sync_future is not None and self._read_sync_future.done():
self._read_sync_future.exception()
if self._read_stream_future is not None and self._read_stream_future.done():
self._read_stream_future.exception()
self.service.unsubscribe(self)
def __del__(self):
if IPCMessageSubscriber in globals():
self.close()
self.close()

View file

@ -8,6 +8,7 @@ from __future__ import absolute_import, print_function, unicode_literals
import os
import errno
import socket
import threading
import logging
import tornado.gen
@ -154,3 +155,85 @@ class IPCMessageClient(BaseIPCReqCase):
self.channel.send({'stop': True})
self.wait()
self.assertEqual(self.payloads[:-1], [None, None, 'foo', 'foo'])
@skipIf(salt.utils.platform.is_windows(), 'Windows does not support Posix IPC')
class IPCMessagePubSubCase(tornado.testing.AsyncTestCase):
'''
Test all of the clear msg stuff
'''
def setUp(self):
super(IPCMessagePubSubCase, self).setUp()
self.opts = {'ipc_write_buffer': 0}
self.socket_path = os.path.join(TMP, 'ipc_test.ipc')
self.pub_channel = self._get_pub_channel()
self.sub_channel = self._get_sub_channel()
def _get_pub_channel(self):
pub_channel = salt.transport.ipc.IPCMessagePublisher(
self.opts,
self.socket_path,
)
pub_channel.start()
return pub_channel
def _get_sub_channel(self):
sub_channel = salt.transport.ipc.IPCMessageSubscriber(
socket_path=self.socket_path,
io_loop=self.io_loop,
)
sub_channel.connect(callback=self.stop)
self.wait()
return sub_channel
def tearDown(self):
super(IPCMessagePubSubCase, self).tearDown()
try:
self.pub_channel.close()
except socket.error as exc:
if exc.errno != errno.EBADF:
# If its not a bad file descriptor error, raise
raise
try:
self.sub_channel.close()
except socket.error as exc:
if exc.errno != errno.EBADF:
# If its not a bad file descriptor error, raise
raise
os.unlink(self.socket_path)
del self.pub_channel
del self.sub_channel
def test_multi_client_reading(self):
# To be completely fair let's create 2 clients.
client1 = self.sub_channel
client2 = self._get_sub_channel()
call_cnt = []
# Create a watchdog to be safe from hanging in sync loops (what old code did)
evt = threading.Event()
def close_server():
if evt.wait(1):
return
client2.close()
self.stop()
watchdog = threading.Thread(target=close_server)
watchdog.start()
# Runs in ioloop thread so we're safe from race conditions here
def handler(raw):
call_cnt.append(raw)
if len(call_cnt) >= 2:
evt.set()
self.stop()
# Now let both waiting data at once
client1.read_async(handler)
client2.read_async(handler)
self.pub_channel.publish('TEST')
self.wait()
self.assertEqual(len(call_cnt), 2)
self.assertEqual(call_cnt[0], 'TEST')
self.assertEqual(call_cnt[1], 'TEST')