1# Copyright DataStax, Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import absolute_import # to enable import io from stdlib 16from collections import namedtuple 17import logging 18import socket 19from uuid import UUID 20 21import six 22from six.moves import range 23import io 24 25from cassandra import ProtocolVersion 26from cassandra import type_codes, DriverException 27from cassandra import (Unavailable, WriteTimeout, ReadTimeout, 28 WriteFailure, ReadFailure, FunctionFailure, 29 AlreadyExists, InvalidRequest, Unauthorized, 30 UnsupportedOperation, UserFunctionDescriptor, 31 UserAggregateDescriptor, SchemaTargetType) 32from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, 33 int8_pack, int8_unpack, uint64_pack, header_pack, 34 v3_header_pack, uint32_pack) 35from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, 36 CounterColumnType, DateType, DecimalType, 37 DoubleType, FloatType, Int32Type, 38 InetAddressType, IntegerType, ListType, 39 LongType, MapType, SetType, TimeUUIDType, 40 UTF8Type, VarcharType, UUIDType, UserType, 41 TupleType, lookup_casstype, SimpleDateType, 42 TimeType, ByteType, ShortType, DurationType) 43from cassandra import WriteType 44from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY 45from cassandra import util 46 47log = logging.getLogger(__name__) 48 49 50class NotSupportedError(Exception): 51 pass 52 53 54class InternalError(Exception): 55 pass 56 57ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) 58 59HEADER_DIRECTION_TO_CLIENT = 0x80 60HEADER_DIRECTION_MASK = 0x80 61 62COMPRESSED_FLAG = 0x01 63TRACING_FLAG = 0x02 64CUSTOM_PAYLOAD_FLAG = 0x04 65WARNING_FLAG = 0x08 66USE_BETA_FLAG = 0x10 67USE_BETA_MASK = ~USE_BETA_FLAG 68 69_message_types_by_opcode = {} 70 71_UNSET_VALUE = object() 72 73 74def register_class(cls): 75 _message_types_by_opcode[cls.opcode] = cls 76 77 78def get_registered_classes(): 79 return _message_types_by_opcode.copy() 80 81 82class _RegisterMessageType(type): 83 def __init__(cls, name, bases, dct): 84 if not name.startswith('_'): 85 register_class(cls) 86 87 88@six.add_metaclass(_RegisterMessageType) 89class _MessageType(object): 90 91 tracing = False 92 custom_payload = None 93 warnings = None 94 95 def update_custom_payload(self, other): 96 if other: 97 if not self.custom_payload: 98 self.custom_payload = {} 99 self.custom_payload.update(other) 100 if len(self.custom_payload) > 65535: 101 raise ValueError("Custom payload map exceeds max count allowed by protocol (65535)") 102 103 def __repr__(self): 104 return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self))) 105 106 107def _get_params(message_obj): 108 base_attrs = dir(_MessageType) 109 return ( 110 (n, a) for n, a in message_obj.__dict__.items() 111 if n not in base_attrs and not n.startswith('_') and not callable(a) 112 ) 113 114 115error_classes = {} 116 117 118class ErrorMessage(_MessageType, Exception): 119 opcode = 0x00 120 name = 'ERROR' 121 summary = 'Unknown' 122 123 def __init__(self, code, message, info): 124 self.code = code 125 self.message = message 126 self.info = info 127 128 @classmethod 129 def recv_body(cls, f, protocol_version, *args): 130 code = read_int(f) 131 msg = read_string(f) 132 subcls = error_classes.get(code, cls) 133 extra_info = subcls.recv_error_info(f, protocol_version) 134 return subcls(code=code, message=msg, info=extra_info) 135 136 def summary_msg(self): 137 msg = 'Error from server: code=%04x [%s] message="%s"' \ 138 % (self.code, self.summary, self.message) 139 if six.PY2 and isinstance(msg, six.text_type): 140 msg = msg.encode('utf-8') 141 return msg 142 143 def __str__(self): 144 return '<%s>' % self.summary_msg() 145 __repr__ = __str__ 146 147 @staticmethod 148 def recv_error_info(f, protocol_version): 149 pass 150 151 def to_exception(self): 152 return self 153 154 155class ErrorMessageSubclass(_RegisterMessageType): 156 def __init__(cls, name, bases, dct): 157 if cls.error_code is not None: # Server has an error code of 0. 158 error_classes[cls.error_code] = cls 159 160 161@six.add_metaclass(ErrorMessageSubclass) 162class ErrorMessageSub(ErrorMessage): 163 error_code = None 164 165 166class RequestExecutionException(ErrorMessageSub): 167 pass 168 169 170class RequestValidationException(ErrorMessageSub): 171 pass 172 173 174class ServerError(ErrorMessageSub): 175 summary = 'Server error' 176 error_code = 0x0000 177 178 179class ProtocolException(ErrorMessageSub): 180 summary = 'Protocol error' 181 error_code = 0x000A 182 183 184class BadCredentials(ErrorMessageSub): 185 summary = 'Bad credentials' 186 error_code = 0x0100 187 188 189class UnavailableErrorMessage(RequestExecutionException): 190 summary = 'Unavailable exception' 191 error_code = 0x1000 192 193 @staticmethod 194 def recv_error_info(f, protocol_version): 195 return { 196 'consistency': read_consistency_level(f), 197 'required_replicas': read_int(f), 198 'alive_replicas': read_int(f), 199 } 200 201 def to_exception(self): 202 return Unavailable(self.summary_msg(), **self.info) 203 204 205class OverloadedErrorMessage(RequestExecutionException): 206 summary = 'Coordinator node overloaded' 207 error_code = 0x1001 208 209 210class IsBootstrappingErrorMessage(RequestExecutionException): 211 summary = 'Coordinator node is bootstrapping' 212 error_code = 0x1002 213 214 215class TruncateError(RequestExecutionException): 216 summary = 'Error during truncate' 217 error_code = 0x1003 218 219 220class WriteTimeoutErrorMessage(RequestExecutionException): 221 summary = "Coordinator node timed out waiting for replica nodes' responses" 222 error_code = 0x1100 223 224 @staticmethod 225 def recv_error_info(f, protocol_version): 226 return { 227 'consistency': read_consistency_level(f), 228 'received_responses': read_int(f), 229 'required_responses': read_int(f), 230 'write_type': WriteType.name_to_value[read_string(f)], 231 } 232 233 def to_exception(self): 234 return WriteTimeout(self.summary_msg(), **self.info) 235 236 237class ReadTimeoutErrorMessage(RequestExecutionException): 238 summary = "Coordinator node timed out waiting for replica nodes' responses" 239 error_code = 0x1200 240 241 @staticmethod 242 def recv_error_info(f, protocol_version): 243 return { 244 'consistency': read_consistency_level(f), 245 'received_responses': read_int(f), 246 'required_responses': read_int(f), 247 'data_retrieved': bool(read_byte(f)), 248 } 249 250 def to_exception(self): 251 return ReadTimeout(self.summary_msg(), **self.info) 252 253 254class ReadFailureMessage(RequestExecutionException): 255 summary = "Replica(s) failed to execute read" 256 error_code = 0x1300 257 258 @staticmethod 259 def recv_error_info(f, protocol_version): 260 consistency = read_consistency_level(f) 261 received_responses = read_int(f) 262 required_responses = read_int(f) 263 264 if ProtocolVersion.uses_error_code_map(protocol_version): 265 error_code_map = read_error_code_map(f) 266 failures = len(error_code_map) 267 else: 268 error_code_map = None 269 failures = read_int(f) 270 271 data_retrieved = bool(read_byte(f)) 272 273 return { 274 'consistency': consistency, 275 'received_responses': received_responses, 276 'required_responses': required_responses, 277 'failures': failures, 278 'error_code_map': error_code_map, 279 'data_retrieved': data_retrieved 280 } 281 282 def to_exception(self): 283 return ReadFailure(self.summary_msg(), **self.info) 284 285 286class FunctionFailureMessage(RequestExecutionException): 287 summary = "User Defined Function failure" 288 error_code = 0x1400 289 290 @staticmethod 291 def recv_error_info(f, protocol_version): 292 return { 293 'keyspace': read_string(f), 294 'function': read_string(f), 295 'arg_types': [read_string(f) for _ in range(read_short(f))], 296 } 297 298 def to_exception(self): 299 return FunctionFailure(self.summary_msg(), **self.info) 300 301 302class WriteFailureMessage(RequestExecutionException): 303 summary = "Replica(s) failed to execute write" 304 error_code = 0x1500 305 306 @staticmethod 307 def recv_error_info(f, protocol_version): 308 consistency = read_consistency_level(f) 309 received_responses = read_int(f) 310 required_responses = read_int(f) 311 312 if ProtocolVersion.uses_error_code_map(protocol_version): 313 error_code_map = read_error_code_map(f) 314 failures = len(error_code_map) 315 else: 316 error_code_map = None 317 failures = read_int(f) 318 319 write_type = WriteType.name_to_value[read_string(f)] 320 321 return { 322 'consistency': consistency, 323 'received_responses': received_responses, 324 'required_responses': required_responses, 325 'failures': failures, 326 'error_code_map': error_code_map, 327 'write_type': write_type 328 } 329 330 def to_exception(self): 331 return WriteFailure(self.summary_msg(), **self.info) 332 333 334class CDCWriteException(RequestExecutionException): 335 summary = 'Failed to execute write due to CDC space exhaustion.' 336 error_code = 0x1600 337 338 339class SyntaxException(RequestValidationException): 340 summary = 'Syntax error in CQL query' 341 error_code = 0x2000 342 343 344class UnauthorizedErrorMessage(RequestValidationException): 345 summary = 'Unauthorized' 346 error_code = 0x2100 347 348 def to_exception(self): 349 return Unauthorized(self.summary_msg()) 350 351 352class InvalidRequestException(RequestValidationException): 353 summary = 'Invalid query' 354 error_code = 0x2200 355 356 def to_exception(self): 357 return InvalidRequest(self.summary_msg()) 358 359 360class ConfigurationException(RequestValidationException): 361 summary = 'Query invalid because of configuration issue' 362 error_code = 0x2300 363 364 365class PreparedQueryNotFound(RequestValidationException): 366 summary = 'Matching prepared statement not found on this node' 367 error_code = 0x2500 368 369 @staticmethod 370 def recv_error_info(f, protocol_version): 371 # return the query ID 372 return read_binary_string(f) 373 374 375class AlreadyExistsException(ConfigurationException): 376 summary = 'Item already exists' 377 error_code = 0x2400 378 379 @staticmethod 380 def recv_error_info(f, protocol_version): 381 return { 382 'keyspace': read_string(f), 383 'table': read_string(f), 384 } 385 386 def to_exception(self): 387 return AlreadyExists(**self.info) 388 389 390class StartupMessage(_MessageType): 391 opcode = 0x01 392 name = 'STARTUP' 393 394 KNOWN_OPTION_KEYS = set(( 395 'CQL_VERSION', 396 'COMPRESSION', 397 'NO_COMPACT' 398 )) 399 400 def __init__(self, cqlversion, options): 401 self.cqlversion = cqlversion 402 self.options = options 403 404 def send_body(self, f, protocol_version): 405 optmap = self.options.copy() 406 optmap['CQL_VERSION'] = self.cqlversion 407 write_stringmap(f, optmap) 408 409 410class ReadyMessage(_MessageType): 411 opcode = 0x02 412 name = 'READY' 413 414 @classmethod 415 def recv_body(cls, *args): 416 return cls() 417 418 419class AuthenticateMessage(_MessageType): 420 opcode = 0x03 421 name = 'AUTHENTICATE' 422 423 def __init__(self, authenticator): 424 self.authenticator = authenticator 425 426 @classmethod 427 def recv_body(cls, f, *args): 428 authname = read_string(f) 429 return cls(authenticator=authname) 430 431 432class CredentialsMessage(_MessageType): 433 opcode = 0x04 434 name = 'CREDENTIALS' 435 436 def __init__(self, creds): 437 self.creds = creds 438 439 def send_body(self, f, protocol_version): 440 if protocol_version > 1: 441 raise UnsupportedOperation( 442 "Credentials-based authentication is not supported with " 443 "protocol version 2 or higher. Use the SASL authentication " 444 "mechanism instead.") 445 write_short(f, len(self.creds)) 446 for credkey, credval in self.creds.items(): 447 write_string(f, credkey) 448 write_string(f, credval) 449 450 451class AuthChallengeMessage(_MessageType): 452 opcode = 0x0E 453 name = 'AUTH_CHALLENGE' 454 455 def __init__(self, challenge): 456 self.challenge = challenge 457 458 @classmethod 459 def recv_body(cls, f, *args): 460 return cls(read_binary_longstring(f)) 461 462 463class AuthResponseMessage(_MessageType): 464 opcode = 0x0F 465 name = 'AUTH_RESPONSE' 466 467 def __init__(self, response): 468 self.response = response 469 470 def send_body(self, f, protocol_version): 471 write_longstring(f, self.response) 472 473 474class AuthSuccessMessage(_MessageType): 475 opcode = 0x10 476 name = 'AUTH_SUCCESS' 477 478 def __init__(self, token): 479 self.token = token 480 481 @classmethod 482 def recv_body(cls, f, *args): 483 return cls(read_longstring(f)) 484 485 486class OptionsMessage(_MessageType): 487 opcode = 0x05 488 name = 'OPTIONS' 489 490 def send_body(self, f, protocol_version): 491 pass 492 493 494class SupportedMessage(_MessageType): 495 opcode = 0x06 496 name = 'SUPPORTED' 497 498 def __init__(self, cql_versions, options): 499 self.cql_versions = cql_versions 500 self.options = options 501 502 @classmethod 503 def recv_body(cls, f, *args): 504 options = read_stringmultimap(f) 505 cql_versions = options.pop('CQL_VERSION') 506 return cls(cql_versions=cql_versions, options=options) 507 508 509# used for QueryMessage and ExecuteMessage 510_VALUES_FLAG = 0x01 511_SKIP_METADATA_FLAG = 0x02 512_PAGE_SIZE_FLAG = 0x04 513_WITH_PAGING_STATE_FLAG = 0x08 514_WITH_SERIAL_CONSISTENCY_FLAG = 0x10 515_PROTOCOL_TIMESTAMP = 0x20 516_WITH_KEYSPACE_FLAG = 0x80 517_PREPARED_WITH_KEYSPACE_FLAG = 0x01 518 519 520class QueryMessage(_MessageType): 521 opcode = 0x07 522 name = 'QUERY' 523 524 def __init__(self, query, consistency_level, serial_consistency_level=None, 525 fetch_size=None, paging_state=None, timestamp=None, keyspace=None): 526 self.query = query 527 self.consistency_level = consistency_level 528 self.serial_consistency_level = serial_consistency_level 529 self.fetch_size = fetch_size 530 self.paging_state = paging_state 531 self.timestamp = timestamp 532 self.keyspace = keyspace 533 self._query_params = None # only used internally. May be set to a list of native-encoded values to have them sent with the request. 534 535 def send_body(self, f, protocol_version): 536 write_longstring(f, self.query) 537 write_consistency_level(f, self.consistency_level) 538 flags = 0x00 539 if self._query_params is not None: 540 flags |= _VALUES_FLAG # also v2+, but we're only setting params internally right now 541 542 if self.serial_consistency_level: 543 if protocol_version >= 2: 544 flags |= _WITH_SERIAL_CONSISTENCY_FLAG 545 else: 546 raise UnsupportedOperation( 547 "Serial consistency levels require the use of protocol version " 548 "2 or higher. Consider setting Cluster.protocol_version to 2 " 549 "to support serial consistency levels.") 550 551 if self.fetch_size: 552 if protocol_version >= 2: 553 flags |= _PAGE_SIZE_FLAG 554 else: 555 raise UnsupportedOperation( 556 "Automatic query paging may only be used with protocol version " 557 "2 or higher. Consider setting Cluster.protocol_version to 2.") 558 559 if self.paging_state: 560 if protocol_version >= 2: 561 flags |= _WITH_PAGING_STATE_FLAG 562 else: 563 raise UnsupportedOperation( 564 "Automatic query paging may only be used with protocol version " 565 "2 or higher. Consider setting Cluster.protocol_version to 2.") 566 567 if self.timestamp is not None: 568 flags |= _PROTOCOL_TIMESTAMP 569 570 if self.keyspace is not None: 571 if ProtocolVersion.uses_keyspace_flag(protocol_version): 572 flags |= _WITH_KEYSPACE_FLAG 573 else: 574 raise UnsupportedOperation( 575 "Keyspaces may only be set on queries with protocol version " 576 "5 or higher. Consider setting Cluster.protocol_version to 5.") 577 578 if ProtocolVersion.uses_int_query_flags(protocol_version): 579 write_uint(f, flags) 580 else: 581 write_byte(f, flags) 582 583 if self._query_params is not None: 584 write_short(f, len(self._query_params)) 585 for param in self._query_params: 586 write_value(f, param) 587 588 if self.fetch_size: 589 write_int(f, self.fetch_size) 590 if self.paging_state: 591 write_longstring(f, self.paging_state) 592 if self.serial_consistency_level: 593 write_consistency_level(f, self.serial_consistency_level) 594 if self.timestamp is not None: 595 write_long(f, self.timestamp) 596 if self.keyspace is not None: 597 write_string(f, self.keyspace) 598 599 600CUSTOM_TYPE = object() 601 602RESULT_KIND_VOID = 0x0001 603RESULT_KIND_ROWS = 0x0002 604RESULT_KIND_SET_KEYSPACE = 0x0003 605RESULT_KIND_PREPARED = 0x0004 606RESULT_KIND_SCHEMA_CHANGE = 0x0005 607 608 609class ResultMessage(_MessageType): 610 opcode = 0x08 611 name = 'RESULT' 612 613 kind = None 614 results = None 615 paging_state = None 616 617 # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE) 618 type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_')) 619 620 _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 621 _HAS_MORE_PAGES_FLAG = 0x0002 622 _NO_METADATA_FLAG = 0x0004 623 _METADATA_ID_FLAG = 0x0008 624 625 def __init__(self, kind, results, paging_state=None, col_types=None): 626 self.kind = kind 627 self.results = results 628 self.paging_state = paging_state 629 self.col_types = col_types 630 631 @classmethod 632 def recv_body(cls, f, protocol_version, user_type_map, result_metadata): 633 kind = read_int(f) 634 paging_state = None 635 col_types = None 636 if kind == RESULT_KIND_VOID: 637 results = None 638 elif kind == RESULT_KIND_ROWS: 639 paging_state, col_types, results, result_metadata_id = cls.recv_results_rows( 640 f, protocol_version, user_type_map, result_metadata) 641 elif kind == RESULT_KIND_SET_KEYSPACE: 642 ksname = read_string(f) 643 results = ksname 644 elif kind == RESULT_KIND_PREPARED: 645 results = cls.recv_results_prepared(f, protocol_version, user_type_map) 646 elif kind == RESULT_KIND_SCHEMA_CHANGE: 647 results = cls.recv_results_schema_change(f, protocol_version) 648 else: 649 raise DriverException("Unknown RESULT kind: %d" % kind) 650 return cls(kind, results, paging_state, col_types) 651 652 @classmethod 653 def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): 654 paging_state, column_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map) 655 column_metadata = column_metadata or result_metadata 656 rowcount = read_int(f) 657 rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] 658 colnames = [c[2] for c in column_metadata] 659 coltypes = [c[3] for c in column_metadata] 660 try: 661 parsed_rows = [ 662 tuple(ctype.from_binary(val, protocol_version) 663 for ctype, val in zip(coltypes, row)) 664 for row in rows] 665 except Exception: 666 for row in rows: 667 for i in range(len(row)): 668 try: 669 coltypes[i].from_binary(row[i], protocol_version) 670 except Exception as e: 671 raise DriverException('Failed decoding result column "%s" of type %s: %s' % (colnames[i], 672 coltypes[i].cql_parameterized_type(), 673 str(e))) 674 return paging_state, coltypes, (colnames, parsed_rows), result_metadata_id 675 676 @classmethod 677 def recv_results_prepared(cls, f, protocol_version, user_type_map): 678 query_id = read_binary_string(f) 679 if ProtocolVersion.uses_prepared_metadata(protocol_version): 680 result_metadata_id = read_binary_string(f) 681 else: 682 result_metadata_id = None 683 bind_metadata, pk_indexes, result_metadata, _ = cls.recv_prepared_metadata(f, protocol_version, user_type_map) 684 return query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id 685 686 @classmethod 687 def recv_results_metadata(cls, f, user_type_map): 688 flags = read_int(f) 689 colcount = read_int(f) 690 691 if flags & cls._HAS_MORE_PAGES_FLAG: 692 paging_state = read_binary_longstring(f) 693 else: 694 paging_state = None 695 696 if flags & cls._METADATA_ID_FLAG: 697 result_metadata_id = read_binary_string(f) 698 else: 699 result_metadata_id = None 700 701 no_meta = bool(flags & cls._NO_METADATA_FLAG) 702 if no_meta: 703 return paging_state, [], result_metadata_id 704 705 glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) 706 if glob_tblspec: 707 ksname = read_string(f) 708 cfname = read_string(f) 709 column_metadata = [] 710 for _ in range(colcount): 711 if glob_tblspec: 712 colksname = ksname 713 colcfname = cfname 714 else: 715 colksname = read_string(f) 716 colcfname = read_string(f) 717 colname = read_string(f) 718 coltype = cls.read_type(f, user_type_map) 719 column_metadata.append((colksname, colcfname, colname, coltype)) 720 return paging_state, column_metadata, result_metadata_id 721 722 @classmethod 723 def recv_prepared_metadata(cls, f, protocol_version, user_type_map): 724 flags = read_int(f) 725 colcount = read_int(f) 726 pk_indexes = None 727 if protocol_version >= 4: 728 num_pk_indexes = read_int(f) 729 pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] 730 731 glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) 732 if glob_tblspec: 733 ksname = read_string(f) 734 cfname = read_string(f) 735 bind_metadata = [] 736 for _ in range(colcount): 737 if glob_tblspec: 738 colksname = ksname 739 colcfname = cfname 740 else: 741 colksname = read_string(f) 742 colcfname = read_string(f) 743 colname = read_string(f) 744 coltype = cls.read_type(f, user_type_map) 745 bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) 746 747 if protocol_version >= 2: 748 _, result_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map) 749 return bind_metadata, pk_indexes, result_metadata, result_metadata_id 750 else: 751 return bind_metadata, pk_indexes, None, None 752 753 @classmethod 754 def recv_results_schema_change(cls, f, protocol_version): 755 return EventMessage.recv_schema_change(f, protocol_version) 756 757 @classmethod 758 def read_type(cls, f, user_type_map): 759 optid = read_short(f) 760 try: 761 typeclass = cls.type_codes[optid] 762 except KeyError: 763 raise NotSupportedError("Unknown data type code 0x%04x. Have to skip" 764 " entire result set." % (optid,)) 765 if typeclass in (ListType, SetType): 766 subtype = cls.read_type(f, user_type_map) 767 typeclass = typeclass.apply_parameters((subtype,)) 768 elif typeclass == MapType: 769 keysubtype = cls.read_type(f, user_type_map) 770 valsubtype = cls.read_type(f, user_type_map) 771 typeclass = typeclass.apply_parameters((keysubtype, valsubtype)) 772 elif typeclass == TupleType: 773 num_items = read_short(f) 774 types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items)) 775 typeclass = typeclass.apply_parameters(types) 776 elif typeclass == UserType: 777 ks = read_string(f) 778 udt_name = read_string(f) 779 num_fields = read_short(f) 780 names, types = zip(*((read_string(f), cls.read_type(f, user_type_map)) 781 for _ in range(num_fields))) 782 specialized_type = typeclass.make_udt_class(ks, udt_name, names, types) 783 specialized_type.mapped_class = user_type_map.get(ks, {}).get(udt_name) 784 typeclass = specialized_type 785 elif typeclass == CUSTOM_TYPE: 786 classname = read_string(f) 787 typeclass = lookup_casstype(classname) 788 789 return typeclass 790 791 @staticmethod 792 def recv_row(f, colcount): 793 return [read_value(f) for _ in range(colcount)] 794 795 796class PrepareMessage(_MessageType): 797 opcode = 0x09 798 name = 'PREPARE' 799 800 def __init__(self, query, keyspace=None): 801 self.query = query 802 self.keyspace = keyspace 803 804 def send_body(self, f, protocol_version): 805 write_longstring(f, self.query) 806 807 flags = 0x00 808 809 if self.keyspace is not None: 810 if ProtocolVersion.uses_keyspace_flag(protocol_version): 811 flags |= _PREPARED_WITH_KEYSPACE_FLAG 812 else: 813 raise UnsupportedOperation( 814 "Keyspaces may only be set on queries with protocol version " 815 "5 or higher. Consider setting Cluster.protocol_version to 5.") 816 817 if ProtocolVersion.uses_prepare_flags(protocol_version): 818 write_uint(f, flags) 819 else: 820 # checks above should prevent this, but just to be safe... 821 if flags: 822 raise UnsupportedOperation( 823 "Attempted to set flags with value {flags:0=#8x} on" 824 "protocol version {pv}, which doesn't support flags" 825 "in prepared statements." 826 "Consider setting Cluster.protocol_version to 5." 827 "".format(flags=flags, pv=protocol_version)) 828 829 if ProtocolVersion.uses_keyspace_flag(protocol_version): 830 if self.keyspace: 831 write_string(f, self.keyspace) 832 833 834class ExecuteMessage(_MessageType): 835 opcode = 0x0A 836 name = 'EXECUTE' 837 def __init__(self, query_id, query_params, consistency_level, 838 serial_consistency_level=None, fetch_size=None, 839 paging_state=None, timestamp=None, skip_meta=False, 840 result_metadata_id=None): 841 self.query_id = query_id 842 self.query_params = query_params 843 self.consistency_level = consistency_level 844 self.serial_consistency_level = serial_consistency_level 845 self.fetch_size = fetch_size 846 self.paging_state = paging_state 847 self.timestamp = timestamp 848 self.skip_meta = skip_meta 849 self.result_metadata_id = result_metadata_id 850 851 def send_body(self, f, protocol_version): 852 write_string(f, self.query_id) 853 if ProtocolVersion.uses_prepared_metadata(protocol_version): 854 write_string(f, self.result_metadata_id) 855 if protocol_version == 1: 856 if self.serial_consistency_level: 857 raise UnsupportedOperation( 858 "Serial consistency levels require the use of protocol version " 859 "2 or higher. Consider setting Cluster.protocol_version to 2 " 860 "to support serial consistency levels.") 861 if self.fetch_size or self.paging_state: 862 raise UnsupportedOperation( 863 "Automatic query paging may only be used with protocol version " 864 "2 or higher. Consider setting Cluster.protocol_version to 2.") 865 write_short(f, len(self.query_params)) 866 for param in self.query_params: 867 write_value(f, param) 868 write_consistency_level(f, self.consistency_level) 869 else: 870 write_consistency_level(f, self.consistency_level) 871 flags = _VALUES_FLAG 872 if self.serial_consistency_level: 873 flags |= _WITH_SERIAL_CONSISTENCY_FLAG 874 if self.fetch_size: 875 flags |= _PAGE_SIZE_FLAG 876 if self.paging_state: 877 flags |= _WITH_PAGING_STATE_FLAG 878 if self.timestamp is not None: 879 if protocol_version >= 3: 880 flags |= _PROTOCOL_TIMESTAMP 881 else: 882 raise UnsupportedOperation( 883 "Protocol-level timestamps may only be used with protocol version " 884 "3 or higher. Consider setting Cluster.protocol_version to 3.") 885 if self.skip_meta: 886 flags |= _SKIP_METADATA_FLAG 887 888 if ProtocolVersion.uses_int_query_flags(protocol_version): 889 write_uint(f, flags) 890 else: 891 write_byte(f, flags) 892 893 write_short(f, len(self.query_params)) 894 for param in self.query_params: 895 write_value(f, param) 896 if self.fetch_size: 897 write_int(f, self.fetch_size) 898 if self.paging_state: 899 write_longstring(f, self.paging_state) 900 if self.serial_consistency_level: 901 write_consistency_level(f, self.serial_consistency_level) 902 if self.timestamp is not None: 903 write_long(f, self.timestamp) 904 905 906 907class BatchMessage(_MessageType): 908 opcode = 0x0D 909 name = 'BATCH' 910 911 def __init__(self, batch_type, queries, consistency_level, 912 serial_consistency_level=None, timestamp=None, 913 keyspace=None): 914 self.batch_type = batch_type 915 self.queries = queries 916 self.consistency_level = consistency_level 917 self.serial_consistency_level = serial_consistency_level 918 self.timestamp = timestamp 919 self.keyspace = keyspace 920 921 def send_body(self, f, protocol_version): 922 write_byte(f, self.batch_type.value) 923 write_short(f, len(self.queries)) 924 for prepared, string_or_query_id, params in self.queries: 925 if not prepared: 926 write_byte(f, 0) 927 write_longstring(f, string_or_query_id) 928 else: 929 write_byte(f, 1) 930 write_short(f, len(string_or_query_id)) 931 f.write(string_or_query_id) 932 write_short(f, len(params)) 933 for param in params: 934 write_value(f, param) 935 936 write_consistency_level(f, self.consistency_level) 937 if protocol_version >= 3: 938 flags = 0 939 if self.serial_consistency_level: 940 flags |= _WITH_SERIAL_CONSISTENCY_FLAG 941 if self.timestamp is not None: 942 flags |= _PROTOCOL_TIMESTAMP 943 if self.keyspace: 944 if ProtocolVersion.uses_keyspace_flag(protocol_version): 945 flags |= _WITH_KEYSPACE_FLAG 946 else: 947 raise UnsupportedOperation( 948 "Keyspaces may only be set on queries with protocol version " 949 "5 or higher. Consider setting Cluster.protocol_version to 5.") 950 951 if ProtocolVersion.uses_int_query_flags(protocol_version): 952 write_int(f, flags) 953 else: 954 write_byte(f, flags) 955 956 if self.serial_consistency_level: 957 write_consistency_level(f, self.serial_consistency_level) 958 if self.timestamp is not None: 959 write_long(f, self.timestamp) 960 961 if ProtocolVersion.uses_keyspace_flag(protocol_version): 962 if self.keyspace is not None: 963 write_string(f, self.keyspace) 964 965 966known_event_types = frozenset(( 967 'TOPOLOGY_CHANGE', 968 'STATUS_CHANGE', 969 'SCHEMA_CHANGE' 970)) 971 972 973class RegisterMessage(_MessageType): 974 opcode = 0x0B 975 name = 'REGISTER' 976 977 def __init__(self, event_list): 978 self.event_list = event_list 979 980 def send_body(self, f, protocol_version): 981 write_stringlist(f, self.event_list) 982 983 984class EventMessage(_MessageType): 985 opcode = 0x0C 986 name = 'EVENT' 987 988 def __init__(self, event_type, event_args): 989 self.event_type = event_type 990 self.event_args = event_args 991 992 @classmethod 993 def recv_body(cls, f, protocol_version, *args): 994 event_type = read_string(f).upper() 995 if event_type in known_event_types: 996 read_method = getattr(cls, 'recv_' + event_type.lower()) 997 return cls(event_type=event_type, event_args=read_method(f, protocol_version)) 998 raise NotSupportedError('Unknown event type %r' % event_type) 999 1000 @classmethod 1001 def recv_topology_change(cls, f, protocol_version): 1002 # "NEW_NODE" or "REMOVED_NODE" 1003 change_type = read_string(f) 1004 address = read_inet(f) 1005 return dict(change_type=change_type, address=address) 1006 1007 @classmethod 1008 def recv_status_change(cls, f, protocol_version): 1009 # "UP" or "DOWN" 1010 change_type = read_string(f) 1011 address = read_inet(f) 1012 return dict(change_type=change_type, address=address) 1013 1014 @classmethod 1015 def recv_schema_change(cls, f, protocol_version): 1016 # "CREATED", "DROPPED", or "UPDATED" 1017 change_type = read_string(f) 1018 if protocol_version >= 3: 1019 target = read_string(f) 1020 keyspace = read_string(f) 1021 event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace} 1022 if target != SchemaTargetType.KEYSPACE: 1023 target_name = read_string(f) 1024 if target == SchemaTargetType.FUNCTION: 1025 event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) 1026 elif target == SchemaTargetType.AGGREGATE: 1027 event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) 1028 else: 1029 event[target.lower()] = target_name 1030 else: 1031 keyspace = read_string(f) 1032 table = read_string(f) 1033 if table: 1034 event = {'target_type': SchemaTargetType.TABLE, 'change_type': change_type, 'keyspace': keyspace, 'table': table} 1035 else: 1036 event = {'target_type': SchemaTargetType.KEYSPACE, 'change_type': change_type, 'keyspace': keyspace} 1037 return event 1038 1039 1040class _ProtocolHandler(object): 1041 """ 1042 _ProtocolHander handles encoding and decoding messages. 1043 1044 This class can be specialized to compose Handlers which implement alternative 1045 result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster` 1046 on initialization. 1047 1048 Contracted class methods are :meth:`_ProtocolHandler.encode_message` and :meth:`_ProtocolHandler.decode_message`. 1049 """ 1050 1051 message_types_by_opcode = _message_types_by_opcode.copy() 1052 """ 1053 Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses 1054 this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized 1055 result decoding implementations. 1056 """ 1057 1058 @classmethod 1059 def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version): 1060 """ 1061 Encodes a message using the specified frame parameters, and compressor 1062 1063 :param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver 1064 :param stream_id: protocol stream id for the frame header 1065 :param protocol_version: version for the frame header, and used encoding contents 1066 :param compressor: optional compression function to be used on the body 1067 """ 1068 flags = 0 1069 body = io.BytesIO() 1070 if msg.custom_payload: 1071 if protocol_version < 4: 1072 raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") 1073 flags |= CUSTOM_PAYLOAD_FLAG 1074 write_bytesmap(body, msg.custom_payload) 1075 msg.send_body(body, protocol_version) 1076 body = body.getvalue() 1077 1078 if compressor and len(body) > 0: 1079 body = compressor(body) 1080 flags |= COMPRESSED_FLAG 1081 1082 if msg.tracing: 1083 flags |= TRACING_FLAG 1084 1085 if allow_beta_protocol_version: 1086 flags |= USE_BETA_FLAG 1087 1088 buff = io.BytesIO() 1089 cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body)) 1090 buff.write(body) 1091 1092 return buff.getvalue() 1093 1094 @staticmethod 1095 def _write_header(f, version, flags, stream_id, opcode, length): 1096 """ 1097 Write a CQL protocol frame header. 1098 """ 1099 pack = v3_header_pack if version >= 3 else header_pack 1100 f.write(pack(version, flags, stream_id, opcode)) 1101 write_int(f, length) 1102 1103 @classmethod 1104 def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body, 1105 decompressor, result_metadata): 1106 """ 1107 Decodes a native protocol message body 1108 1109 :param protocol_version: version to use decoding contents 1110 :param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type 1111 :param stream_id: native protocol stream id from the frame header 1112 :param flags: native protocol flags bitmap from the header 1113 :param opcode: native protocol opcode from the header 1114 :param body: frame body 1115 :param decompressor: optional decompression function to inflate the body 1116 :return: a message decoded from the body and frame attributes 1117 """ 1118 if flags & COMPRESSED_FLAG: 1119 if decompressor is None: 1120 raise RuntimeError("No de-compressor available for compressed frame!") 1121 body = decompressor(body) 1122 flags ^= COMPRESSED_FLAG 1123 1124 body = io.BytesIO(body) 1125 if flags & TRACING_FLAG: 1126 trace_id = UUID(bytes=body.read(16)) 1127 flags ^= TRACING_FLAG 1128 else: 1129 trace_id = None 1130 1131 if flags & WARNING_FLAG: 1132 warnings = read_stringlist(body) 1133 flags ^= WARNING_FLAG 1134 else: 1135 warnings = None 1136 1137 if flags & CUSTOM_PAYLOAD_FLAG: 1138 custom_payload = read_bytesmap(body) 1139 flags ^= CUSTOM_PAYLOAD_FLAG 1140 else: 1141 custom_payload = None 1142 1143 flags &= USE_BETA_MASK # will only be set if we asserted it in connection estabishment 1144 1145 if flags: 1146 log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) 1147 1148 msg_class = cls.message_types_by_opcode[opcode] 1149 msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata) 1150 msg.stream_id = stream_id 1151 msg.trace_id = trace_id 1152 msg.custom_payload = custom_payload 1153 msg.warnings = warnings 1154 1155 if msg.warnings: 1156 for w in msg.warnings: 1157 log.warning("Server warning: %s", w) 1158 1159 return msg 1160 1161def cython_protocol_handler(colparser): 1162 """ 1163 Given a column parser to deserialize ResultMessages, return a suitable 1164 Cython-based protocol handler. 1165 1166 There are three Cython-based protocol handlers: 1167 1168 - obj_parser.ListParser 1169 decodes result messages into a list of tuples 1170 1171 - obj_parser.LazyParser 1172 decodes result messages lazily by returning an iterator 1173 1174 - numpy_parser.NumPyParser 1175 decodes result messages into NumPy arrays 1176 1177 The default is to use obj_parser.ListParser 1178 """ 1179 from cassandra.row_parser import make_recv_results_rows 1180 1181 class FastResultMessage(ResultMessage): 1182 """ 1183 Cython version of Result Message that has a faster implementation of 1184 recv_results_row. 1185 """ 1186 # type_codes = ResultMessage.type_codes.copy() 1187 code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items()) 1188 recv_results_rows = classmethod(make_recv_results_rows(colparser)) 1189 1190 class CythonProtocolHandler(_ProtocolHandler): 1191 """ 1192 Use FastResultMessage to decode query result message messages. 1193 """ 1194 1195 my_opcodes = _ProtocolHandler.message_types_by_opcode.copy() 1196 my_opcodes[FastResultMessage.opcode] = FastResultMessage 1197 message_types_by_opcode = my_opcodes 1198 1199 col_parser = colparser 1200 1201 return CythonProtocolHandler 1202 1203 1204if HAVE_CYTHON: 1205 from cassandra.obj_parser import ListParser, LazyParser 1206 ProtocolHandler = cython_protocol_handler(ListParser()) 1207 LazyProtocolHandler = cython_protocol_handler(LazyParser()) 1208else: 1209 # Use Python-based ProtocolHandler 1210 ProtocolHandler = _ProtocolHandler 1211 LazyProtocolHandler = None 1212 1213 1214if HAVE_CYTHON and HAVE_NUMPY: 1215 from cassandra.numpy_parser import NumpyParser 1216 NumpyProtocolHandler = cython_protocol_handler(NumpyParser()) 1217else: 1218 NumpyProtocolHandler = None 1219 1220 1221def read_byte(f): 1222 return int8_unpack(f.read(1)) 1223 1224 1225def write_byte(f, b): 1226 f.write(int8_pack(b)) 1227 1228 1229def read_int(f): 1230 return int32_unpack(f.read(4)) 1231 1232 1233def write_int(f, i): 1234 f.write(int32_pack(i)) 1235 1236 1237def write_uint(f, i): 1238 f.write(uint32_pack(i)) 1239 1240 1241def write_long(f, i): 1242 f.write(uint64_pack(i)) 1243 1244 1245def read_short(f): 1246 return uint16_unpack(f.read(2)) 1247 1248 1249def write_short(f, s): 1250 f.write(uint16_pack(s)) 1251 1252 1253def read_consistency_level(f): 1254 return read_short(f) 1255 1256 1257def write_consistency_level(f, cl): 1258 write_short(f, cl) 1259 1260 1261def read_string(f): 1262 size = read_short(f) 1263 contents = f.read(size) 1264 return contents.decode('utf8') 1265 1266 1267def read_binary_string(f): 1268 size = read_short(f) 1269 contents = f.read(size) 1270 return contents 1271 1272 1273def write_string(f, s): 1274 if isinstance(s, six.text_type): 1275 s = s.encode('utf8') 1276 write_short(f, len(s)) 1277 f.write(s) 1278 1279 1280def read_binary_longstring(f): 1281 size = read_int(f) 1282 contents = f.read(size) 1283 return contents 1284 1285 1286def read_longstring(f): 1287 return read_binary_longstring(f).decode('utf8') 1288 1289 1290def write_longstring(f, s): 1291 if isinstance(s, six.text_type): 1292 s = s.encode('utf8') 1293 write_int(f, len(s)) 1294 f.write(s) 1295 1296 1297def read_stringlist(f): 1298 numstrs = read_short(f) 1299 return [read_string(f) for _ in range(numstrs)] 1300 1301 1302def write_stringlist(f, stringlist): 1303 write_short(f, len(stringlist)) 1304 for s in stringlist: 1305 write_string(f, s) 1306 1307 1308def read_stringmap(f): 1309 numpairs = read_short(f) 1310 strmap = {} 1311 for _ in range(numpairs): 1312 k = read_string(f) 1313 strmap[k] = read_string(f) 1314 return strmap 1315 1316 1317def write_stringmap(f, strmap): 1318 write_short(f, len(strmap)) 1319 for k, v in strmap.items(): 1320 write_string(f, k) 1321 write_string(f, v) 1322 1323 1324def read_bytesmap(f): 1325 numpairs = read_short(f) 1326 bytesmap = {} 1327 for _ in range(numpairs): 1328 k = read_string(f) 1329 bytesmap[k] = read_value(f) 1330 return bytesmap 1331 1332 1333def write_bytesmap(f, bytesmap): 1334 write_short(f, len(bytesmap)) 1335 for k, v in bytesmap.items(): 1336 write_string(f, k) 1337 write_value(f, v) 1338 1339 1340def read_stringmultimap(f): 1341 numkeys = read_short(f) 1342 strmmap = {} 1343 for _ in range(numkeys): 1344 k = read_string(f) 1345 strmmap[k] = read_stringlist(f) 1346 return strmmap 1347 1348 1349def write_stringmultimap(f, strmmap): 1350 write_short(f, len(strmmap)) 1351 for k, v in strmmap.items(): 1352 write_string(f, k) 1353 write_stringlist(f, v) 1354 1355 1356def read_error_code_map(f): 1357 numpairs = read_int(f) 1358 error_code_map = {} 1359 for _ in range(numpairs): 1360 endpoint = read_inet_addr_only(f) 1361 error_code_map[endpoint] = read_short(f) 1362 return error_code_map 1363 1364 1365def read_value(f): 1366 size = read_int(f) 1367 if size < 0: 1368 return None 1369 return f.read(size) 1370 1371 1372def write_value(f, v): 1373 if v is None: 1374 write_int(f, -1) 1375 elif v is _UNSET_VALUE: 1376 write_int(f, -2) 1377 else: 1378 write_int(f, len(v)) 1379 f.write(v) 1380 1381 1382def read_inet_addr_only(f): 1383 size = read_byte(f) 1384 addrbytes = f.read(size) 1385 if size == 4: 1386 addrfam = socket.AF_INET 1387 elif size == 16: 1388 addrfam = socket.AF_INET6 1389 else: 1390 raise InternalError("bad inet address: %r" % (addrbytes,)) 1391 return util.inet_ntop(addrfam, addrbytes) 1392 1393 1394def read_inet(f): 1395 addr = read_inet_addr_only(f) 1396 port = read_int(f) 1397 return (addr, port) 1398 1399 1400def write_inet(f, addrtuple): 1401 addr, port = addrtuple 1402 if ':' in addr: 1403 addrfam = socket.AF_INET6 1404 else: 1405 addrfam = socket.AF_INET 1406 addrbytes = util.inet_pton(addrfam, addr) 1407 write_byte(f, len(addrbytes)) 1408 f.write(addrbytes) 1409 write_int(f, port) 1410