From de97822be51c183cc6f6d60ac9126e7a22c21676 Mon Sep 17 00:00:00 2001 From: Cesar Augusto Sanchez Date: Tue, 16 Aug 2022 23:58:44 -0400 Subject: [PATCH] fix: update cassandra_cql module to allow execution from CLI. - add load_balancing_policy argument --- changelog/59909.fixed | 1 + salt/modules/cassandra_cql.py | 548 ++++++++++++++++-- .../unit/modules/test_cassandra_cql.py | 215 +++++++ 3 files changed, 708 insertions(+), 56 deletions(-) create mode 100644 changelog/59909.fixed create mode 100644 tests/pytests/unit/modules/test_cassandra_cql.py diff --git a/changelog/59909.fixed b/changelog/59909.fixed new file mode 100644 index 00000000000..b16c524c7a9 --- /dev/null +++ b/changelog/59909.fixed @@ -0,0 +1 @@ +add load balancing policy default option and ensure the module can be executed with arguments from CLI diff --git a/salt/modules/cassandra_cql.py b/salt/modules/cassandra_cql.py index d18689f39bf..8b09fb2c818 100644 --- a/salt/modules/cassandra_cql.py +++ b/salt/modules/cassandra_cql.py @@ -77,12 +77,27 @@ queries based on the internal schema of said version. # defaults to 4, if not set protocol_version: 3 + Also all configuration could be passed directly to module as arguments. + + .. code-block:: bash + + salt minion1 cassandra_cql.info contact_points=delme-nextgen-01 port=9042 cql_user=cassandra cql_pass=cassandra protocol_version=4 + + salt minion1 cassandra_cql.info ssl_options='{"ca_certs": /path/to/-ca.crt}' + + We can also provide the load balancing policy as arguments + + .. code-block:: bash + + salt minion1 cassandra_cql.cql_query "alter user cassandra with password 'cassandra2' ;" contact_points=scylladb cql_user=user1 cql_pass=password port=9142 protocol_version=4 ssl_options='{"ca_certs": path-to-client-ca.crt}' load_balancing_policy=DCAwareRoundRobinPolicy load_balancing_policy_args='{"local_dc": "datacenter1"}' + """ import logging import re import ssl +import salt.loader.context import salt.utils.json import salt.utils.versions from salt.exceptions import CommandExecutionError @@ -103,6 +118,21 @@ try: ConnectionShutdown, OperationTimedOut, ) + from cassandra.policies import ( + DCAwareRoundRobinPolicy, + ExponentialReconnectionPolicy, + HostDistance, + HostFilterPolicy, + IdentityTranslator, + LoadBalancingPolicy, + NoSpeculativeExecutionPlan, + NoSpeculativeExecutionPolicy, + RetryPolicy, + RoundRobinPolicy, + SimpleConvictionPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, + ) from cassandra.query import dict_factory # pylint: enable=import-error,no-name-in-module @@ -110,6 +140,25 @@ try: except ImportError: pass +__salt_loader__ = salt.loader.context.LoaderContext() +__context__ = __salt_loader__.named_context("__context__", {}) + +LOAD_BALANCING_POLICY_MAP = { + "HostDistance": HostDistance, + "LoadBalancingPolicy": LoadBalancingPolicy, + "RoundRobinPolicy": RoundRobinPolicy, + "DCAwareRoundRobinPolicy": DCAwareRoundRobinPolicy, + "WhiteListRoundRobinPolicy": WhiteListRoundRobinPolicy, + "TokenAwarePolicy": TokenAwarePolicy, + "HostFilterPolicy": HostFilterPolicy, + "SimpleConvictionPolicy": SimpleConvictionPolicy, + "ExponentialReconnectionPolicy": ExponentialReconnectionPolicy, + "RetryPolicy": RetryPolicy, + "IdentityTranslator": IdentityTranslator, + "NoSpeculativeExecutionPlan": NoSpeculativeExecutionPlan, + "NoSpeculativeExecutionPolicy": NoSpeculativeExecutionPolicy, +} + def __virtual__(): """ @@ -127,6 +176,16 @@ def _async_log_errors(errors): log.error("Cassandra_cql asynchronous call returned: %s", errors) +def _get_lbp_policy(name, **policy_args): + """ + Returns the Load Balancer Policy class by name + """ + if name in LOAD_BALANCING_POLICY_MAP: + return LOAD_BALANCING_POLICY_MAP.get(name)(**policy_args) + else: + log.error("The policy %s is not available", name) + + def _load_properties(property_name, config_option, set_default=False, default=None): """ Load properties for the cassandra module from config or pillar. @@ -147,7 +206,7 @@ def _load_properties(property_name, config_option, set_default=False, default=No "No property specified in function, trying to load from salt configuration" ) try: - options = __salt__["config.option"]("cassandra") + options = __salt__["config.option"]("cassandra", default={}) except BaseException as e: log.error("Failed to get cassandra config options. Reason: %s", e) raise @@ -175,7 +234,9 @@ def _get_ssl_opts(): Parse out ssl_options for Cassandra cluster connection. Make sure that the ssl_version (if any specified) is valid. """ - sslopts = __salt__["config.option"]("cassandra").get("ssl_options", None) + sslopts = __salt__["config.option"]("cassandra", default={}).get( + "ssl_options", None + ) ssl_opts = {} if sslopts: @@ -198,7 +259,14 @@ def _get_ssl_opts(): def _connect( - contact_points=None, port=None, cql_user=None, cql_pass=None, protocol_version=None + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ Connect to a Cassandra cluster. @@ -211,8 +279,14 @@ def _connect( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int - :param protocol_version: Cassandra protocol version to use. - :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The session and cluster objects. :rtype: cluster object, session object """ @@ -239,40 +313,70 @@ def _connect( __context__["cassandra_cql_returner_session"], ) else: - - contact_points = _load_properties( - property_name=contact_points, config_option="cluster" - ) + if contact_points is None: + contact_points = _load_properties( + property_name=contact_points, config_option="cluster" + ) contact_points = ( contact_points if isinstance(contact_points, list) else contact_points.split(",") ) - port = _load_properties( - property_name=port, config_option="port", set_default=True, default=9042 - ) - cql_user = _load_properties( - property_name=cql_user, - config_option="username", - set_default=True, - default="cassandra", - ) - cql_pass = _load_properties( - property_name=cql_pass, - config_option="password", - set_default=True, - default="cassandra", - ) - protocol_version = _load_properties( - property_name=protocol_version, - config_option="protocol_version", - set_default=True, - default=4, - ) + if port is None: + port = _load_properties( + property_name=port, config_option="port", set_default=True, default=9042 + ) + if cql_user is None: + cql_user = _load_properties( + property_name=cql_user, + config_option="username", + set_default=True, + default="cassandra", + ) + if cql_pass is None: + cql_pass = _load_properties( + property_name=cql_pass, + config_option="password", + set_default=True, + default="cassandra", + ) + if protocol_version is None: + protocol_version = _load_properties( + property_name=protocol_version, + config_option="protocol_version", + set_default=True, + default=4, + ) + + if load_balancing_policy_args is None: + load_balancing_policy_args = _load_properties( + property_name=load_balancing_policy_args, + config_option="load_balancing_policy_args", + set_default=True, + default={}, + ) + + if load_balancing_policy is None: + load_balancing_policy = _load_properties( + property_name=load_balancing_policy, + config_option="load_balancing_policy", + set_default=True, + default="RoundRobinPolicy", + ) + + if load_balancing_policy_args: + lbp_policy_cls = _get_lbp_policy( + load_balancing_policy, **load_balancing_policy_args + ) + else: + lbp_policy_cls = _get_lbp_policy(load_balancing_policy) try: auth_provider = PlainTextAuthProvider(username=cql_user, password=cql_pass) - ssl_opts = _get_ssl_opts() + if ssl_options is None: + ssl_opts = _get_ssl_opts() + else: + ssl_opts = ssl_options if ssl_opts: cluster = Cluster( contact_points, @@ -280,6 +384,7 @@ def _connect( auth_provider=auth_provider, ssl_options=ssl_opts, protocol_version=protocol_version, + load_balancing_policy=lbp_policy_cls, compression=True, ) else: @@ -288,6 +393,7 @@ def _connect( port=port, auth_provider=auth_provider, protocol_version=protocol_version, + load_balancing_policy=lbp_policy_cls, compression=True, ) for recontimes in range(1, 4): @@ -312,14 +418,22 @@ def _connect( return cluster, session except TypeError: pass - except (ConnectionException, ConnectionShutdown, NoHostAvailable): + except (ConnectionException, ConnectionShutdown, NoHostAvailable) as err: log.error("Could not connect to Cassandra cluster at %s", contact_points) - raise CommandExecutionError( - "ERROR: Could not connect to Cassandra cluster." - ) + raise CommandExecutionError(str(err)) -def cql_query(query, contact_points=None, port=None, cql_user=None, cql_pass=None): +def cql_query( + query, + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, +): """ Run a query on a Cassandra cluster and return a dictionary. @@ -335,6 +449,14 @@ def cql_query(query, contact_points=None, port=None, cql_user=None, cql_pass=Non :type port: int :param params: The parameters for the query, optional. :type params: str + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: A dictionary from the return values of the query :rtype: list[dict] @@ -350,6 +472,10 @@ def cql_query(query, contact_points=None, port=None, cql_user=None, cql_pass=Non port=port, cql_user=cql_user, cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, ) except CommandExecutionError: log.critical("Could not get Cassandra cluster session.") @@ -416,6 +542,10 @@ def cql_query_with_prepare( port=None, cql_user=None, cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, **kwargs ): """ @@ -448,6 +578,14 @@ def cql_query_with_prepare( :type port: int :param params: The parameters for the query, optional. :type params: str + :param protocol_version: Cassandra protocol version to use. + :type port: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: A dictionary from the return values of the query :rtype: list[dict] @@ -472,6 +610,10 @@ def cql_query_with_prepare( port=port, cql_user=cql_user, cql_pass=cql_pass, + protocol_version=None, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=None, ) except CommandExecutionError: log.critical("Could not get Cassandra cluster session.") @@ -525,7 +667,16 @@ def cql_query_with_prepare( return ret -def version(contact_points=None, port=None, cql_user=None, cql_pass=None): +def version( + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, +): """ Show the Cassandra version. @@ -537,6 +688,14 @@ def version(contact_points=None, port=None, cql_user=None, cql_pass=None): :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The version for this Cassandra cluster. :rtype: str @@ -551,7 +710,17 @@ def version(contact_points=None, port=None, cql_user=None, cql_pass=None): query = "select release_version from system.local limit 1;" try: - ret = cql_query(query, contact_points, port, cql_user, cql_pass) + ret = cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not get Cassandra version.") raise @@ -562,7 +731,16 @@ def version(contact_points=None, port=None, cql_user=None, cql_pass=None): return ret[0].get("release_version") -def info(contact_points=None, port=None, cql_user=None, cql_pass=None): +def info( + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, +): """ Show the Cassandra information for this cluster. @@ -574,6 +752,14 @@ def info(contact_points=None, port=None, cql_user=None, cql_pass=None): :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The information for this Cassandra cluster. :rtype: dict @@ -601,7 +787,17 @@ def info(contact_points=None, port=None, cql_user=None, cql_pass=None): ret = {} try: - ret = cql_query(query, contact_points, port, cql_user, cql_pass) + ret = cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not list Cassandra info.") raise @@ -612,7 +808,16 @@ def info(contact_points=None, port=None, cql_user=None, cql_pass=None): return ret -def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None): +def list_keyspaces( + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, +): """ List keyspaces in a Cassandra cluster. @@ -624,6 +829,14 @@ def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None) :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The keyspaces in this Cassandra cluster. :rtype: list[dict] @@ -643,7 +856,17 @@ def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None) ret = {} try: - ret = cql_query(query, contact_points, port, cql_user, cql_pass) + ret = cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not list keyspaces.") raise @@ -655,7 +878,15 @@ def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None) def list_column_families( - keyspace=None, contact_points=None, port=None, cql_user=None, cql_pass=None + keyspace=None, + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ List column families in a Cassandra cluster for all keyspaces or just the provided one. @@ -670,6 +901,14 @@ def list_column_families( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The column families in this Cassandra cluster. :rtype: list[dict] @@ -695,7 +934,17 @@ def list_column_families( ret = {} try: - ret = cql_query(query, contact_points, port, cql_user, cql_pass) + ret = cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not list column families.") raise @@ -707,7 +956,15 @@ def list_column_families( def keyspace_exists( - keyspace, contact_points=None, port=None, cql_user=None, cql_pass=None + keyspace, + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ Check if a keyspace exists in a Cassandra cluster. @@ -722,6 +979,14 @@ def keyspace_exists( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The info for the keyspace or False if it does not exist. :rtype: dict @@ -743,7 +1008,17 @@ def keyspace_exists( } try: - ret = cql_query(query, contact_points, port, cql_user, cql_pass) + ret = cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not determine if keyspace exists.") raise @@ -763,6 +1038,10 @@ def create_keyspace( port=None, cql_user=None, cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ Create a new keyspace in Cassandra. @@ -784,6 +1063,14 @@ def create_keyspace( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The info for the keyspace or False if it does not exist. :rtype: dict @@ -797,7 +1084,17 @@ def create_keyspace( salt 'minion1' cassandra_cql.create_keyspace keyspace=newkeyspace replication_strategy=NetworkTopologyStrategy \ replication_datacenters='{"datacenter_1": 3, "datacenter_2": 2}' """ - existing_keyspace = keyspace_exists(keyspace, contact_points, port) + existing_keyspace = keyspace_exists( + keyspace, + contact_points=contact_points, + cql_user=cql_user, + cql_pass=cql_pass, + port=port, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) if not existing_keyspace: # Add the strategy, replication_factor, etc. replication_map = {"class": replication_strategy} @@ -822,7 +1119,17 @@ def create_keyspace( ) try: - cql_query(query, contact_points, port, cql_user, cql_pass) + cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not create keyspace.") raise @@ -832,7 +1139,15 @@ def create_keyspace( def drop_keyspace( - keyspace, contact_points=None, port=None, cql_user=None, cql_pass=None + keyspace, + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ Drop a keyspace if it exists in a Cassandra cluster. @@ -847,6 +1162,14 @@ def drop_keyspace( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The info for the keyspace or False if it does not exist. :rtype: dict @@ -858,11 +1181,31 @@ def drop_keyspace( salt 'minion1' cassandra_cql.drop_keyspace keyspace=test contact_points=minion1 """ - existing_keyspace = keyspace_exists(keyspace, contact_points, port) + existing_keyspace = keyspace_exists( + keyspace, + contact_points=contact_points, + cql_user=cql_user, + cql_pass=cql_pass, + port=port, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) if existing_keyspace: query = """drop keyspace {};""".format(keyspace) try: - cql_query(query, contact_points, port, cql_user, cql_pass) + cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not drop keyspace.") raise @@ -873,7 +1216,16 @@ def drop_keyspace( return True -def list_users(contact_points=None, port=None, cql_user=None, cql_pass=None): +def list_users( + contact_points=None, + port=None, + cql_user=None, + cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, +): """ List existing users in this Cassandra cluster. @@ -885,6 +1237,14 @@ def list_users(contact_points=None, port=None, cql_user=None, cql_pass=None): :type cql_user: str :param cql_pass: The Cassandra user password if authentication is turned on. :type cql_pass: str + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: The list of existing users. :rtype: dict @@ -901,7 +1261,17 @@ def list_users(contact_points=None, port=None, cql_user=None, cql_pass=None): ret = {} try: - ret = cql_query(query, contact_points, port, cql_user, cql_pass) + ret = cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not list users.") raise @@ -920,6 +1290,10 @@ def create_user( port=None, cql_user=None, cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ Create a new cassandra user with credentials and superuser status. @@ -938,6 +1312,14 @@ def create_user( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: :rtype: @@ -964,7 +1346,17 @@ def create_user( # The create user query doesn't actually return anything if the query succeeds. # If the query fails, catch the exception, log a messange and raise it again. try: - cql_query(query, contact_points, port, cql_user, cql_pass) + cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not create user.") raise @@ -984,6 +1376,10 @@ def list_permissions( port=None, cql_user=None, cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ List permissions. @@ -1004,6 +1400,14 @@ def list_permissions( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: Dictionary of permissions. :rtype: dict @@ -1034,7 +1438,17 @@ def list_permissions( ret = {} try: - ret = cql_query(query, contact_points, port, cql_user, cql_pass) + ret = cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not list permissions.") raise @@ -1054,6 +1468,10 @@ def grant_permission( port=None, cql_user=None, cql_pass=None, + protocol_version=None, + load_balancing_policy=None, + load_balancing_policy_args=None, + ssl_options=None, ): """ Grant permissions to a user. @@ -1074,6 +1492,14 @@ def grant_permission( :type cql_pass: str :param port: The Cassandra cluster port, defaults to None. :type port: int + :param protocol_version: Cassandra protocol version to use. + :type protocol_version: int + :param load_balancing_policy: cassandra.policy class name to use + :type load_balancing_policy: str + :param load_balancing_policy_args: cassandra.policy constructor args + :type load_balancing_policy_args: dict + :param ssl_options: Cassandra protocol version to use. + :type ssl_options: dict :return: :rtype: @@ -1098,7 +1524,17 @@ def grant_permission( log.debug("Attempting to grant permissions with query '%s'", query) try: - cql_query(query, contact_points, port, cql_user, cql_pass) + cql_query( + query, + contact_points=contact_points, + port=port, + cql_user=cql_user, + cql_pass=cql_pass, + protocol_version=protocol_version, + load_balancing_policy=load_balancing_policy, + load_balancing_policy_args=load_balancing_policy_args, + ssl_options=ssl_options, + ) except CommandExecutionError: log.critical("Could not grant permissions.") raise diff --git a/tests/pytests/unit/modules/test_cassandra_cql.py b/tests/pytests/unit/modules/test_cassandra_cql.py new file mode 100644 index 00000000000..1845049d56f --- /dev/null +++ b/tests/pytests/unit/modules/test_cassandra_cql.py @@ -0,0 +1,215 @@ +""" +Test case for the cassandra_cql module +""" + + +import logging + +import pytest + +import salt.modules.cassandra_cql as cql +from salt.exceptions import CommandExecutionError +from tests.support.mock import MagicMock, patch + +log = logging.getLogger(__name__) + + +def test_cql_query(caplog): + """ + Test salt.modules.cassandra_cql.cql_query function + """ + + mock_session = MagicMock() + mock_client = MagicMock() + mock = MagicMock(return_value=(mock_session, mock_client)) + query = "query" + with patch.object(cql, "_connect", mock): + query_result = cql.cql_query(query) + + assert query_result == [] + + query = {"5.0.1": "query1", "5.0.0": "query2"} + mock_version = MagicMock(return_value="5.0.1") + mock_session = MagicMock() + mock_client = MagicMock() + mock = MagicMock(return_value=(mock_session, mock_client)) + with patch.object(cql, "version", mock_version): + with patch.object(cql, "_connect", mock): + query_result = cql.cql_query(query) + assert query_result == [] + + +def test_cql_query_with_prepare(caplog): + """ + Test salt.modules.cassandra_cql.cql_query_with_prepare function + """ + + mock_session = MagicMock() + mock_client = MagicMock() + mock = MagicMock(return_value=(mock_session, mock_client)) + query = "query" + statement_args = {"arg1": "test"} + + mock_context = MagicMock( + return_value={"cassandra_cql_prepared": {"statement_name": query}} + ) + with patch.object(cql, "__context__", mock_context): + with patch.object(cql, "_connect", mock): + query_result = cql.cql_query_with_prepare( + query, "statement_name", statement_args + ) + assert query_result == [] + + +def test_version(caplog): + """ + Test salt.modules.cassandra_cql.version function + """ + mock_cql_query = MagicMock(return_value=[{"release_version": "5.0.1"}]) + + with patch.object(cql, "cql_query", mock_cql_query): + version = cql.version() + assert version == "5.0.1" + + mock_cql_query = MagicMock(side_effect=CommandExecutionError) + with pytest.raises(CommandExecutionError) as err: + with patch.object(cql, "cql_query", mock_cql_query): + version = cql.version() + assert "{}".format(err.value) == "" + assert "Could not get Cassandra version." in caplog.text + for record in caplog.records: + assert record.levelname == "CRITICAL" + + +def test_info(): + """ + Test salt.modules.cassandra_cql.info function + """ + expected = {"result": "info"} + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + info = cql.info() + + assert info == expected + + +def test_list_keyspaces(): + """ + Test salt.modules.cassandra_cql.list_keyspaces function + """ + expected = [{"keyspace_name": "name1"}, {"keyspace_name": "name2"}] + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + keyspaces = cql.list_keyspaces() + + assert keyspaces == expected + + +def test_list_column_families(): + """ + Test salt.modules.cassandra_cql.list_column_families function + """ + expected = [{"colum_name": "column1"}, {"column_name": "column2"}] + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + columns = cql.list_column_families() + + assert columns == expected + + +def test_keyspace_exists(): + """ + Test salt.modules.cassandra_cql.keyspace_exists function + """ + expected = "keyspace" + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + exists = cql.keyspace_exists("keyspace") + + assert exists == bool(expected) + + expected = [] + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + exists = cql.keyspace_exists("keyspace") + + assert exists == bool(expected) + + +def test_create_keyspace(): + """ + Test salt.modules.cassandra_cql.create_keyspace function + """ + expected = None + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + result = cql.create_keyspace("keyspace") + + assert result == expected + + +def test_drop_keyspace(): + """ + Test salt.modules.cassandra_cql.drop_keyspace function + """ + expected = True + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + result = cql.drop_keyspace("keyspace") + + assert result == expected + + +def test_list_users(): + """ + Test salt.modules.cassandra_cql.list_users function + """ + expected = [{"name": "user1", "super": True}, {"name": "user2", "super": False}] + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + result = cql.list_users() + + assert result == expected + + +def test_create_user(): + """ + Test salt.modules.cassandra_cql.create_user function + """ + expected = True + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + result = cql.create_user("user", "password") + + assert result == expected + + +def test_list_permissions(): + """ + Test salt.modules.cassandra_cql.list_permissions function + """ + expected = [ + { + "permission": "ALTER", + "resource": "", + "role": "user1", + "username": "user1", + } + ] + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + result = cql.list_permissions(username="user1", resource="one") + + assert result == expected + + +def test_grant_permission(): + """ + Test salt.modules.cassandra_cql.grant_permission function + """ + expected = True + mock_cql_query = MagicMock(return_value=expected) + with patch.object(cql, "cql_query", mock_cql_query): + result = cql.grant_permission(username="user1", resource="one") + + assert result == expected