1# Copyright DataStax, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import absolute_import  # to enable import io from stdlib
16from collections import namedtuple
17import logging
18import socket
19from uuid import UUID
20
21import six
22from six.moves import range
23import io
24
25from cassandra import ProtocolVersion
26from cassandra import type_codes, DriverException
27from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
28                       WriteFailure, ReadFailure, FunctionFailure,
29                       AlreadyExists, InvalidRequest, Unauthorized,
30                       UnsupportedOperation, UserFunctionDescriptor,
31                       UserAggregateDescriptor, SchemaTargetType)
32from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
33                               int8_pack, int8_unpack, uint64_pack, header_pack,
34                               v3_header_pack, uint32_pack)
35from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
36                                CounterColumnType, DateType, DecimalType,
37                                DoubleType, FloatType, Int32Type,
38                                InetAddressType, IntegerType, ListType,
39                                LongType, MapType, SetType, TimeUUIDType,
40                                UTF8Type, VarcharType, UUIDType, UserType,
41                                TupleType, lookup_casstype, SimpleDateType,
42                                TimeType, ByteType, ShortType, DurationType)
43from cassandra import WriteType
44from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
45from cassandra import util
46
47log = logging.getLogger(__name__)
48
49
50class NotSupportedError(Exception):
51    pass
52
53
54class InternalError(Exception):
55    pass
56
57ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])
58
59HEADER_DIRECTION_TO_CLIENT = 0x80
60HEADER_DIRECTION_MASK = 0x80
61
62COMPRESSED_FLAG = 0x01
63TRACING_FLAG = 0x02
64CUSTOM_PAYLOAD_FLAG = 0x04
65WARNING_FLAG = 0x08
66USE_BETA_FLAG = 0x10
67USE_BETA_MASK = ~USE_BETA_FLAG
68
69_message_types_by_opcode = {}
70
71_UNSET_VALUE = object()
72
73
74def register_class(cls):
75    _message_types_by_opcode[cls.opcode] = cls
76
77
78def get_registered_classes():
79    return _message_types_by_opcode.copy()
80
81
82class _RegisterMessageType(type):
83    def __init__(cls, name, bases, dct):
84        if not name.startswith('_'):
85            register_class(cls)
86
87
88@six.add_metaclass(_RegisterMessageType)
89class _MessageType(object):
90
91    tracing = False
92    custom_payload = None
93    warnings = None
94
95    def update_custom_payload(self, other):
96        if other:
97            if not self.custom_payload:
98                self.custom_payload = {}
99            self.custom_payload.update(other)
100            if len(self.custom_payload) > 65535:
101                raise ValueError("Custom payload map exceeds max count allowed by protocol (65535)")
102
103    def __repr__(self):
104        return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
105
106
107def _get_params(message_obj):
108    base_attrs = dir(_MessageType)
109    return (
110        (n, a) for n, a in message_obj.__dict__.items()
111        if n not in base_attrs and not n.startswith('_') and not callable(a)
112    )
113
114
115error_classes = {}
116
117
118class ErrorMessage(_MessageType, Exception):
119    opcode = 0x00
120    name = 'ERROR'
121    summary = 'Unknown'
122
123    def __init__(self, code, message, info):
124        self.code = code
125        self.message = message
126        self.info = info
127
128    @classmethod
129    def recv_body(cls, f, protocol_version, *args):
130        code = read_int(f)
131        msg = read_string(f)
132        subcls = error_classes.get(code, cls)
133        extra_info = subcls.recv_error_info(f, protocol_version)
134        return subcls(code=code, message=msg, info=extra_info)
135
136    def summary_msg(self):
137        msg = 'Error from server: code=%04x [%s] message="%s"' \
138              % (self.code, self.summary, self.message)
139        if six.PY2 and isinstance(msg, six.text_type):
140            msg = msg.encode('utf-8')
141        return msg
142
143    def __str__(self):
144        return '<%s>' % self.summary_msg()
145    __repr__ = __str__
146
147    @staticmethod
148    def recv_error_info(f, protocol_version):
149        pass
150
151    def to_exception(self):
152        return self
153
154
155class ErrorMessageSubclass(_RegisterMessageType):
156    def __init__(cls, name, bases, dct):
157        if cls.error_code is not None:  # Server has an error code of 0.
158            error_classes[cls.error_code] = cls
159
160
161@six.add_metaclass(ErrorMessageSubclass)
162class ErrorMessageSub(ErrorMessage):
163    error_code = None
164
165
166class RequestExecutionException(ErrorMessageSub):
167    pass
168
169
170class RequestValidationException(ErrorMessageSub):
171    pass
172
173
174class ServerError(ErrorMessageSub):
175    summary = 'Server error'
176    error_code = 0x0000
177
178
179class ProtocolException(ErrorMessageSub):
180    summary = 'Protocol error'
181    error_code = 0x000A
182
183
184class BadCredentials(ErrorMessageSub):
185    summary = 'Bad credentials'
186    error_code = 0x0100
187
188
189class UnavailableErrorMessage(RequestExecutionException):
190    summary = 'Unavailable exception'
191    error_code = 0x1000
192
193    @staticmethod
194    def recv_error_info(f, protocol_version):
195        return {
196            'consistency': read_consistency_level(f),
197            'required_replicas': read_int(f),
198            'alive_replicas': read_int(f),
199        }
200
201    def to_exception(self):
202        return Unavailable(self.summary_msg(), **self.info)
203
204
205class OverloadedErrorMessage(RequestExecutionException):
206    summary = 'Coordinator node overloaded'
207    error_code = 0x1001
208
209
210class IsBootstrappingErrorMessage(RequestExecutionException):
211    summary = 'Coordinator node is bootstrapping'
212    error_code = 0x1002
213
214
215class TruncateError(RequestExecutionException):
216    summary = 'Error during truncate'
217    error_code = 0x1003
218
219
220class WriteTimeoutErrorMessage(RequestExecutionException):
221    summary = "Coordinator node timed out waiting for replica nodes' responses"
222    error_code = 0x1100
223
224    @staticmethod
225    def recv_error_info(f, protocol_version):
226        return {
227            'consistency': read_consistency_level(f),
228            'received_responses': read_int(f),
229            'required_responses': read_int(f),
230            'write_type': WriteType.name_to_value[read_string(f)],
231        }
232
233    def to_exception(self):
234        return WriteTimeout(self.summary_msg(), **self.info)
235
236
237class ReadTimeoutErrorMessage(RequestExecutionException):
238    summary = "Coordinator node timed out waiting for replica nodes' responses"
239    error_code = 0x1200
240
241    @staticmethod
242    def recv_error_info(f, protocol_version):
243        return {
244            'consistency': read_consistency_level(f),
245            'received_responses': read_int(f),
246            'required_responses': read_int(f),
247            'data_retrieved': bool(read_byte(f)),
248        }
249
250    def to_exception(self):
251        return ReadTimeout(self.summary_msg(), **self.info)
252
253
254class ReadFailureMessage(RequestExecutionException):
255    summary = "Replica(s) failed to execute read"
256    error_code = 0x1300
257
258    @staticmethod
259    def recv_error_info(f, protocol_version):
260        consistency = read_consistency_level(f)
261        received_responses = read_int(f)
262        required_responses = read_int(f)
263
264        if ProtocolVersion.uses_error_code_map(protocol_version):
265            error_code_map = read_error_code_map(f)
266            failures = len(error_code_map)
267        else:
268            error_code_map = None
269            failures = read_int(f)
270
271        data_retrieved = bool(read_byte(f))
272
273        return {
274            'consistency': consistency,
275            'received_responses': received_responses,
276            'required_responses': required_responses,
277            'failures': failures,
278            'error_code_map': error_code_map,
279            'data_retrieved': data_retrieved
280        }
281
282    def to_exception(self):
283        return ReadFailure(self.summary_msg(), **self.info)
284
285
286class FunctionFailureMessage(RequestExecutionException):
287    summary = "User Defined Function failure"
288    error_code = 0x1400
289
290    @staticmethod
291    def recv_error_info(f, protocol_version):
292        return {
293            'keyspace': read_string(f),
294            'function': read_string(f),
295            'arg_types': [read_string(f) for _ in range(read_short(f))],
296        }
297
298    def to_exception(self):
299        return FunctionFailure(self.summary_msg(), **self.info)
300
301
302class WriteFailureMessage(RequestExecutionException):
303    summary = "Replica(s) failed to execute write"
304    error_code = 0x1500
305
306    @staticmethod
307    def recv_error_info(f, protocol_version):
308        consistency = read_consistency_level(f)
309        received_responses = read_int(f)
310        required_responses = read_int(f)
311
312        if ProtocolVersion.uses_error_code_map(protocol_version):
313            error_code_map = read_error_code_map(f)
314            failures = len(error_code_map)
315        else:
316            error_code_map = None
317            failures = read_int(f)
318
319        write_type = WriteType.name_to_value[read_string(f)]
320
321        return {
322            'consistency': consistency,
323            'received_responses': received_responses,
324            'required_responses': required_responses,
325            'failures': failures,
326            'error_code_map': error_code_map,
327            'write_type': write_type
328        }
329
330    def to_exception(self):
331        return WriteFailure(self.summary_msg(), **self.info)
332
333
334class CDCWriteException(RequestExecutionException):
335    summary = 'Failed to execute write due to CDC space exhaustion.'
336    error_code = 0x1600
337
338
339class SyntaxException(RequestValidationException):
340    summary = 'Syntax error in CQL query'
341    error_code = 0x2000
342
343
344class UnauthorizedErrorMessage(RequestValidationException):
345    summary = 'Unauthorized'
346    error_code = 0x2100
347
348    def to_exception(self):
349        return Unauthorized(self.summary_msg())
350
351
352class InvalidRequestException(RequestValidationException):
353    summary = 'Invalid query'
354    error_code = 0x2200
355
356    def to_exception(self):
357        return InvalidRequest(self.summary_msg())
358
359
360class ConfigurationException(RequestValidationException):
361    summary = 'Query invalid because of configuration issue'
362    error_code = 0x2300
363
364
365class PreparedQueryNotFound(RequestValidationException):
366    summary = 'Matching prepared statement not found on this node'
367    error_code = 0x2500
368
369    @staticmethod
370    def recv_error_info(f, protocol_version):
371        # return the query ID
372        return read_binary_string(f)
373
374
375class AlreadyExistsException(ConfigurationException):
376    summary = 'Item already exists'
377    error_code = 0x2400
378
379    @staticmethod
380    def recv_error_info(f, protocol_version):
381        return {
382            'keyspace': read_string(f),
383            'table': read_string(f),
384        }
385
386    def to_exception(self):
387        return AlreadyExists(**self.info)
388
389
390class StartupMessage(_MessageType):
391    opcode = 0x01
392    name = 'STARTUP'
393
394    KNOWN_OPTION_KEYS = set((
395        'CQL_VERSION',
396        'COMPRESSION',
397        'NO_COMPACT'
398    ))
399
400    def __init__(self, cqlversion, options):
401        self.cqlversion = cqlversion
402        self.options = options
403
404    def send_body(self, f, protocol_version):
405        optmap = self.options.copy()
406        optmap['CQL_VERSION'] = self.cqlversion
407        write_stringmap(f, optmap)
408
409
410class ReadyMessage(_MessageType):
411    opcode = 0x02
412    name = 'READY'
413
414    @classmethod
415    def recv_body(cls, *args):
416        return cls()
417
418
419class AuthenticateMessage(_MessageType):
420    opcode = 0x03
421    name = 'AUTHENTICATE'
422
423    def __init__(self, authenticator):
424        self.authenticator = authenticator
425
426    @classmethod
427    def recv_body(cls, f, *args):
428        authname = read_string(f)
429        return cls(authenticator=authname)
430
431
432class CredentialsMessage(_MessageType):
433    opcode = 0x04
434    name = 'CREDENTIALS'
435
436    def __init__(self, creds):
437        self.creds = creds
438
439    def send_body(self, f, protocol_version):
440        if protocol_version > 1:
441            raise UnsupportedOperation(
442                "Credentials-based authentication is not supported with "
443                "protocol version 2 or higher.  Use the SASL authentication "
444                "mechanism instead.")
445        write_short(f, len(self.creds))
446        for credkey, credval in self.creds.items():
447            write_string(f, credkey)
448            write_string(f, credval)
449
450
451class AuthChallengeMessage(_MessageType):
452    opcode = 0x0E
453    name = 'AUTH_CHALLENGE'
454
455    def __init__(self, challenge):
456        self.challenge = challenge
457
458    @classmethod
459    def recv_body(cls, f, *args):
460        return cls(read_binary_longstring(f))
461
462
463class AuthResponseMessage(_MessageType):
464    opcode = 0x0F
465    name = 'AUTH_RESPONSE'
466
467    def __init__(self, response):
468        self.response = response
469
470    def send_body(self, f, protocol_version):
471        write_longstring(f, self.response)
472
473
474class AuthSuccessMessage(_MessageType):
475    opcode = 0x10
476    name = 'AUTH_SUCCESS'
477
478    def __init__(self, token):
479        self.token = token
480
481    @classmethod
482    def recv_body(cls, f, *args):
483        return cls(read_longstring(f))
484
485
486class OptionsMessage(_MessageType):
487    opcode = 0x05
488    name = 'OPTIONS'
489
490    def send_body(self, f, protocol_version):
491        pass
492
493
494class SupportedMessage(_MessageType):
495    opcode = 0x06
496    name = 'SUPPORTED'
497
498    def __init__(self, cql_versions, options):
499        self.cql_versions = cql_versions
500        self.options = options
501
502    @classmethod
503    def recv_body(cls, f, *args):
504        options = read_stringmultimap(f)
505        cql_versions = options.pop('CQL_VERSION')
506        return cls(cql_versions=cql_versions, options=options)
507
508
509# used for QueryMessage and ExecuteMessage
510_VALUES_FLAG = 0x01
511_SKIP_METADATA_FLAG = 0x02
512_PAGE_SIZE_FLAG = 0x04
513_WITH_PAGING_STATE_FLAG = 0x08
514_WITH_SERIAL_CONSISTENCY_FLAG = 0x10
515_PROTOCOL_TIMESTAMP = 0x20
516_WITH_KEYSPACE_FLAG = 0x80
517_PREPARED_WITH_KEYSPACE_FLAG = 0x01
518
519
520class QueryMessage(_MessageType):
521    opcode = 0x07
522    name = 'QUERY'
523
524    def __init__(self, query, consistency_level, serial_consistency_level=None,
525                 fetch_size=None, paging_state=None, timestamp=None, keyspace=None):
526        self.query = query
527        self.consistency_level = consistency_level
528        self.serial_consistency_level = serial_consistency_level
529        self.fetch_size = fetch_size
530        self.paging_state = paging_state
531        self.timestamp = timestamp
532        self.keyspace = keyspace
533        self._query_params = None  # only used internally. May be set to a list of native-encoded values to have them sent with the request.
534
535    def send_body(self, f, protocol_version):
536        write_longstring(f, self.query)
537        write_consistency_level(f, self.consistency_level)
538        flags = 0x00
539        if self._query_params is not None:
540            flags |= _VALUES_FLAG  # also v2+, but we're only setting params internally right now
541
542        if self.serial_consistency_level:
543            if protocol_version >= 2:
544                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
545            else:
546                raise UnsupportedOperation(
547                    "Serial consistency levels require the use of protocol version "
548                    "2 or higher. Consider setting Cluster.protocol_version to 2 "
549                    "to support serial consistency levels.")
550
551        if self.fetch_size:
552            if protocol_version >= 2:
553                flags |= _PAGE_SIZE_FLAG
554            else:
555                raise UnsupportedOperation(
556                    "Automatic query paging may only be used with protocol version "
557                    "2 or higher. Consider setting Cluster.protocol_version to 2.")
558
559        if self.paging_state:
560            if protocol_version >= 2:
561                flags |= _WITH_PAGING_STATE_FLAG
562            else:
563                raise UnsupportedOperation(
564                    "Automatic query paging may only be used with protocol version "
565                    "2 or higher. Consider setting Cluster.protocol_version to 2.")
566
567        if self.timestamp is not None:
568            flags |= _PROTOCOL_TIMESTAMP
569
570        if self.keyspace is not None:
571            if ProtocolVersion.uses_keyspace_flag(protocol_version):
572                flags |= _WITH_KEYSPACE_FLAG
573            else:
574                raise UnsupportedOperation(
575                    "Keyspaces may only be set on queries with protocol version "
576                    "5 or higher. Consider setting Cluster.protocol_version to 5.")
577
578        if ProtocolVersion.uses_int_query_flags(protocol_version):
579            write_uint(f, flags)
580        else:
581            write_byte(f, flags)
582
583        if self._query_params is not None:
584            write_short(f, len(self._query_params))
585            for param in self._query_params:
586                write_value(f, param)
587
588        if self.fetch_size:
589            write_int(f, self.fetch_size)
590        if self.paging_state:
591            write_longstring(f, self.paging_state)
592        if self.serial_consistency_level:
593            write_consistency_level(f, self.serial_consistency_level)
594        if self.timestamp is not None:
595            write_long(f, self.timestamp)
596        if self.keyspace is not None:
597            write_string(f, self.keyspace)
598
599
600CUSTOM_TYPE = object()
601
602RESULT_KIND_VOID = 0x0001
603RESULT_KIND_ROWS = 0x0002
604RESULT_KIND_SET_KEYSPACE = 0x0003
605RESULT_KIND_PREPARED = 0x0004
606RESULT_KIND_SCHEMA_CHANGE = 0x0005
607
608
609class ResultMessage(_MessageType):
610    opcode = 0x08
611    name = 'RESULT'
612
613    kind = None
614    results = None
615    paging_state = None
616
617    # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE)
618    type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_'))
619
620    _FLAGS_GLOBAL_TABLES_SPEC = 0x0001
621    _HAS_MORE_PAGES_FLAG = 0x0002
622    _NO_METADATA_FLAG = 0x0004
623    _METADATA_ID_FLAG = 0x0008
624
625    def __init__(self, kind, results, paging_state=None, col_types=None):
626        self.kind = kind
627        self.results = results
628        self.paging_state = paging_state
629        self.col_types = col_types
630
631    @classmethod
632    def recv_body(cls, f, protocol_version, user_type_map, result_metadata):
633        kind = read_int(f)
634        paging_state = None
635        col_types = None
636        if kind == RESULT_KIND_VOID:
637            results = None
638        elif kind == RESULT_KIND_ROWS:
639            paging_state, col_types, results, result_metadata_id = cls.recv_results_rows(
640                f, protocol_version, user_type_map, result_metadata)
641        elif kind == RESULT_KIND_SET_KEYSPACE:
642            ksname = read_string(f)
643            results = ksname
644        elif kind == RESULT_KIND_PREPARED:
645            results = cls.recv_results_prepared(f, protocol_version, user_type_map)
646        elif kind == RESULT_KIND_SCHEMA_CHANGE:
647            results = cls.recv_results_schema_change(f, protocol_version)
648        else:
649            raise DriverException("Unknown RESULT kind: %d" % kind)
650        return cls(kind, results, paging_state, col_types)
651
652    @classmethod
653    def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata):
654        paging_state, column_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map)
655        column_metadata = column_metadata or result_metadata
656        rowcount = read_int(f)
657        rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
658        colnames = [c[2] for c in column_metadata]
659        coltypes = [c[3] for c in column_metadata]
660        try:
661            parsed_rows = [
662                tuple(ctype.from_binary(val, protocol_version)
663                      for ctype, val in zip(coltypes, row))
664                for row in rows]
665        except Exception:
666            for row in rows:
667                for i in range(len(row)):
668                    try:
669                        coltypes[i].from_binary(row[i], protocol_version)
670                    except Exception as e:
671                        raise DriverException('Failed decoding result column "%s" of type %s: %s' % (colnames[i],
672                                                                                                     coltypes[i].cql_parameterized_type(),
673                                                                                                     str(e)))
674        return paging_state, coltypes, (colnames, parsed_rows), result_metadata_id
675
676    @classmethod
677    def recv_results_prepared(cls, f, protocol_version, user_type_map):
678        query_id = read_binary_string(f)
679        if ProtocolVersion.uses_prepared_metadata(protocol_version):
680            result_metadata_id = read_binary_string(f)
681        else:
682            result_metadata_id = None
683        bind_metadata, pk_indexes, result_metadata, _ = cls.recv_prepared_metadata(f, protocol_version, user_type_map)
684        return query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id
685
686    @classmethod
687    def recv_results_metadata(cls, f, user_type_map):
688        flags = read_int(f)
689        colcount = read_int(f)
690
691        if flags & cls._HAS_MORE_PAGES_FLAG:
692            paging_state = read_binary_longstring(f)
693        else:
694            paging_state = None
695
696        if flags & cls._METADATA_ID_FLAG:
697            result_metadata_id = read_binary_string(f)
698        else:
699            result_metadata_id = None
700
701        no_meta = bool(flags & cls._NO_METADATA_FLAG)
702        if no_meta:
703            return paging_state, [], result_metadata_id
704
705        glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
706        if glob_tblspec:
707            ksname = read_string(f)
708            cfname = read_string(f)
709        column_metadata = []
710        for _ in range(colcount):
711            if glob_tblspec:
712                colksname = ksname
713                colcfname = cfname
714            else:
715                colksname = read_string(f)
716                colcfname = read_string(f)
717            colname = read_string(f)
718            coltype = cls.read_type(f, user_type_map)
719            column_metadata.append((colksname, colcfname, colname, coltype))
720        return paging_state, column_metadata, result_metadata_id
721
722    @classmethod
723    def recv_prepared_metadata(cls, f, protocol_version, user_type_map):
724        flags = read_int(f)
725        colcount = read_int(f)
726        pk_indexes = None
727        if protocol_version >= 4:
728            num_pk_indexes = read_int(f)
729            pk_indexes = [read_short(f) for _ in range(num_pk_indexes)]
730
731        glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
732        if glob_tblspec:
733            ksname = read_string(f)
734            cfname = read_string(f)
735        bind_metadata = []
736        for _ in range(colcount):
737            if glob_tblspec:
738                colksname = ksname
739                colcfname = cfname
740            else:
741                colksname = read_string(f)
742                colcfname = read_string(f)
743            colname = read_string(f)
744            coltype = cls.read_type(f, user_type_map)
745            bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype))
746
747        if protocol_version >= 2:
748            _, result_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map)
749            return bind_metadata, pk_indexes, result_metadata, result_metadata_id
750        else:
751            return bind_metadata, pk_indexes, None, None
752
753    @classmethod
754    def recv_results_schema_change(cls, f, protocol_version):
755        return EventMessage.recv_schema_change(f, protocol_version)
756
757    @classmethod
758    def read_type(cls, f, user_type_map):
759        optid = read_short(f)
760        try:
761            typeclass = cls.type_codes[optid]
762        except KeyError:
763            raise NotSupportedError("Unknown data type code 0x%04x. Have to skip"
764                                    " entire result set." % (optid,))
765        if typeclass in (ListType, SetType):
766            subtype = cls.read_type(f, user_type_map)
767            typeclass = typeclass.apply_parameters((subtype,))
768        elif typeclass == MapType:
769            keysubtype = cls.read_type(f, user_type_map)
770            valsubtype = cls.read_type(f, user_type_map)
771            typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
772        elif typeclass == TupleType:
773            num_items = read_short(f)
774            types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items))
775            typeclass = typeclass.apply_parameters(types)
776        elif typeclass == UserType:
777            ks = read_string(f)
778            udt_name = read_string(f)
779            num_fields = read_short(f)
780            names, types = zip(*((read_string(f), cls.read_type(f, user_type_map))
781                                 for _ in range(num_fields)))
782            specialized_type = typeclass.make_udt_class(ks, udt_name, names, types)
783            specialized_type.mapped_class = user_type_map.get(ks, {}).get(udt_name)
784            typeclass = specialized_type
785        elif typeclass == CUSTOM_TYPE:
786            classname = read_string(f)
787            typeclass = lookup_casstype(classname)
788
789        return typeclass
790
791    @staticmethod
792    def recv_row(f, colcount):
793        return [read_value(f) for _ in range(colcount)]
794
795
796class PrepareMessage(_MessageType):
797    opcode = 0x09
798    name = 'PREPARE'
799
800    def __init__(self, query, keyspace=None):
801        self.query = query
802        self.keyspace = keyspace
803
804    def send_body(self, f, protocol_version):
805        write_longstring(f, self.query)
806
807        flags = 0x00
808
809        if self.keyspace is not None:
810            if ProtocolVersion.uses_keyspace_flag(protocol_version):
811                flags |= _PREPARED_WITH_KEYSPACE_FLAG
812            else:
813                raise UnsupportedOperation(
814                    "Keyspaces may only be set on queries with protocol version "
815                    "5 or higher. Consider setting Cluster.protocol_version to 5.")
816
817        if ProtocolVersion.uses_prepare_flags(protocol_version):
818            write_uint(f, flags)
819        else:
820            # checks above should prevent this, but just to be safe...
821            if flags:
822                raise UnsupportedOperation(
823                    "Attempted to set flags with value {flags:0=#8x} on"
824                    "protocol version {pv}, which doesn't support flags"
825                    "in prepared statements."
826                    "Consider setting Cluster.protocol_version to 5."
827                    "".format(flags=flags, pv=protocol_version))
828
829        if ProtocolVersion.uses_keyspace_flag(protocol_version):
830            if self.keyspace:
831                write_string(f, self.keyspace)
832
833
834class ExecuteMessage(_MessageType):
835    opcode = 0x0A
836    name = 'EXECUTE'
837    def __init__(self, query_id, query_params, consistency_level,
838                 serial_consistency_level=None, fetch_size=None,
839                 paging_state=None, timestamp=None, skip_meta=False,
840                 result_metadata_id=None):
841        self.query_id = query_id
842        self.query_params = query_params
843        self.consistency_level = consistency_level
844        self.serial_consistency_level = serial_consistency_level
845        self.fetch_size = fetch_size
846        self.paging_state = paging_state
847        self.timestamp = timestamp
848        self.skip_meta = skip_meta
849        self.result_metadata_id = result_metadata_id
850
851    def send_body(self, f, protocol_version):
852        write_string(f, self.query_id)
853        if ProtocolVersion.uses_prepared_metadata(protocol_version):
854            write_string(f, self.result_metadata_id)
855        if protocol_version == 1:
856            if self.serial_consistency_level:
857                raise UnsupportedOperation(
858                    "Serial consistency levels require the use of protocol version "
859                    "2 or higher. Consider setting Cluster.protocol_version to 2 "
860                    "to support serial consistency levels.")
861            if self.fetch_size or self.paging_state:
862                raise UnsupportedOperation(
863                    "Automatic query paging may only be used with protocol version "
864                    "2 or higher. Consider setting Cluster.protocol_version to 2.")
865            write_short(f, len(self.query_params))
866            for param in self.query_params:
867                write_value(f, param)
868            write_consistency_level(f, self.consistency_level)
869        else:
870            write_consistency_level(f, self.consistency_level)
871            flags = _VALUES_FLAG
872            if self.serial_consistency_level:
873                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
874            if self.fetch_size:
875                flags |= _PAGE_SIZE_FLAG
876            if self.paging_state:
877                flags |= _WITH_PAGING_STATE_FLAG
878            if self.timestamp is not None:
879                if protocol_version >= 3:
880                    flags |= _PROTOCOL_TIMESTAMP
881                else:
882                    raise UnsupportedOperation(
883                        "Protocol-level timestamps may only be used with protocol version "
884                        "3 or higher. Consider setting Cluster.protocol_version to 3.")
885            if self.skip_meta:
886                flags |= _SKIP_METADATA_FLAG
887
888            if ProtocolVersion.uses_int_query_flags(protocol_version):
889                write_uint(f, flags)
890            else:
891                write_byte(f, flags)
892
893            write_short(f, len(self.query_params))
894            for param in self.query_params:
895                write_value(f, param)
896            if self.fetch_size:
897                write_int(f, self.fetch_size)
898            if self.paging_state:
899                write_longstring(f, self.paging_state)
900            if self.serial_consistency_level:
901                write_consistency_level(f, self.serial_consistency_level)
902            if self.timestamp is not None:
903                write_long(f, self.timestamp)
904
905
906
907class BatchMessage(_MessageType):
908    opcode = 0x0D
909    name = 'BATCH'
910
911    def __init__(self, batch_type, queries, consistency_level,
912                 serial_consistency_level=None, timestamp=None,
913                 keyspace=None):
914        self.batch_type = batch_type
915        self.queries = queries
916        self.consistency_level = consistency_level
917        self.serial_consistency_level = serial_consistency_level
918        self.timestamp = timestamp
919        self.keyspace = keyspace
920
921    def send_body(self, f, protocol_version):
922        write_byte(f, self.batch_type.value)
923        write_short(f, len(self.queries))
924        for prepared, string_or_query_id, params in self.queries:
925            if not prepared:
926                write_byte(f, 0)
927                write_longstring(f, string_or_query_id)
928            else:
929                write_byte(f, 1)
930                write_short(f, len(string_or_query_id))
931                f.write(string_or_query_id)
932            write_short(f, len(params))
933            for param in params:
934                write_value(f, param)
935
936        write_consistency_level(f, self.consistency_level)
937        if protocol_version >= 3:
938            flags = 0
939            if self.serial_consistency_level:
940                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
941            if self.timestamp is not None:
942                flags |= _PROTOCOL_TIMESTAMP
943            if self.keyspace:
944                if ProtocolVersion.uses_keyspace_flag(protocol_version):
945                    flags |= _WITH_KEYSPACE_FLAG
946                else:
947                    raise UnsupportedOperation(
948                        "Keyspaces may only be set on queries with protocol version "
949                        "5 or higher. Consider setting Cluster.protocol_version to 5.")
950
951            if ProtocolVersion.uses_int_query_flags(protocol_version):
952                write_int(f, flags)
953            else:
954                write_byte(f, flags)
955
956            if self.serial_consistency_level:
957                write_consistency_level(f, self.serial_consistency_level)
958            if self.timestamp is not None:
959                write_long(f, self.timestamp)
960
961            if ProtocolVersion.uses_keyspace_flag(protocol_version):
962                if self.keyspace is not None:
963                    write_string(f, self.keyspace)
964
965
966known_event_types = frozenset((
967    'TOPOLOGY_CHANGE',
968    'STATUS_CHANGE',
969    'SCHEMA_CHANGE'
970))
971
972
973class RegisterMessage(_MessageType):
974    opcode = 0x0B
975    name = 'REGISTER'
976
977    def __init__(self, event_list):
978        self.event_list = event_list
979
980    def send_body(self, f, protocol_version):
981        write_stringlist(f, self.event_list)
982
983
984class EventMessage(_MessageType):
985    opcode = 0x0C
986    name = 'EVENT'
987
988    def __init__(self, event_type, event_args):
989        self.event_type = event_type
990        self.event_args = event_args
991
992    @classmethod
993    def recv_body(cls, f, protocol_version, *args):
994        event_type = read_string(f).upper()
995        if event_type in known_event_types:
996            read_method = getattr(cls, 'recv_' + event_type.lower())
997            return cls(event_type=event_type, event_args=read_method(f, protocol_version))
998        raise NotSupportedError('Unknown event type %r' % event_type)
999
1000    @classmethod
1001    def recv_topology_change(cls, f, protocol_version):
1002        # "NEW_NODE" or "REMOVED_NODE"
1003        change_type = read_string(f)
1004        address = read_inet(f)
1005        return dict(change_type=change_type, address=address)
1006
1007    @classmethod
1008    def recv_status_change(cls, f, protocol_version):
1009        # "UP" or "DOWN"
1010        change_type = read_string(f)
1011        address = read_inet(f)
1012        return dict(change_type=change_type, address=address)
1013
1014    @classmethod
1015    def recv_schema_change(cls, f, protocol_version):
1016        # "CREATED", "DROPPED", or "UPDATED"
1017        change_type = read_string(f)
1018        if protocol_version >= 3:
1019            target = read_string(f)
1020            keyspace = read_string(f)
1021            event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace}
1022            if target != SchemaTargetType.KEYSPACE:
1023                target_name = read_string(f)
1024                if target == SchemaTargetType.FUNCTION:
1025                    event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
1026                elif target == SchemaTargetType.AGGREGATE:
1027                    event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
1028                else:
1029                    event[target.lower()] = target_name
1030        else:
1031            keyspace = read_string(f)
1032            table = read_string(f)
1033            if table:
1034                event = {'target_type': SchemaTargetType.TABLE, 'change_type': change_type, 'keyspace': keyspace, 'table': table}
1035            else:
1036                event = {'target_type': SchemaTargetType.KEYSPACE, 'change_type': change_type, 'keyspace': keyspace}
1037        return event
1038
1039
1040class _ProtocolHandler(object):
1041    """
1042    _ProtocolHander handles encoding and decoding messages.
1043
1044    This class can be specialized to compose Handlers which implement alternative
1045    result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster`
1046    on initialization.
1047
1048    Contracted class methods are :meth:`_ProtocolHandler.encode_message` and :meth:`_ProtocolHandler.decode_message`.
1049    """
1050
1051    message_types_by_opcode = _message_types_by_opcode.copy()
1052    """
1053    Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses
1054    this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized
1055    result decoding implementations.
1056    """
1057
1058    @classmethod
1059    def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version):
1060        """
1061        Encodes a message using the specified frame parameters, and compressor
1062
1063        :param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver
1064        :param stream_id: protocol stream id for the frame header
1065        :param protocol_version: version for the frame header, and used encoding contents
1066        :param compressor: optional compression function to be used on the body
1067        """
1068        flags = 0
1069        body = io.BytesIO()
1070        if msg.custom_payload:
1071            if protocol_version < 4:
1072                raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
1073            flags |= CUSTOM_PAYLOAD_FLAG
1074            write_bytesmap(body, msg.custom_payload)
1075        msg.send_body(body, protocol_version)
1076        body = body.getvalue()
1077
1078        if compressor and len(body) > 0:
1079            body = compressor(body)
1080            flags |= COMPRESSED_FLAG
1081
1082        if msg.tracing:
1083            flags |= TRACING_FLAG
1084
1085        if allow_beta_protocol_version:
1086            flags |= USE_BETA_FLAG
1087
1088        buff = io.BytesIO()
1089        cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
1090        buff.write(body)
1091
1092        return buff.getvalue()
1093
1094    @staticmethod
1095    def _write_header(f, version, flags, stream_id, opcode, length):
1096        """
1097        Write a CQL protocol frame header.
1098        """
1099        pack = v3_header_pack if version >= 3 else header_pack
1100        f.write(pack(version, flags, stream_id, opcode))
1101        write_int(f, length)
1102
1103    @classmethod
1104    def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body,
1105                       decompressor, result_metadata):
1106        """
1107        Decodes a native protocol message body
1108
1109        :param protocol_version: version to use decoding contents
1110        :param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type
1111        :param stream_id: native protocol stream id from the frame header
1112        :param flags: native protocol flags bitmap from the header
1113        :param opcode: native protocol opcode from the header
1114        :param body: frame body
1115        :param decompressor: optional decompression function to inflate the body
1116        :return: a message decoded from the body and frame attributes
1117        """
1118        if flags & COMPRESSED_FLAG:
1119            if decompressor is None:
1120                raise RuntimeError("No de-compressor available for compressed frame!")
1121            body = decompressor(body)
1122            flags ^= COMPRESSED_FLAG
1123
1124        body = io.BytesIO(body)
1125        if flags & TRACING_FLAG:
1126            trace_id = UUID(bytes=body.read(16))
1127            flags ^= TRACING_FLAG
1128        else:
1129            trace_id = None
1130
1131        if flags & WARNING_FLAG:
1132            warnings = read_stringlist(body)
1133            flags ^= WARNING_FLAG
1134        else:
1135            warnings = None
1136
1137        if flags & CUSTOM_PAYLOAD_FLAG:
1138            custom_payload = read_bytesmap(body)
1139            flags ^= CUSTOM_PAYLOAD_FLAG
1140        else:
1141            custom_payload = None
1142
1143        flags &= USE_BETA_MASK # will only be set if we asserted it in connection estabishment
1144
1145        if flags:
1146            log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
1147
1148        msg_class = cls.message_types_by_opcode[opcode]
1149        msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata)
1150        msg.stream_id = stream_id
1151        msg.trace_id = trace_id
1152        msg.custom_payload = custom_payload
1153        msg.warnings = warnings
1154
1155        if msg.warnings:
1156            for w in msg.warnings:
1157                log.warning("Server warning: %s", w)
1158
1159        return msg
1160
1161def cython_protocol_handler(colparser):
1162    """
1163    Given a column parser to deserialize ResultMessages, return a suitable
1164    Cython-based protocol handler.
1165
1166    There are three Cython-based protocol handlers:
1167
1168        - obj_parser.ListParser
1169            decodes result messages into a list of tuples
1170
1171        - obj_parser.LazyParser
1172            decodes result messages lazily by returning an iterator
1173
1174        - numpy_parser.NumPyParser
1175            decodes result messages into NumPy arrays
1176
1177    The default is to use obj_parser.ListParser
1178    """
1179    from cassandra.row_parser import make_recv_results_rows
1180
1181    class FastResultMessage(ResultMessage):
1182        """
1183        Cython version of Result Message that has a faster implementation of
1184        recv_results_row.
1185        """
1186        # type_codes = ResultMessage.type_codes.copy()
1187        code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items())
1188        recv_results_rows = classmethod(make_recv_results_rows(colparser))
1189
1190    class CythonProtocolHandler(_ProtocolHandler):
1191        """
1192        Use FastResultMessage to decode query result message messages.
1193        """
1194
1195        my_opcodes = _ProtocolHandler.message_types_by_opcode.copy()
1196        my_opcodes[FastResultMessage.opcode] = FastResultMessage
1197        message_types_by_opcode = my_opcodes
1198
1199        col_parser = colparser
1200
1201    return CythonProtocolHandler
1202
1203
1204if HAVE_CYTHON:
1205    from cassandra.obj_parser import ListParser, LazyParser
1206    ProtocolHandler = cython_protocol_handler(ListParser())
1207    LazyProtocolHandler = cython_protocol_handler(LazyParser())
1208else:
1209    # Use Python-based ProtocolHandler
1210    ProtocolHandler = _ProtocolHandler
1211    LazyProtocolHandler = None
1212
1213
1214if HAVE_CYTHON and HAVE_NUMPY:
1215    from cassandra.numpy_parser import NumpyParser
1216    NumpyProtocolHandler = cython_protocol_handler(NumpyParser())
1217else:
1218    NumpyProtocolHandler = None
1219
1220
1221def read_byte(f):
1222    return int8_unpack(f.read(1))
1223
1224
1225def write_byte(f, b):
1226    f.write(int8_pack(b))
1227
1228
1229def read_int(f):
1230    return int32_unpack(f.read(4))
1231
1232
1233def write_int(f, i):
1234    f.write(int32_pack(i))
1235
1236
1237def write_uint(f, i):
1238    f.write(uint32_pack(i))
1239
1240
1241def write_long(f, i):
1242    f.write(uint64_pack(i))
1243
1244
1245def read_short(f):
1246    return uint16_unpack(f.read(2))
1247
1248
1249def write_short(f, s):
1250    f.write(uint16_pack(s))
1251
1252
1253def read_consistency_level(f):
1254    return read_short(f)
1255
1256
1257def write_consistency_level(f, cl):
1258    write_short(f, cl)
1259
1260
1261def read_string(f):
1262    size = read_short(f)
1263    contents = f.read(size)
1264    return contents.decode('utf8')
1265
1266
1267def read_binary_string(f):
1268    size = read_short(f)
1269    contents = f.read(size)
1270    return contents
1271
1272
1273def write_string(f, s):
1274    if isinstance(s, six.text_type):
1275        s = s.encode('utf8')
1276    write_short(f, len(s))
1277    f.write(s)
1278
1279
1280def read_binary_longstring(f):
1281    size = read_int(f)
1282    contents = f.read(size)
1283    return contents
1284
1285
1286def read_longstring(f):
1287    return read_binary_longstring(f).decode('utf8')
1288
1289
1290def write_longstring(f, s):
1291    if isinstance(s, six.text_type):
1292        s = s.encode('utf8')
1293    write_int(f, len(s))
1294    f.write(s)
1295
1296
1297def read_stringlist(f):
1298    numstrs = read_short(f)
1299    return [read_string(f) for _ in range(numstrs)]
1300
1301
1302def write_stringlist(f, stringlist):
1303    write_short(f, len(stringlist))
1304    for s in stringlist:
1305        write_string(f, s)
1306
1307
1308def read_stringmap(f):
1309    numpairs = read_short(f)
1310    strmap = {}
1311    for _ in range(numpairs):
1312        k = read_string(f)
1313        strmap[k] = read_string(f)
1314    return strmap
1315
1316
1317def write_stringmap(f, strmap):
1318    write_short(f, len(strmap))
1319    for k, v in strmap.items():
1320        write_string(f, k)
1321        write_string(f, v)
1322
1323
1324def read_bytesmap(f):
1325    numpairs = read_short(f)
1326    bytesmap = {}
1327    for _ in range(numpairs):
1328        k = read_string(f)
1329        bytesmap[k] = read_value(f)
1330    return bytesmap
1331
1332
1333def write_bytesmap(f, bytesmap):
1334    write_short(f, len(bytesmap))
1335    for k, v in bytesmap.items():
1336        write_string(f, k)
1337        write_value(f, v)
1338
1339
1340def read_stringmultimap(f):
1341    numkeys = read_short(f)
1342    strmmap = {}
1343    for _ in range(numkeys):
1344        k = read_string(f)
1345        strmmap[k] = read_stringlist(f)
1346    return strmmap
1347
1348
1349def write_stringmultimap(f, strmmap):
1350    write_short(f, len(strmmap))
1351    for k, v in strmmap.items():
1352        write_string(f, k)
1353        write_stringlist(f, v)
1354
1355
1356def read_error_code_map(f):
1357    numpairs = read_int(f)
1358    error_code_map = {}
1359    for _ in range(numpairs):
1360        endpoint = read_inet_addr_only(f)
1361        error_code_map[endpoint] = read_short(f)
1362    return error_code_map
1363
1364
1365def read_value(f):
1366    size = read_int(f)
1367    if size < 0:
1368        return None
1369    return f.read(size)
1370
1371
1372def write_value(f, v):
1373    if v is None:
1374        write_int(f, -1)
1375    elif v is _UNSET_VALUE:
1376        write_int(f, -2)
1377    else:
1378        write_int(f, len(v))
1379        f.write(v)
1380
1381
1382def read_inet_addr_only(f):
1383    size = read_byte(f)
1384    addrbytes = f.read(size)
1385    if size == 4:
1386        addrfam = socket.AF_INET
1387    elif size == 16:
1388        addrfam = socket.AF_INET6
1389    else:
1390        raise InternalError("bad inet address: %r" % (addrbytes,))
1391    return util.inet_ntop(addrfam, addrbytes)
1392
1393
1394def read_inet(f):
1395    addr = read_inet_addr_only(f)
1396    port = read_int(f)
1397    return (addr, port)
1398
1399
1400def write_inet(f, addrtuple):
1401    addr, port = addrtuple
1402    if ':' in addr:
1403        addrfam = socket.AF_INET6
1404    else:
1405        addrfam = socket.AF_INET
1406    addrbytes = util.inet_pton(addrfam, addr)
1407    write_byte(f, len(addrbytes))
1408    f.write(addrbytes)
1409    write_int(f, port)
1410