1# Copyright DataStax, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from collections import defaultdict
16import logging
17import six
18import threading
19
20from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist, ConsistencyLevel
21from cassandra.query import SimpleStatement, dict_factory
22
23from cassandra.cqlengine import CQLEngineException
24from cassandra.cqlengine.statements import BaseCQLStatement
25
26
27log = logging.getLogger(__name__)
28
29NOT_SET = _NOT_SET  # required for passing timeout to Session.execute
30
31cluster = None
32session = None
33
34# connections registry
35DEFAULT_CONNECTION = object()
36_connections = {}
37
38# Because type models may be registered before a connection is present,
39# and because sessions may be replaced, we must register UDTs here, in order
40# to have them registered when a new session is established.
41udt_by_keyspace = defaultdict(dict)
42
43
44def format_log_context(msg, connection=None, keyspace=None):
45    """Format log message to add keyspace and connection context"""
46    connection_info = connection or 'DEFAULT_CONNECTION'
47
48    if keyspace:
49        msg = '[Connection: {0}, Keyspace: {1}] {2}'.format(connection_info, keyspace, msg)
50    else:
51        msg = '[Connection: {0}] {1}'.format(connection_info, msg)
52    return msg
53
54
55class UndefinedKeyspaceException(CQLEngineException):
56    pass
57
58
59class Connection(object):
60    """CQLEngine Connection"""
61
62    name = None
63    hosts = None
64
65    consistency = None
66    retry_connect = False
67    lazy_connect = False
68    lazy_connect_lock = None
69    cluster_options = None
70
71    cluster = None
72    session = None
73
74    def __init__(self, name, hosts, consistency=None,
75                 lazy_connect=False, retry_connect=False, cluster_options=None):
76        self.hosts = hosts
77        self.name = name
78        self.consistency = consistency
79        self.lazy_connect = lazy_connect
80        self.retry_connect = retry_connect
81        self.cluster_options = cluster_options if cluster_options else {}
82        self.lazy_connect_lock = threading.RLock()
83
84    @classmethod
85    def from_session(cls, name, session):
86        instance = cls(name=name, hosts=session.hosts)
87        instance.cluster, instance.session = session.cluster, session
88        instance.setup_session()
89        return instance
90
91    def setup(self):
92        """Setup the connection"""
93        global cluster, session
94
95        if 'username' in self.cluster_options or 'password' in self.cluster_options:
96            raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider")
97
98        if self.lazy_connect:
99            return
100
101        self.cluster = Cluster(self.hosts, **self.cluster_options)
102        try:
103            self.session = self.cluster.connect()
104            log.debug(format_log_context("connection initialized with internally created session", connection=self.name))
105        except NoHostAvailable:
106            if self.retry_connect:
107                log.warning(format_log_context("connect failed, setting up for re-attempt on first use", connection=self.name))
108                self.lazy_connect = True
109            raise
110
111        if self.consistency is not None:
112            self.session.default_consistency_level = self.consistency
113
114        if DEFAULT_CONNECTION in _connections and _connections[DEFAULT_CONNECTION] == self:
115            cluster = _connections[DEFAULT_CONNECTION].cluster
116            session = _connections[DEFAULT_CONNECTION].session
117
118        self.setup_session()
119
120    def setup_session(self):
121        self.session.row_factory = dict_factory
122        enc = self.session.encoder
123        enc.mapping[tuple] = enc.cql_encode_tuple
124        _register_known_types(self.session.cluster)
125
126    def handle_lazy_connect(self):
127
128        # if lazy_connect is False, it means the cluster is setup and ready
129        # No need to acquire the lock
130        if not self.lazy_connect:
131            return
132
133        with self.lazy_connect_lock:
134            # lazy_connect might have been set to False by another thread while waiting the lock
135            # In this case, do nothing.
136            if self.lazy_connect:
137                log.debug(format_log_context("Lazy connect enabled", connection=self.name))
138                self.lazy_connect = False
139                self.setup()
140
141
142def register_connection(name, hosts=None, consistency=None, lazy_connect=False,
143                        retry_connect=False, cluster_options=None, default=False,
144                        session=None):
145    """
146    Add a connection to the connection registry. ``hosts`` and ``session`` are
147    mutually exclusive, and ``consistency``, ``lazy_connect``,
148    ``retry_connect``, and ``cluster_options`` only work with ``hosts``. Using
149    ``hosts`` will create a new :class:`cassandra.cluster.Cluster` and
150    :class:`cassandra.cluster.Session`.
151
152    :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`).
153    :param int consistency: The default :class:`~.ConsistencyLevel` for the
154        registered connection's new session. Default is the same as
155        :attr:`.Session.default_consistency_level`. For use with ``hosts`` only;
156        will fail when used with ``session``.
157    :param bool lazy_connect: True if should not connect until first use. For
158        use with ``hosts`` only; will fail when used with ``session``.
159    :param bool retry_connect: True if we should retry to connect even if there
160        was a connection failure initially. For use with ``hosts`` only; will
161        fail when used with ``session``.
162    :param dict cluster_options: A dict of options to be used as keyword
163        arguments to :class:`cassandra.cluster.Cluster`. For use with ``hosts``
164        only; will fail when used with ``session``.
165    :param bool default: If True, set the new connection as the cqlengine
166        default
167    :param Session session: A :class:`cassandra.cluster.Session` to be used in
168        the created connection.
169    """
170
171    if name in _connections:
172        log.warning("Registering connection '{0}' when it already exists.".format(name))
173
174    if session is not None:
175        invalid_config_args = (hosts is not None or
176                               consistency is not None or
177                               lazy_connect is not False or
178                               retry_connect is not False or
179                               cluster_options is not None)
180        if invalid_config_args:
181            raise CQLEngineException(
182                "Session configuration arguments and 'session' argument are mutually exclusive"
183            )
184        conn = Connection.from_session(name, session=session)
185        conn.setup_session()
186    else:  # use hosts argument
187        if consistency is None:
188            consistency = ConsistencyLevel.LOCAL_ONE
189        conn = Connection(
190            name, hosts=hosts,
191            consistency=consistency, lazy_connect=lazy_connect,
192            retry_connect=retry_connect, cluster_options=cluster_options
193        )
194        conn.setup()
195
196    _connections[name] = conn
197
198    if default:
199        set_default_connection(name)
200
201    return conn
202
203
204def unregister_connection(name):
205    global cluster, session
206
207    if name not in _connections:
208        return
209
210    if DEFAULT_CONNECTION in _connections and _connections[name] == _connections[DEFAULT_CONNECTION]:
211        del _connections[DEFAULT_CONNECTION]
212        cluster = None
213        session = None
214
215    conn = _connections[name]
216    if conn.cluster:
217        conn.cluster.shutdown()
218    del _connections[name]
219    log.debug("Connection '{0}' has been removed from the registry.".format(name))
220
221
222def set_default_connection(name):
223    global cluster, session
224
225    if name not in _connections:
226        raise CQLEngineException("Connection '{0}' doesn't exist.".format(name))
227
228    log.debug("Connection '{0}' has been set as default.".format(name))
229    _connections[DEFAULT_CONNECTION] = _connections[name]
230    cluster = _connections[name].cluster
231    session = _connections[name].session
232
233
234def get_connection(name=None):
235
236    if not name:
237        name = DEFAULT_CONNECTION
238
239    if name not in _connections:
240        raise CQLEngineException("Connection name '{0}' doesn't exist in the registry.".format(name))
241
242    conn = _connections[name]
243    conn.handle_lazy_connect()
244
245    return conn
246
247
248def default():
249    """
250    Configures the default connection to localhost, using the driver defaults
251    (except for row_factory)
252    """
253
254    try:
255        conn = get_connection()
256        if conn.session:
257            log.warning("configuring new default connection for cqlengine when one was already set")
258    except:
259        pass
260
261    register_connection('default', hosts=None, default=True)
262
263    log.debug("cqlengine connection initialized with default session to localhost")
264
265
266def set_session(s):
267    """
268    Configures the default connection with a preexisting :class:`cassandra.cluster.Session`
269
270    Note: the mapper presently requires a Session :attr:`~.row_factory` set to ``dict_factory``.
271    This may be relaxed in the future
272    """
273
274    try:
275        conn = get_connection()
276    except CQLEngineException:
277        # no default connection set; initalize one
278        register_connection('default', session=s, default=True)
279        conn = get_connection()
280
281    if conn.session:
282        log.warning("configuring new default connection for cqlengine when one was already set")
283
284    if s.row_factory is not dict_factory:
285        raise CQLEngineException("Failed to initialize: 'Session.row_factory' must be 'dict_factory'.")
286    conn.session = s
287    conn.cluster = s.cluster
288
289    # Set default keyspace from given session's keyspace
290    if conn.session.keyspace:
291        from cassandra.cqlengine import models
292        models.DEFAULT_KEYSPACE = conn.session.keyspace
293
294    conn.setup_session()
295
296    log.debug("cqlengine default connection initialized with %s", s)
297
298
299def setup(
300        hosts,
301        default_keyspace,
302        consistency=None,
303        lazy_connect=False,
304        retry_connect=False,
305        **kwargs):
306    """
307    Setup a the driver connection used by the mapper
308
309    :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`)
310    :param str default_keyspace: The default keyspace to use
311    :param int consistency: The global default :class:`~.ConsistencyLevel` - default is the same as :attr:`.Session.default_consistency_level`
312    :param bool lazy_connect: True if should not connect until first use
313    :param bool retry_connect: True if we should retry to connect even if there was a connection failure initially
314    :param \*\*kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster`
315    """
316
317    from cassandra.cqlengine import models
318    models.DEFAULT_KEYSPACE = default_keyspace
319
320    register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect,
321                        retry_connect=retry_connect, cluster_options=kwargs, default=True)
322
323
324def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connection=None):
325
326    conn = get_connection(connection)
327
328    if not conn.session:
329        raise CQLEngineException("It is required to setup() cqlengine before executing queries")
330
331    if isinstance(query, SimpleStatement):
332        pass  #
333    elif isinstance(query, BaseCQLStatement):
334        params = query.get_context()
335        query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
336    elif isinstance(query, six.string_types):
337        query = SimpleStatement(query, consistency_level=consistency_level)
338    log.debug(format_log_context(query.query_string, connection=connection))
339
340    result = conn.session.execute(query, params, timeout=timeout)
341
342    return result
343
344
345def get_session(connection=None):
346    conn = get_connection(connection)
347    return conn.session
348
349
350def get_cluster(connection=None):
351    conn = get_connection(connection)
352    if not conn.cluster:
353        raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__)
354    return conn.cluster
355
356
357def register_udt(keyspace, type_name, klass, connection=None):
358    udt_by_keyspace[keyspace][type_name] = klass
359
360    try:
361        cluster = get_cluster(connection)
362    except CQLEngineException:
363        cluster = None
364
365    if cluster:
366        try:
367            cluster.register_user_type(keyspace, type_name, klass)
368        except UserTypeDoesNotExist:
369            pass  # new types are covered in management sync functions
370
371
372def _register_known_types(cluster):
373    from cassandra.cqlengine import models
374    for ks_name, name_type_map in udt_by_keyspace.items():
375        for type_name, klass in name_type_map.items():
376            try:
377                cluster.register_user_type(ks_name or models.DEFAULT_KEYSPACE, type_name, klass)
378            except UserTypeDoesNotExist:
379                pass  # new types are covered in management sync functions
380