fix: update cassandra_cql module to allow execution from CLI.

- add load_balancing_policy argument
This commit is contained in:
Cesar Augusto Sanchez 2022-08-16 23:58:44 -04:00 committed by Gareth J. Greenaway
parent 5515989e84
commit de97822be5
3 changed files with 708 additions and 56 deletions

1
changelog/59909.fixed Normal file
View file

@ -0,0 +1 @@
add load balancing policy default option and ensure the module can be executed with arguments from CLI

View file

@ -77,12 +77,27 @@ queries based on the internal schema of said version.
# defaults to 4, if not set
protocol_version: 3
Also all configuration could be passed directly to module as arguments.
.. code-block:: bash
salt minion1 cassandra_cql.info contact_points=delme-nextgen-01 port=9042 cql_user=cassandra cql_pass=cassandra protocol_version=4
salt minion1 cassandra_cql.info ssl_options='{"ca_certs": /path/to/-ca.crt}'
We can also provide the load balancing policy as arguments
.. code-block:: bash
salt minion1 cassandra_cql.cql_query "alter user cassandra with password 'cassandra2' ;" contact_points=scylladb cql_user=user1 cql_pass=password port=9142 protocol_version=4 ssl_options='{"ca_certs": path-to-client-ca.crt}' load_balancing_policy=DCAwareRoundRobinPolicy load_balancing_policy_args='{"local_dc": "datacenter1"}'
"""
import logging
import re
import ssl
import salt.loader.context
import salt.utils.json
import salt.utils.versions
from salt.exceptions import CommandExecutionError
@ -103,6 +118,21 @@ try:
ConnectionShutdown,
OperationTimedOut,
)
from cassandra.policies import (
DCAwareRoundRobinPolicy,
ExponentialReconnectionPolicy,
HostDistance,
HostFilterPolicy,
IdentityTranslator,
LoadBalancingPolicy,
NoSpeculativeExecutionPlan,
NoSpeculativeExecutionPolicy,
RetryPolicy,
RoundRobinPolicy,
SimpleConvictionPolicy,
TokenAwarePolicy,
WhiteListRoundRobinPolicy,
)
from cassandra.query import dict_factory
# pylint: enable=import-error,no-name-in-module
@ -110,6 +140,25 @@ try:
except ImportError:
pass
__salt_loader__ = salt.loader.context.LoaderContext()
__context__ = __salt_loader__.named_context("__context__", {})
LOAD_BALANCING_POLICY_MAP = {
"HostDistance": HostDistance,
"LoadBalancingPolicy": LoadBalancingPolicy,
"RoundRobinPolicy": RoundRobinPolicy,
"DCAwareRoundRobinPolicy": DCAwareRoundRobinPolicy,
"WhiteListRoundRobinPolicy": WhiteListRoundRobinPolicy,
"TokenAwarePolicy": TokenAwarePolicy,
"HostFilterPolicy": HostFilterPolicy,
"SimpleConvictionPolicy": SimpleConvictionPolicy,
"ExponentialReconnectionPolicy": ExponentialReconnectionPolicy,
"RetryPolicy": RetryPolicy,
"IdentityTranslator": IdentityTranslator,
"NoSpeculativeExecutionPlan": NoSpeculativeExecutionPlan,
"NoSpeculativeExecutionPolicy": NoSpeculativeExecutionPolicy,
}
def __virtual__():
"""
@ -127,6 +176,16 @@ def _async_log_errors(errors):
log.error("Cassandra_cql asynchronous call returned: %s", errors)
def _get_lbp_policy(name, **policy_args):
"""
Returns the Load Balancer Policy class by name
"""
if name in LOAD_BALANCING_POLICY_MAP:
return LOAD_BALANCING_POLICY_MAP.get(name)(**policy_args)
else:
log.error("The policy %s is not available", name)
def _load_properties(property_name, config_option, set_default=False, default=None):
"""
Load properties for the cassandra module from config or pillar.
@ -147,7 +206,7 @@ def _load_properties(property_name, config_option, set_default=False, default=No
"No property specified in function, trying to load from salt configuration"
)
try:
options = __salt__["config.option"]("cassandra")
options = __salt__["config.option"]("cassandra", default={})
except BaseException as e:
log.error("Failed to get cassandra config options. Reason: %s", e)
raise
@ -175,7 +234,9 @@ def _get_ssl_opts():
Parse out ssl_options for Cassandra cluster connection.
Make sure that the ssl_version (if any specified) is valid.
"""
sslopts = __salt__["config.option"]("cassandra").get("ssl_options", None)
sslopts = __salt__["config.option"]("cassandra", default={}).get(
"ssl_options", None
)
ssl_opts = {}
if sslopts:
@ -198,7 +259,14 @@ def _get_ssl_opts():
def _connect(
contact_points=None, port=None, cql_user=None, cql_pass=None, protocol_version=None
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Connect to a Cassandra cluster.
@ -212,7 +280,13 @@ def _connect(
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type port: int
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The session and cluster objects.
:rtype: cluster object, session object
"""
@ -239,7 +313,7 @@ def _connect(
__context__["cassandra_cql_returner_session"],
)
else:
if contact_points is None:
contact_points = _load_properties(
property_name=contact_points, config_option="cluster"
)
@ -248,21 +322,25 @@ def _connect(
if isinstance(contact_points, list)
else contact_points.split(",")
)
if port is None:
port = _load_properties(
property_name=port, config_option="port", set_default=True, default=9042
)
if cql_user is None:
cql_user = _load_properties(
property_name=cql_user,
config_option="username",
set_default=True,
default="cassandra",
)
if cql_pass is None:
cql_pass = _load_properties(
property_name=cql_pass,
config_option="password",
set_default=True,
default="cassandra",
)
if protocol_version is None:
protocol_version = _load_properties(
property_name=protocol_version,
config_option="protocol_version",
@ -270,9 +348,35 @@ def _connect(
default=4,
)
if load_balancing_policy_args is None:
load_balancing_policy_args = _load_properties(
property_name=load_balancing_policy_args,
config_option="load_balancing_policy_args",
set_default=True,
default={},
)
if load_balancing_policy is None:
load_balancing_policy = _load_properties(
property_name=load_balancing_policy,
config_option="load_balancing_policy",
set_default=True,
default="RoundRobinPolicy",
)
if load_balancing_policy_args:
lbp_policy_cls = _get_lbp_policy(
load_balancing_policy, **load_balancing_policy_args
)
else:
lbp_policy_cls = _get_lbp_policy(load_balancing_policy)
try:
auth_provider = PlainTextAuthProvider(username=cql_user, password=cql_pass)
if ssl_options is None:
ssl_opts = _get_ssl_opts()
else:
ssl_opts = ssl_options
if ssl_opts:
cluster = Cluster(
contact_points,
@ -280,6 +384,7 @@ def _connect(
auth_provider=auth_provider,
ssl_options=ssl_opts,
protocol_version=protocol_version,
load_balancing_policy=lbp_policy_cls,
compression=True,
)
else:
@ -288,6 +393,7 @@ def _connect(
port=port,
auth_provider=auth_provider,
protocol_version=protocol_version,
load_balancing_policy=lbp_policy_cls,
compression=True,
)
for recontimes in range(1, 4):
@ -312,14 +418,22 @@ def _connect(
return cluster, session
except TypeError:
pass
except (ConnectionException, ConnectionShutdown, NoHostAvailable):
except (ConnectionException, ConnectionShutdown, NoHostAvailable) as err:
log.error("Could not connect to Cassandra cluster at %s", contact_points)
raise CommandExecutionError(
"ERROR: Could not connect to Cassandra cluster."
)
raise CommandExecutionError(str(err))
def cql_query(query, contact_points=None, port=None, cql_user=None, cql_pass=None):
def cql_query(
query,
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Run a query on a Cassandra cluster and return a dictionary.
@ -335,6 +449,14 @@ def cql_query(query, contact_points=None, port=None, cql_user=None, cql_pass=Non
:type port: int
:param params: The parameters for the query, optional.
:type params: str
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: A dictionary from the return values of the query
:rtype: list[dict]
@ -350,6 +472,10 @@ def cql_query(query, contact_points=None, port=None, cql_user=None, cql_pass=Non
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not get Cassandra cluster session.")
@ -416,6 +542,10 @@ def cql_query_with_prepare(
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
**kwargs
):
"""
@ -448,6 +578,14 @@ def cql_query_with_prepare(
:type port: int
:param params: The parameters for the query, optional.
:type params: str
:param protocol_version: Cassandra protocol version to use.
:type port: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: A dictionary from the return values of the query
:rtype: list[dict]
@ -472,6 +610,10 @@ def cql_query_with_prepare(
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=None,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=None,
)
except CommandExecutionError:
log.critical("Could not get Cassandra cluster session.")
@ -525,7 +667,16 @@ def cql_query_with_prepare(
return ret
def version(contact_points=None, port=None, cql_user=None, cql_pass=None):
def version(
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Show the Cassandra version.
@ -537,6 +688,14 @@ def version(contact_points=None, port=None, cql_user=None, cql_pass=None):
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The version for this Cassandra cluster.
:rtype: str
@ -551,7 +710,17 @@ def version(contact_points=None, port=None, cql_user=None, cql_pass=None):
query = "select release_version from system.local limit 1;"
try:
ret = cql_query(query, contact_points, port, cql_user, cql_pass)
ret = cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not get Cassandra version.")
raise
@ -562,7 +731,16 @@ def version(contact_points=None, port=None, cql_user=None, cql_pass=None):
return ret[0].get("release_version")
def info(contact_points=None, port=None, cql_user=None, cql_pass=None):
def info(
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Show the Cassandra information for this cluster.
@ -574,6 +752,14 @@ def info(contact_points=None, port=None, cql_user=None, cql_pass=None):
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The information for this Cassandra cluster.
:rtype: dict
@ -601,7 +787,17 @@ def info(contact_points=None, port=None, cql_user=None, cql_pass=None):
ret = {}
try:
ret = cql_query(query, contact_points, port, cql_user, cql_pass)
ret = cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not list Cassandra info.")
raise
@ -612,7 +808,16 @@ def info(contact_points=None, port=None, cql_user=None, cql_pass=None):
return ret
def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None):
def list_keyspaces(
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
List keyspaces in a Cassandra cluster.
@ -624,6 +829,14 @@ def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None)
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The keyspaces in this Cassandra cluster.
:rtype: list[dict]
@ -643,7 +856,17 @@ def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None)
ret = {}
try:
ret = cql_query(query, contact_points, port, cql_user, cql_pass)
ret = cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not list keyspaces.")
raise
@ -655,7 +878,15 @@ def list_keyspaces(contact_points=None, port=None, cql_user=None, cql_pass=None)
def list_column_families(
keyspace=None, contact_points=None, port=None, cql_user=None, cql_pass=None
keyspace=None,
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
List column families in a Cassandra cluster for all keyspaces or just the provided one.
@ -670,6 +901,14 @@ def list_column_families(
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The column families in this Cassandra cluster.
:rtype: list[dict]
@ -695,7 +934,17 @@ def list_column_families(
ret = {}
try:
ret = cql_query(query, contact_points, port, cql_user, cql_pass)
ret = cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not list column families.")
raise
@ -707,7 +956,15 @@ def list_column_families(
def keyspace_exists(
keyspace, contact_points=None, port=None, cql_user=None, cql_pass=None
keyspace,
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Check if a keyspace exists in a Cassandra cluster.
@ -722,6 +979,14 @@ def keyspace_exists(
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The info for the keyspace or False if it does not exist.
:rtype: dict
@ -743,7 +1008,17 @@ def keyspace_exists(
}
try:
ret = cql_query(query, contact_points, port, cql_user, cql_pass)
ret = cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not determine if keyspace exists.")
raise
@ -763,6 +1038,10 @@ def create_keyspace(
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Create a new keyspace in Cassandra.
@ -784,6 +1063,14 @@ def create_keyspace(
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The info for the keyspace or False if it does not exist.
:rtype: dict
@ -797,7 +1084,17 @@ def create_keyspace(
salt 'minion1' cassandra_cql.create_keyspace keyspace=newkeyspace replication_strategy=NetworkTopologyStrategy \
replication_datacenters='{"datacenter_1": 3, "datacenter_2": 2}'
"""
existing_keyspace = keyspace_exists(keyspace, contact_points, port)
existing_keyspace = keyspace_exists(
keyspace,
contact_points=contact_points,
cql_user=cql_user,
cql_pass=cql_pass,
port=port,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
if not existing_keyspace:
# Add the strategy, replication_factor, etc.
replication_map = {"class": replication_strategy}
@ -822,7 +1119,17 @@ def create_keyspace(
)
try:
cql_query(query, contact_points, port, cql_user, cql_pass)
cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not create keyspace.")
raise
@ -832,7 +1139,15 @@ def create_keyspace(
def drop_keyspace(
keyspace, contact_points=None, port=None, cql_user=None, cql_pass=None
keyspace,
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Drop a keyspace if it exists in a Cassandra cluster.
@ -847,6 +1162,14 @@ def drop_keyspace(
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The info for the keyspace or False if it does not exist.
:rtype: dict
@ -858,11 +1181,31 @@ def drop_keyspace(
salt 'minion1' cassandra_cql.drop_keyspace keyspace=test contact_points=minion1
"""
existing_keyspace = keyspace_exists(keyspace, contact_points, port)
existing_keyspace = keyspace_exists(
keyspace,
contact_points=contact_points,
cql_user=cql_user,
cql_pass=cql_pass,
port=port,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
if existing_keyspace:
query = """drop keyspace {};""".format(keyspace)
try:
cql_query(query, contact_points, port, cql_user, cql_pass)
cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not drop keyspace.")
raise
@ -873,7 +1216,16 @@ def drop_keyspace(
return True
def list_users(contact_points=None, port=None, cql_user=None, cql_pass=None):
def list_users(
contact_points=None,
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
List existing users in this Cassandra cluster.
@ -885,6 +1237,14 @@ def list_users(contact_points=None, port=None, cql_user=None, cql_pass=None):
:type cql_user: str
:param cql_pass: The Cassandra user password if authentication is turned on.
:type cql_pass: str
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: The list of existing users.
:rtype: dict
@ -901,7 +1261,17 @@ def list_users(contact_points=None, port=None, cql_user=None, cql_pass=None):
ret = {}
try:
ret = cql_query(query, contact_points, port, cql_user, cql_pass)
ret = cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not list users.")
raise
@ -920,6 +1290,10 @@ def create_user(
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Create a new cassandra user with credentials and superuser status.
@ -938,6 +1312,14 @@ def create_user(
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return:
:rtype:
@ -964,7 +1346,17 @@ def create_user(
# The create user query doesn't actually return anything if the query succeeds.
# If the query fails, catch the exception, log a messange and raise it again.
try:
cql_query(query, contact_points, port, cql_user, cql_pass)
cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not create user.")
raise
@ -984,6 +1376,10 @@ def list_permissions(
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
List permissions.
@ -1004,6 +1400,14 @@ def list_permissions(
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return: Dictionary of permissions.
:rtype: dict
@ -1034,7 +1438,17 @@ def list_permissions(
ret = {}
try:
ret = cql_query(query, contact_points, port, cql_user, cql_pass)
ret = cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not list permissions.")
raise
@ -1054,6 +1468,10 @@ def grant_permission(
port=None,
cql_user=None,
cql_pass=None,
protocol_version=None,
load_balancing_policy=None,
load_balancing_policy_args=None,
ssl_options=None,
):
"""
Grant permissions to a user.
@ -1074,6 +1492,14 @@ def grant_permission(
:type cql_pass: str
:param port: The Cassandra cluster port, defaults to None.
:type port: int
:param protocol_version: Cassandra protocol version to use.
:type protocol_version: int
:param load_balancing_policy: cassandra.policy class name to use
:type load_balancing_policy: str
:param load_balancing_policy_args: cassandra.policy constructor args
:type load_balancing_policy_args: dict
:param ssl_options: Cassandra protocol version to use.
:type ssl_options: dict
:return:
:rtype:
@ -1098,7 +1524,17 @@ def grant_permission(
log.debug("Attempting to grant permissions with query '%s'", query)
try:
cql_query(query, contact_points, port, cql_user, cql_pass)
cql_query(
query,
contact_points=contact_points,
port=port,
cql_user=cql_user,
cql_pass=cql_pass,
protocol_version=protocol_version,
load_balancing_policy=load_balancing_policy,
load_balancing_policy_args=load_balancing_policy_args,
ssl_options=ssl_options,
)
except CommandExecutionError:
log.critical("Could not grant permissions.")
raise

View file

@ -0,0 +1,215 @@
"""
Test case for the cassandra_cql module
"""
import logging
import pytest
import salt.modules.cassandra_cql as cql
from salt.exceptions import CommandExecutionError
from tests.support.mock import MagicMock, patch
log = logging.getLogger(__name__)
def test_cql_query(caplog):
"""
Test salt.modules.cassandra_cql.cql_query function
"""
mock_session = MagicMock()
mock_client = MagicMock()
mock = MagicMock(return_value=(mock_session, mock_client))
query = "query"
with patch.object(cql, "_connect", mock):
query_result = cql.cql_query(query)
assert query_result == []
query = {"5.0.1": "query1", "5.0.0": "query2"}
mock_version = MagicMock(return_value="5.0.1")
mock_session = MagicMock()
mock_client = MagicMock()
mock = MagicMock(return_value=(mock_session, mock_client))
with patch.object(cql, "version", mock_version):
with patch.object(cql, "_connect", mock):
query_result = cql.cql_query(query)
assert query_result == []
def test_cql_query_with_prepare(caplog):
"""
Test salt.modules.cassandra_cql.cql_query_with_prepare function
"""
mock_session = MagicMock()
mock_client = MagicMock()
mock = MagicMock(return_value=(mock_session, mock_client))
query = "query"
statement_args = {"arg1": "test"}
mock_context = MagicMock(
return_value={"cassandra_cql_prepared": {"statement_name": query}}
)
with patch.object(cql, "__context__", mock_context):
with patch.object(cql, "_connect", mock):
query_result = cql.cql_query_with_prepare(
query, "statement_name", statement_args
)
assert query_result == []
def test_version(caplog):
"""
Test salt.modules.cassandra_cql.version function
"""
mock_cql_query = MagicMock(return_value=[{"release_version": "5.0.1"}])
with patch.object(cql, "cql_query", mock_cql_query):
version = cql.version()
assert version == "5.0.1"
mock_cql_query = MagicMock(side_effect=CommandExecutionError)
with pytest.raises(CommandExecutionError) as err:
with patch.object(cql, "cql_query", mock_cql_query):
version = cql.version()
assert "{}".format(err.value) == ""
assert "Could not get Cassandra version." in caplog.text
for record in caplog.records:
assert record.levelname == "CRITICAL"
def test_info():
"""
Test salt.modules.cassandra_cql.info function
"""
expected = {"result": "info"}
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
info = cql.info()
assert info == expected
def test_list_keyspaces():
"""
Test salt.modules.cassandra_cql.list_keyspaces function
"""
expected = [{"keyspace_name": "name1"}, {"keyspace_name": "name2"}]
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
keyspaces = cql.list_keyspaces()
assert keyspaces == expected
def test_list_column_families():
"""
Test salt.modules.cassandra_cql.list_column_families function
"""
expected = [{"colum_name": "column1"}, {"column_name": "column2"}]
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
columns = cql.list_column_families()
assert columns == expected
def test_keyspace_exists():
"""
Test salt.modules.cassandra_cql.keyspace_exists function
"""
expected = "keyspace"
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
exists = cql.keyspace_exists("keyspace")
assert exists == bool(expected)
expected = []
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
exists = cql.keyspace_exists("keyspace")
assert exists == bool(expected)
def test_create_keyspace():
"""
Test salt.modules.cassandra_cql.create_keyspace function
"""
expected = None
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
result = cql.create_keyspace("keyspace")
assert result == expected
def test_drop_keyspace():
"""
Test salt.modules.cassandra_cql.drop_keyspace function
"""
expected = True
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
result = cql.drop_keyspace("keyspace")
assert result == expected
def test_list_users():
"""
Test salt.modules.cassandra_cql.list_users function
"""
expected = [{"name": "user1", "super": True}, {"name": "user2", "super": False}]
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
result = cql.list_users()
assert result == expected
def test_create_user():
"""
Test salt.modules.cassandra_cql.create_user function
"""
expected = True
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
result = cql.create_user("user", "password")
assert result == expected
def test_list_permissions():
"""
Test salt.modules.cassandra_cql.list_permissions function
"""
expected = [
{
"permission": "ALTER",
"resource": "<keyspace one>",
"role": "user1",
"username": "user1",
}
]
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
result = cql.list_permissions(username="user1", resource="one")
assert result == expected
def test_grant_permission():
"""
Test salt.modules.cassandra_cql.grant_permission function
"""
expected = True
mock_cql_query = MagicMock(return_value=expected)
with patch.object(cql, "cql_query", mock_cql_query):
result = cql.grant_permission(username="user1", resource="one")
assert result == expected