mirror of
https://github.com/saltstack/salt.git
synced 2025-04-17 10:10:20 +00:00
Implements MessageClientPool to avoid blocking waiting for zeromq and tcp communications.
This commit is contained in:
parent
a10f0146a4
commit
94b9ea51eb
8 changed files with 288 additions and 38 deletions
|
@ -785,6 +785,26 @@ what you are doing! Transports are explained in :ref:`Salt Transports
|
|||
ret_port: 4606
|
||||
zeromq: []
|
||||
|
||||
``sock_pool_size``
|
||||
------------------
|
||||
|
||||
Default: 1
|
||||
|
||||
To avoid blocking waiting while writing a data to a socket, we support
|
||||
socket pool for Salt applications. For example, a job with a large number
|
||||
of target host list can cause long period blocking waiting. The option
|
||||
is used by ZMQ and TCP transports, and the other transport methods don't
|
||||
need the socket pool by definition. Most of Salt tools, including CLI,
|
||||
are enough to use a single bucket of socket pool. On the other hands,
|
||||
it is highly recommended to set the size of socket pool larger than 1
|
||||
for other Salt applications, especially Salt API, which must write data
|
||||
to socket concurrently.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
sock_pool_size: 15
|
||||
|
||||
|
||||
Salt-SSH Configuration
|
||||
======================
|
||||
|
||||
|
|
|
@ -199,6 +199,9 @@ VALID_OPTS = {
|
|||
# The directory containing unix sockets for things like the event bus
|
||||
'sock_dir': str,
|
||||
|
||||
# The pool size of unix sockets, it is necessary to avoid blocking waiting for zeromq and tcp communications.
|
||||
'sock_pool_size': int,
|
||||
|
||||
# Specifies how the file server should backup files, if enabled. The backups
|
||||
# live in the cache dir.
|
||||
'backup_mode': str,
|
||||
|
@ -989,6 +992,7 @@ DEFAULT_MINION_OPTS = {
|
|||
'grains_deep_merge': False,
|
||||
'conf_file': os.path.join(salt.syspaths.CONFIG_DIR, 'minion'),
|
||||
'sock_dir': os.path.join(salt.syspaths.SOCK_DIR, 'minion'),
|
||||
'sock_pool_size': 1,
|
||||
'backup_mode': '',
|
||||
'renderer': 'yaml_jinja',
|
||||
'renderer_whitelist': [],
|
||||
|
@ -1212,6 +1216,7 @@ DEFAULT_MASTER_OPTS = {
|
|||
'user': salt.utils.get_user(),
|
||||
'worker_threads': 5,
|
||||
'sock_dir': os.path.join(salt.syspaths.SOCK_DIR, 'master'),
|
||||
'sock_pool_size': 1,
|
||||
'ret_port': 4506,
|
||||
'timeout': 5,
|
||||
'keep_jobs': 24,
|
||||
|
@ -2053,6 +2058,7 @@ def syndic_config(master_config_path,
|
|||
'sock_dir': os.path.join(
|
||||
opts['cachedir'], opts.get('syndic_sock_dir', opts['sock_dir'])
|
||||
),
|
||||
'sock_pool_size': master_opts['sock_pool_size'],
|
||||
'cachedir': master_opts['cachedir'],
|
||||
}
|
||||
opts.update(syndic_opts)
|
||||
|
|
|
@ -3,9 +3,13 @@
|
|||
Encapsulate the different transports available to Salt.
|
||||
'''
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
|
||||
# Import third party libs
|
||||
import salt.ext.six as six
|
||||
from salt.ext.six.moves import range
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def iter_transport_opts(opts):
|
||||
|
@ -47,3 +51,25 @@ class Channel(object):
|
|||
# salt.transport.channel.Channel.factory()
|
||||
from salt.transport.client import ReqChannel
|
||||
return ReqChannel.factory(opts, **kwargs)
|
||||
|
||||
|
||||
class MessageClientPool(object):
|
||||
def __init__(self, tgt, opts, args=None, kwargs=None):
|
||||
sock_pool_size = opts['sock_pool_size'] if 'sock_pool_size' in opts else 1
|
||||
if sock_pool_size < 1:
|
||||
log.warn('sock_pool_size is not correctly set, \
|
||||
the option should be greater than 0 but, {0}'.format(sock_pool_size))
|
||||
sock_pool_size = 1
|
||||
|
||||
if args is None:
|
||||
args = ()
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
self.message_clients = [tgt(*args, **kwargs) for _ in range(sock_pool_size)]
|
||||
|
||||
def __del__(self):
|
||||
while self.message_clients:
|
||||
message_client = self.message_clients.pop()
|
||||
if message_client is not None:
|
||||
del message_client
|
||||
|
|
|
@ -267,9 +267,9 @@ class AsyncTCPReqChannel(salt.transport.client.ReqChannel):
|
|||
host, port = parse.netloc.rsplit(':', 1)
|
||||
self.master_addr = (host, int(port))
|
||||
self._closing = False
|
||||
self.message_client = SaltMessageClient(
|
||||
self.opts, host, int(port), io_loop=self.io_loop,
|
||||
resolver=resolver)
|
||||
self.message_client = SaltMessageClientPool(self.opts,
|
||||
args=(self.opts, host, int(port),),
|
||||
kwargs={'io_loop': self.io_loop, 'resolver': resolver})
|
||||
|
||||
def close(self):
|
||||
if self._closing:
|
||||
|
@ -404,7 +404,7 @@ class AsyncTCPPubChannel(salt.transport.mixins.auth.AESPubClientMixin, salt.tran
|
|||
def _do_transfer():
|
||||
msg = self._package_load(self.auth.crypticle.dumps(load))
|
||||
package = salt.transport.frame.frame_msg(msg, header=None)
|
||||
yield self.message_client._stream.write(package)
|
||||
yield self.message_client.write_to_stream(package)
|
||||
raise tornado.gen.Return(True)
|
||||
|
||||
if force_auth or not self.auth.authenticated:
|
||||
|
@ -494,13 +494,12 @@ class AsyncTCPPubChannel(salt.transport.mixins.auth.AESPubClientMixin, salt.tran
|
|||
if not self.auth.authenticated:
|
||||
yield self.auth.authenticate()
|
||||
if self.auth.authenticated:
|
||||
self.message_client = SaltMessageClient(
|
||||
self.message_client = SaltMessageClientPool(
|
||||
self.opts,
|
||||
self.opts['master_ip'],
|
||||
int(self.auth.creds['publish_port']),
|
||||
io_loop=self.io_loop,
|
||||
connect_callback=self.connect_callback,
|
||||
disconnect_callback=self.disconnect_callback)
|
||||
args=(self.opts, self.opts['master_ip'], int(self.auth.creds['publish_port']),),
|
||||
kwargs={'io_loop': self.io_loop,
|
||||
'connect_callback': self.connect_callback,
|
||||
'disconnect_callback': self.disconnect_callback})
|
||||
yield self.message_client.connect() # wait for the client to be connected
|
||||
self.connected = True
|
||||
# TODO: better exception handling...
|
||||
|
@ -764,6 +763,42 @@ class TCPClientKeepAlive(tornado.tcpclient.TCPClient):
|
|||
return stream.connect(addr)
|
||||
|
||||
|
||||
class SaltMessageClientPool(salt.transport.MessageClientPool):
|
||||
'''
|
||||
Wrapper class of SaltMessageClient to avoid blocking waiting while writing data to socket.
|
||||
'''
|
||||
def __init__(self, opts, args=None, kwargs=None):
|
||||
super(SaltMessageClientPool, self).__init__(SaltMessageClient, opts, args=args, kwargs=kwargs)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
for message_client in self.message_clients:
|
||||
message_client.close()
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def connect(self):
|
||||
futures = []
|
||||
for message_client in self.message_clients:
|
||||
futures.append(message_client.connect())
|
||||
for future in futures:
|
||||
yield future
|
||||
raise tornado.gen.Return(None)
|
||||
|
||||
def on_recv(self, *args, **kwargs):
|
||||
for message_client in self.message_clients:
|
||||
message_client.on_recv(*args, **kwargs)
|
||||
|
||||
def send(self, *args, **kwargs):
|
||||
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
|
||||
return message_clients[0].send(*args, **kwargs)
|
||||
|
||||
def write_to_stream(self, *args, **kwargs):
|
||||
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
|
||||
return message_clients[0]._stream.write(*args, **kwargs)
|
||||
|
||||
|
||||
# TODO consolidate with IPCClient
|
||||
# TODO: limit in-flight messages.
|
||||
# TODO: singleton? Something to not re-create the tcp connection so much
|
||||
|
|
|
@ -118,8 +118,9 @@ class AsyncZeroMQReqChannel(salt.transport.client.ReqChannel):
|
|||
# copied. The reason is the same as the io_loop skip above.
|
||||
setattr(result, key,
|
||||
AsyncReqMessageClientPool(result.opts,
|
||||
self.master_uri,
|
||||
io_loop=result._io_loop))
|
||||
args=(result.opts, self.master_uri,),
|
||||
kwargs={'io_loop': self._io_loop}))
|
||||
|
||||
continue
|
||||
setattr(result, key, copy.deepcopy(self.__dict__[key], memo))
|
||||
return result
|
||||
|
@ -156,9 +157,8 @@ class AsyncZeroMQReqChannel(salt.transport.client.ReqChannel):
|
|||
# we don't need to worry about auth as a kwarg, since its a singleton
|
||||
self.auth = salt.crypt.AsyncAuth(self.opts, io_loop=self._io_loop)
|
||||
self.message_client = AsyncReqMessageClientPool(self.opts,
|
||||
self.master_uri,
|
||||
io_loop=self._io_loop,
|
||||
)
|
||||
args=(self.opts, self.master_uri,),
|
||||
kwargs={'io_loop': self._io_loop})
|
||||
|
||||
def __del__(self):
|
||||
'''
|
||||
|
@ -815,32 +815,23 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
|
|||
context.term()
|
||||
|
||||
|
||||
# TODO: unit tests!
|
||||
class AsyncReqMessageClientPool(object):
|
||||
def __init__(self, opts, addr, linger=0, io_loop=None, socket_pool=1):
|
||||
self.opts = opts
|
||||
self.addr = addr
|
||||
self.linger = linger
|
||||
self.io_loop = io_loop
|
||||
self.socket_pool = socket_pool
|
||||
self.message_clients = []
|
||||
|
||||
def destroy(self):
|
||||
for message_client in self.message_clients:
|
||||
message_client.destroy()
|
||||
self.message_clients = []
|
||||
class AsyncReqMessageClientPool(salt.transport.MessageClientPool):
|
||||
'''
|
||||
Wrapper class of AsyncReqMessageClientPool to avoid blocking waiting while writing data to socket.
|
||||
'''
|
||||
def __init__(self, opts, args=None, kwargs=None):
|
||||
super(AsyncReqMessageClientPool, self).__init__(AsyncReqMessageClient, opts, args=args, kwargs=kwargs)
|
||||
|
||||
def __del__(self):
|
||||
self.destroy()
|
||||
|
||||
def send(self, message, timeout=None, tries=3, future=None, callback=None, raw=False):
|
||||
if len(self.message_clients) < self.socket_pool:
|
||||
message_client = AsyncReqMessageClient(self.opts, self.addr, self.linger, self.io_loop)
|
||||
self.message_clients.append(message_client)
|
||||
return message_client.send(message, timeout, tries, future, callback, raw)
|
||||
else:
|
||||
available_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
|
||||
return available_clients[0].send(message, timeout, tries, future, callback, raw)
|
||||
def destroy(self):
|
||||
for message_client in self.message_clients:
|
||||
message_client.destroy()
|
||||
|
||||
def send(self, *args, **kwargs):
|
||||
message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue))
|
||||
return message_clients[0].send(*args, **kwargs)
|
||||
|
||||
|
||||
# TODO: unit tests!
|
||||
|
|
|
@ -10,7 +10,8 @@ import threading
|
|||
|
||||
import tornado.gen
|
||||
import tornado.ioloop
|
||||
from tornado.testing import AsyncTestCase
|
||||
import tornado.concurrent
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
|
||||
import salt.config
|
||||
import salt.ext.six as six
|
||||
|
@ -18,9 +19,12 @@ import salt.utils
|
|||
import salt.transport.server
|
||||
import salt.transport.client
|
||||
import salt.exceptions
|
||||
from salt.ext.six.moves import range
|
||||
from salt.transport.tcp import SaltMessageClientPool
|
||||
|
||||
# Import Salt Testing libs
|
||||
from salttesting import TestCase, skipIf
|
||||
from salttesting.mock import MagicMock, patch
|
||||
from salttesting.helpers import ensure_in_syspath
|
||||
ensure_in_syspath('../')
|
||||
import integration
|
||||
|
@ -199,6 +203,81 @@ class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
|
|||
Tests around the publish system
|
||||
'''
|
||||
|
||||
|
||||
class SaltMessageClientPoolTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(SaltMessageClientPoolTest, self).setUp()
|
||||
sock_pool_size = 5
|
||||
with patch('salt.transport.tcp.SaltMessageClient.__init__', MagicMock(return_value=None)):
|
||||
self.message_client_pool = SaltMessageClientPool({'sock_pool_size': sock_pool_size},
|
||||
args=({}, '', 0))
|
||||
self.original_message_clients = self.message_client_pool.message_clients
|
||||
self.message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
|
||||
|
||||
def tearDown(self):
|
||||
with patch('salt.transport.tcp.SaltMessageClient.close', MagicMock(return_value=None)):
|
||||
del self.original_message_clients
|
||||
super(SaltMessageClientPoolTest, self).tearDown()
|
||||
|
||||
def test_send(self):
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
message_client_mock.send_queue = [0, 0, 0]
|
||||
message_client_mock.send.return_value = []
|
||||
self.assertEqual([], self.message_client_pool.send())
|
||||
self.message_client_pool.message_clients[2].send_queue = [0]
|
||||
self.message_client_pool.message_clients[2].send.return_value = [1]
|
||||
self.assertEqual([1], self.message_client_pool.send())
|
||||
|
||||
def test_write_to_stream(self):
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
message_client_mock.send_queue = [0, 0, 0]
|
||||
message_client_mock._stream.write.return_value = []
|
||||
self.assertEqual([], self.message_client_pool.write_to_stream(''))
|
||||
self.message_client_pool.message_clients[2].send_queue = [0]
|
||||
self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
|
||||
self.assertEqual([1], self.message_client_pool.write_to_stream(''))
|
||||
|
||||
def test_close(self):
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
message_client_mock.close.return_value = None
|
||||
self.message_client_pool.close()
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
self.assertTrue(message_client_mock.close.called)
|
||||
|
||||
def test_on_recv(self):
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
message_client_mock.on_recv.return_value = None
|
||||
self.message_client_pool.on_recv()
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
self.assertTrue(message_client_mock.on_recv.called)
|
||||
|
||||
def test_connect_all(self):
|
||||
@gen_test
|
||||
def test_connect(self):
|
||||
yield self.message_client_pool.connect()
|
||||
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
future = tornado.concurrent.Future()
|
||||
future.set_result('foo')
|
||||
message_client_mock.connect.return_value = future
|
||||
|
||||
self.assertIsNone(test_connect(self))
|
||||
|
||||
def test_connect_partial(self):
|
||||
@gen_test(timeout=0.1)
|
||||
def test_connect(self):
|
||||
yield self.message_client_pool.connect()
|
||||
|
||||
for idx, message_client_mock in enumerate(self.message_client_pool.message_clients):
|
||||
future = tornado.concurrent.Future()
|
||||
if idx % 2 == 0:
|
||||
future.set_result('foo')
|
||||
message_client_mock.connect.return_value = future
|
||||
|
||||
with self.assertRaises(tornado.ioloop.TimeoutError):
|
||||
test_connect(self)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from integration import run_tests
|
||||
run_tests(ClearReqTestCases, needs_daemon=False)
|
||||
|
|
|
@ -24,9 +24,12 @@ import salt.utils
|
|||
import salt.transport.server
|
||||
import salt.transport.client
|
||||
import salt.exceptions
|
||||
from salt.ext.six.moves import range
|
||||
from salt.transport.zeromq import AsyncReqMessageClientPool
|
||||
|
||||
# Import Salt Testing libs
|
||||
from salttesting import TestCase, skipIf
|
||||
from salttesting.mock import MagicMock, patch
|
||||
from salttesting.helpers import ensure_in_syspath
|
||||
ensure_in_syspath('../')
|
||||
|
||||
|
@ -223,6 +226,41 @@ class AsyncPubChannelTest(BaseZMQPubCase, PubChannelMixin):
|
|||
return zmq.eventloop.ioloop.ZMQIOLoop()
|
||||
|
||||
|
||||
class AsyncReqMessageClientPoolTest(TestCase):
|
||||
def setUp(self):
|
||||
super(AsyncReqMessageClientPoolTest, self).setUp()
|
||||
sock_pool_size = 5
|
||||
with patch('salt.transport.zeromq.AsyncReqMessageClient.__init__', MagicMock(return_value=None)):
|
||||
self.message_client_pool = AsyncReqMessageClientPool({'sock_pool_size': sock_pool_size},
|
||||
args=({}, ''))
|
||||
self.original_message_clients = self.message_client_pool.message_clients
|
||||
self.message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
|
||||
|
||||
def tearDown(self):
|
||||
with patch('salt.transport.zeromq.AsyncReqMessageClient.destroy', MagicMock(return_value=None)):
|
||||
del self.original_message_clients
|
||||
super(AsyncReqMessageClientPoolTest, self).tearDown()
|
||||
|
||||
def test_send(self):
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
message_client_mock.send_queue = [0, 0, 0]
|
||||
message_client_mock.send.return_value = []
|
||||
|
||||
self.assertEqual([], self.message_client_pool.send())
|
||||
|
||||
self.message_client_pool.message_clients[2].send_queue = [0]
|
||||
self.message_client_pool.message_clients[2].send.return_value = [1]
|
||||
self.assertEqual([1], self.message_client_pool.send())
|
||||
|
||||
def test_destroy(self):
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
message_client_mock.destroy.return_value = None
|
||||
|
||||
self.message_client_pool.destroy()
|
||||
for message_client_mock in self.message_client_pool.message_clients:
|
||||
self.assertTrue(message_client_mock.destroy.called)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from integration import run_tests
|
||||
run_tests(ClearReqTestCases, needs_daemon=False)
|
||||
|
|
55
tests/unit/transport_test.py
Normal file
55
tests/unit/transport_test.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Import python libs
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
|
||||
from salt.transport import MessageClientPool
|
||||
|
||||
# Import Salt Testing libs
|
||||
from salttesting import TestCase
|
||||
from salttesting.helpers import ensure_in_syspath
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ensure_in_syspath('../')
|
||||
|
||||
|
||||
class MessageClientPoolTest(TestCase):
|
||||
|
||||
class MockClass(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def test_init(self):
|
||||
opts = {'sock_pool_size': 10}
|
||||
args = (0,)
|
||||
kwargs = {'kwarg': 1}
|
||||
message_client_pool = MessageClientPool(self.MockClass, opts, args=args, kwargs=kwargs)
|
||||
self.assertEqual(opts['sock_pool_size'], len(message_client_pool.message_clients))
|
||||
for message_client in message_client_pool.message_clients:
|
||||
self.assertEqual(message_client.args, args)
|
||||
self.assertEqual(message_client.kwargs, kwargs)
|
||||
|
||||
def test_init_without_config(self):
|
||||
opts = {}
|
||||
args = (0,)
|
||||
kwargs = {'kwarg': 1}
|
||||
message_client_pool = MessageClientPool(self.MockClass, opts, args=args, kwargs=kwargs)
|
||||
# The size of pool is set as 1 by the MessageClientPool init method.
|
||||
self.assertEqual(1, len(message_client_pool.message_clients))
|
||||
for message_client in message_client_pool.message_clients:
|
||||
self.assertEqual(message_client.args, args)
|
||||
self.assertEqual(message_client.kwargs, kwargs)
|
||||
|
||||
def test_init_less_than_one(self):
|
||||
opts = {'sock_pool_size': -1}
|
||||
args = (0,)
|
||||
kwargs = {'kwarg': 1}
|
||||
message_client_pool = MessageClientPool(self.MockClass, opts, args=args, kwargs=kwargs)
|
||||
# The size of pool is set as 1 by the MessageClientPool init method.
|
||||
self.assertEqual(1, len(message_client_pool.message_clients))
|
||||
for message_client in message_client_pool.message_clients:
|
||||
self.assertEqual(message_client.args, args)
|
||||
self.assertEqual(message_client.kwargs, kwargs)
|
Loading…
Add table
Reference in a new issue