Implements MessageClientPool to avoid blocking waiting for zeromq and tcp communications.

This commit is contained in:
kstreee 2017-05-31 14:45:48 +09:00
parent a10f0146a4
commit 94b9ea51eb
8 changed files with 288 additions and 38 deletions

View file

@ -785,6 +785,26 @@ what you are doing! Transports are explained in :ref:`Salt Transports
ret_port: 4606
zeromq: []
``sock_pool_size``
------------------
Default: 1
To avoid blocking waiting while writing a data to a socket, we support
socket pool for Salt applications. For example, a job with a large number
of target host list can cause long period blocking waiting. The option
is used by ZMQ and TCP transports, and the other transport methods don't
need the socket pool by definition. Most of Salt tools, including CLI,
are enough to use a single bucket of socket pool. On the other hands,
it is highly recommended to set the size of socket pool larger than 1
for other Salt applications, especially Salt API, which must write data
to socket concurrently.
.. code-block:: yaml
sock_pool_size: 15
Salt-SSH Configuration
======================

View file

@ -199,6 +199,9 @@ VALID_OPTS = {
# The directory containing unix sockets for things like the event bus
'sock_dir': str,
# The pool size of unix sockets, it is necessary to avoid blocking waiting for zeromq and tcp communications.
'sock_pool_size': int,
# Specifies how the file server should backup files, if enabled. The backups
# live in the cache dir.
'backup_mode': str,
@ -989,6 +992,7 @@ DEFAULT_MINION_OPTS = {
'grains_deep_merge': False,
'conf_file': os.path.join(salt.syspaths.CONFIG_DIR, 'minion'),
'sock_dir': os.path.join(salt.syspaths.SOCK_DIR, 'minion'),
'sock_pool_size': 1,
'backup_mode': '',
'renderer': 'yaml_jinja',
'renderer_whitelist': [],
@ -1212,6 +1216,7 @@ DEFAULT_MASTER_OPTS = {
'user': salt.utils.get_user(),
'worker_threads': 5,
'sock_dir': os.path.join(salt.syspaths.SOCK_DIR, 'master'),
'sock_pool_size': 1,
'ret_port': 4506,
'timeout': 5,
'keep_jobs': 24,
@ -2053,6 +2058,7 @@ def syndic_config(master_config_path,
'sock_dir': os.path.join(
opts['cachedir'], opts.get('syndic_sock_dir', opts['sock_dir'])
),
'sock_pool_size': master_opts['sock_pool_size'],
'cachedir': master_opts['cachedir'],
}
opts.update(syndic_opts)

View file

@ -3,9 +3,13 @@
Encapsulate the different transports available to Salt.
'''
from __future__ import absolute_import
import logging
# Import third party libs
import salt.ext.six as six
from salt.ext.six.moves import range
log = logging.getLogger(__name__)
def iter_transport_opts(opts):
@ -47,3 +51,25 @@ class Channel(object):
# salt.transport.channel.Channel.factory()
from salt.transport.client import ReqChannel
return ReqChannel.factory(opts, **kwargs)
class MessageClientPool(object):
def __init__(self, tgt, opts, args=None, kwargs=None):
sock_pool_size = opts['sock_pool_size'] if 'sock_pool_size' in opts else 1
if sock_pool_size < 1:
log.warn('sock_pool_size is not correctly set, \
the option should be greater than 0 but, {0}'.format(sock_pool_size))
sock_pool_size = 1
if args is None:
args = ()
if kwargs is None:
kwargs = {}
self.message_clients = [tgt(*args, **kwargs) for _ in range(sock_pool_size)]
def __del__(self):
while self.message_clients:
message_client = self.message_clients.pop()
if message_client is not None:
del message_client

View file

@ -267,9 +267,9 @@ class AsyncTCPReqChannel(salt.transport.client.ReqChannel):
host, port = parse.netloc.rsplit(':', 1)
self.master_addr = (host, int(port))
self._closing = False
self.message_client = SaltMessageClient(
self.opts, host, int(port), io_loop=self.io_loop,
resolver=resolver)
self.message_client = SaltMessageClientPool(self.opts,
args=(self.opts, host, int(port),),
kwargs={'io_loop': self.io_loop, 'resolver': resolver})
def close(self):
if self._closing:
@ -404,7 +404,7 @@ class AsyncTCPPubChannel(salt.transport.mixins.auth.AESPubClientMixin, salt.tran
def _do_transfer():
msg = self._package_load(self.auth.crypticle.dumps(load))
package = salt.transport.frame.frame_msg(msg, header=None)
yield self.message_client._stream.write(package)
yield self.message_client.write_to_stream(package)
raise tornado.gen.Return(True)
if force_auth or not self.auth.authenticated:
@ -494,13 +494,12 @@ class AsyncTCPPubChannel(salt.transport.mixins.auth.AESPubClientMixin, salt.tran
if not self.auth.authenticated:
yield self.auth.authenticate()
if self.auth.authenticated:
self.message_client = SaltMessageClient(
self.message_client = SaltMessageClientPool(
self.opts,
self.opts['master_ip'],
int(self.auth.creds['publish_port']),
io_loop=self.io_loop,
connect_callback=self.connect_callback,
disconnect_callback=self.disconnect_callback)
args=(self.opts, self.opts['master_ip'], int(self.auth.creds['publish_port']),),
kwargs={'io_loop': self.io_loop,
'connect_callback': self.connect_callback,
'disconnect_callback': self.disconnect_callback})
yield self.message_client.connect() # wait for the client to be connected
self.connected = True
# TODO: better exception handling...
@ -764,6 +763,42 @@ class TCPClientKeepAlive(tornado.tcpclient.TCPClient):
return stream.connect(addr)
class SaltMessageClientPool(salt.transport.MessageClientPool):
'''
Wrapper class of SaltMessageClient to avoid blocking waiting while writing data to socket.
'''
def __init__(self, opts, args=None, kwargs=None):
super(SaltMessageClientPool, self).__init__(SaltMessageClient, opts, args=args, kwargs=kwargs)
def __del__(self):
self.close()
def close(self):
for message_client in self.message_clients:
message_client.close()
@tornado.gen.coroutine
def connect(self):
futures = []
for message_client in self.message_clients:
futures.append(message_client.connect())
for future in futures:
yield future
raise tornado.gen.Return(None)
def on_recv(self, *args, **kwargs):
for message_client in self.message_clients:
message_client.on_recv(*args, **kwargs)
def send(self, *args, **kwargs):
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
return message_clients[0].send(*args, **kwargs)
def write_to_stream(self, *args, **kwargs):
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
return message_clients[0]._stream.write(*args, **kwargs)
# TODO consolidate with IPCClient
# TODO: limit in-flight messages.
# TODO: singleton? Something to not re-create the tcp connection so much

View file

@ -118,8 +118,9 @@ class AsyncZeroMQReqChannel(salt.transport.client.ReqChannel):
# copied. The reason is the same as the io_loop skip above.
setattr(result, key,
AsyncReqMessageClientPool(result.opts,
self.master_uri,
io_loop=result._io_loop))
args=(result.opts, self.master_uri,),
kwargs={'io_loop': self._io_loop}))
continue
setattr(result, key, copy.deepcopy(self.__dict__[key], memo))
return result
@ -156,9 +157,8 @@ class AsyncZeroMQReqChannel(salt.transport.client.ReqChannel):
# we don't need to worry about auth as a kwarg, since its a singleton
self.auth = salt.crypt.AsyncAuth(self.opts, io_loop=self._io_loop)
self.message_client = AsyncReqMessageClientPool(self.opts,
self.master_uri,
io_loop=self._io_loop,
)
args=(self.opts, self.master_uri,),
kwargs={'io_loop': self._io_loop})
def __del__(self):
'''
@ -815,32 +815,23 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
context.term()
# TODO: unit tests!
class AsyncReqMessageClientPool(object):
def __init__(self, opts, addr, linger=0, io_loop=None, socket_pool=1):
self.opts = opts
self.addr = addr
self.linger = linger
self.io_loop = io_loop
self.socket_pool = socket_pool
self.message_clients = []
def destroy(self):
for message_client in self.message_clients:
message_client.destroy()
self.message_clients = []
class AsyncReqMessageClientPool(salt.transport.MessageClientPool):
'''
Wrapper class of AsyncReqMessageClientPool to avoid blocking waiting while writing data to socket.
'''
def __init__(self, opts, args=None, kwargs=None):
super(AsyncReqMessageClientPool, self).__init__(AsyncReqMessageClient, opts, args=args, kwargs=kwargs)
def __del__(self):
self.destroy()
def send(self, message, timeout=None, tries=3, future=None, callback=None, raw=False):
if len(self.message_clients) < self.socket_pool:
message_client = AsyncReqMessageClient(self.opts, self.addr, self.linger, self.io_loop)
self.message_clients.append(message_client)
return message_client.send(message, timeout, tries, future, callback, raw)
else:
available_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
return available_clients[0].send(message, timeout, tries, future, callback, raw)
def destroy(self):
for message_client in self.message_clients:
message_client.destroy()
def send(self, *args, **kwargs):
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
return message_clients[0].send(*args, **kwargs)
# TODO: unit tests!

View file

@ -10,7 +10,8 @@ import threading
import tornado.gen
import tornado.ioloop
from tornado.testing import AsyncTestCase
import tornado.concurrent
from tornado.testing import AsyncTestCase, gen_test
import salt.config
import salt.ext.six as six
@ -18,9 +19,12 @@ import salt.utils
import salt.transport.server
import salt.transport.client
import salt.exceptions
from salt.ext.six.moves import range
from salt.transport.tcp import SaltMessageClientPool
# Import Salt Testing libs
from salttesting import TestCase, skipIf
from salttesting.mock import MagicMock, patch
from salttesting.helpers import ensure_in_syspath
ensure_in_syspath('../')
import integration
@ -199,6 +203,81 @@ class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
Tests around the publish system
'''
class SaltMessageClientPoolTest(AsyncTestCase):
def setUp(self):
super(SaltMessageClientPoolTest, self).setUp()
sock_pool_size = 5
with patch('salt.transport.tcp.SaltMessageClient.__init__', MagicMock(return_value=None)):
self.message_client_pool = SaltMessageClientPool({'sock_pool_size': sock_pool_size},
args=({}, '', 0))
self.original_message_clients = self.message_client_pool.message_clients
self.message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
def tearDown(self):
with patch('salt.transport.tcp.SaltMessageClient.close', MagicMock(return_value=None)):
del self.original_message_clients
super(SaltMessageClientPoolTest, self).tearDown()
def test_send(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
self.assertEqual([], self.message_client_pool.send())
self.message_client_pool.message_clients[2].send_queue = [0]
self.message_client_pool.message_clients[2].send.return_value = [1]
self.assertEqual([1], self.message_client_pool.send())
def test_write_to_stream(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock._stream.write.return_value = []
self.assertEqual([], self.message_client_pool.write_to_stream(''))
self.message_client_pool.message_clients[2].send_queue = [0]
self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
self.assertEqual([1], self.message_client_pool.write_to_stream(''))
def test_close(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.close.return_value = None
self.message_client_pool.close()
for message_client_mock in self.message_client_pool.message_clients:
self.assertTrue(message_client_mock.close.called)
def test_on_recv(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.on_recv.return_value = None
self.message_client_pool.on_recv()
for message_client_mock in self.message_client_pool.message_clients:
self.assertTrue(message_client_mock.on_recv.called)
def test_connect_all(self):
@gen_test
def test_connect(self):
yield self.message_client_pool.connect()
for message_client_mock in self.message_client_pool.message_clients:
future = tornado.concurrent.Future()
future.set_result('foo')
message_client_mock.connect.return_value = future
self.assertIsNone(test_connect(self))
def test_connect_partial(self):
@gen_test(timeout=0.1)
def test_connect(self):
yield self.message_client_pool.connect()
for idx, message_client_mock in enumerate(self.message_client_pool.message_clients):
future = tornado.concurrent.Future()
if idx % 2 == 0:
future.set_result('foo')
message_client_mock.connect.return_value = future
with self.assertRaises(tornado.ioloop.TimeoutError):
test_connect(self)
if __name__ == '__main__':
from integration import run_tests
run_tests(ClearReqTestCases, needs_daemon=False)

View file

@ -24,9 +24,12 @@ import salt.utils
import salt.transport.server
import salt.transport.client
import salt.exceptions
from salt.ext.six.moves import range
from salt.transport.zeromq import AsyncReqMessageClientPool
# Import Salt Testing libs
from salttesting import TestCase, skipIf
from salttesting.mock import MagicMock, patch
from salttesting.helpers import ensure_in_syspath
ensure_in_syspath('../')
@ -223,6 +226,41 @@ class AsyncPubChannelTest(BaseZMQPubCase, PubChannelMixin):
return zmq.eventloop.ioloop.ZMQIOLoop()
class AsyncReqMessageClientPoolTest(TestCase):
def setUp(self):
super(AsyncReqMessageClientPoolTest, self).setUp()
sock_pool_size = 5
with patch('salt.transport.zeromq.AsyncReqMessageClient.__init__', MagicMock(return_value=None)):
self.message_client_pool = AsyncReqMessageClientPool({'sock_pool_size': sock_pool_size},
args=({}, ''))
self.original_message_clients = self.message_client_pool.message_clients
self.message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
def tearDown(self):
with patch('salt.transport.zeromq.AsyncReqMessageClient.destroy', MagicMock(return_value=None)):
del self.original_message_clients
super(AsyncReqMessageClientPoolTest, self).tearDown()
def test_send(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.send_queue = [0, 0, 0]
message_client_mock.send.return_value = []
self.assertEqual([], self.message_client_pool.send())
self.message_client_pool.message_clients[2].send_queue = [0]
self.message_client_pool.message_clients[2].send.return_value = [1]
self.assertEqual([1], self.message_client_pool.send())
def test_destroy(self):
for message_client_mock in self.message_client_pool.message_clients:
message_client_mock.destroy.return_value = None
self.message_client_pool.destroy()
for message_client_mock in self.message_client_pool.message_clients:
self.assertTrue(message_client_mock.destroy.called)
if __name__ == '__main__':
from integration import run_tests
run_tests(ClearReqTestCases, needs_daemon=False)

View file

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Import python libs
from __future__ import absolute_import
import logging
from salt.transport import MessageClientPool
# Import Salt Testing libs
from salttesting import TestCase
from salttesting.helpers import ensure_in_syspath
log = logging.getLogger(__name__)
ensure_in_syspath('../')
class MessageClientPoolTest(TestCase):
class MockClass(object):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def test_init(self):
opts = {'sock_pool_size': 10}
args = (0,)
kwargs = {'kwarg': 1}
message_client_pool = MessageClientPool(self.MockClass, opts, args=args, kwargs=kwargs)
self.assertEqual(opts['sock_pool_size'], len(message_client_pool.message_clients))
for message_client in message_client_pool.message_clients:
self.assertEqual(message_client.args, args)
self.assertEqual(message_client.kwargs, kwargs)
def test_init_without_config(self):
opts = {}
args = (0,)
kwargs = {'kwarg': 1}
message_client_pool = MessageClientPool(self.MockClass, opts, args=args, kwargs=kwargs)
# The size of pool is set as 1 by the MessageClientPool init method.
self.assertEqual(1, len(message_client_pool.message_clients))
for message_client in message_client_pool.message_clients:
self.assertEqual(message_client.args, args)
self.assertEqual(message_client.kwargs, kwargs)
def test_init_less_than_one(self):
opts = {'sock_pool_size': -1}
args = (0,)
kwargs = {'kwarg': 1}
message_client_pool = MessageClientPool(self.MockClass, opts, args=args, kwargs=kwargs)
# The size of pool is set as 1 by the MessageClientPool init method.
self.assertEqual(1, len(message_client_pool.message_clients))
for message_client in message_client_pool.message_clients:
self.assertEqual(message_client.args, args)
self.assertEqual(message_client.kwargs, kwargs)