1 /* 2 Copyright (c) DataStax, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #ifndef MOCKSSANDRA_HPP 18 #define MOCKSSANDRA_HPP 19 20 #include <uv.h> 21 22 #include <openssl/bio.h> 23 #include <openssl/err.h> 24 #include <openssl/ssl.h> 25 26 #include <stdint.h> 27 28 #include "address.hpp" 29 #include "event_loop.hpp" 30 #include "list.hpp" 31 #include "map.hpp" 32 #include "ref_counted.hpp" 33 #include "scoped_ptr.hpp" 34 #include "string.hpp" 35 #include "third_party/mt19937_64/mt19937_64.hpp" 36 #include "timer.hpp" 37 #include "vector.hpp" 38 39 #if defined(WIN32) || defined(_WIN32) 40 #undef ERROR_ALREADY_EXISTS 41 #undef ERROR 42 #undef X509_NAME 43 #endif 44 45 #define CLIENT_OPTIONS_QUERY "client.options" 46 47 using datastax::String; 48 using datastax::internal::Atomic; 49 using datastax::internal::List; 50 using datastax::internal::Map; 51 using datastax::internal::RefCounted; 52 using datastax::internal::ScopedPtr; 53 using datastax::internal::SharedRefPtr; 54 using datastax::internal::Vector; 55 using datastax::internal::core::Address; 56 using datastax::internal::core::EventLoop; 57 using datastax::internal::core::EventLoopGroup; 58 using datastax::internal::core::RoundRobinEventLoopGroup; 59 using datastax::internal::core::Task; 60 using datastax::internal::core::Timer; 61 62 namespace mockssandra { 63 64 class Ssl { 65 public: 66 static String generate_key(); 67 static String generate_cert(const String& key, String cn = "", String ca_cert = "", 68 String ca_key = ""); 69 }; 70 71 namespace internal { 72 73 class ServerConnection; 74 75 class Tcp { 76 public: 77 Tcp(void* data); 78 79 int init(uv_loop_t* loop); 80 int bind(const struct sockaddr* addr); 81 82 uv_handle_t* as_handle(); 83 uv_stream_t* as_stream(); 84 85 private: 86 uv_tcp_t tcp_; 87 }; 88 89 class ClientConnection { 90 public: 91 ClientConnection(ServerConnection* server); 92 virtual ~ClientConnection(); 93 server()94 ServerConnection* server() { return server_; } 95 on_accept()96 virtual int on_accept() { return accept(); } on_close()97 virtual void on_close() {} 98 on_read(const char * data,size_t len)99 virtual void on_read(const char* data, size_t len) {} on_write()100 virtual void on_write() {} 101 102 int write(const String& data); 103 int write(const char* data, size_t len); 104 void close(); 105 106 protected: 107 int accept(); 108 109 const char* sni_server_name() const; 110 111 private: 112 static void on_close(uv_handle_t* handle); 113 void handle_close(); 114 115 static void on_alloc(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf); 116 117 static void on_read(uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf); 118 void handle_read(ssize_t nread, const uv_buf_t* buf); 119 120 static void on_write(uv_write_t* req, int status); 121 void handle_write(int status); 122 123 private: 124 int internal_write(const char* data, size_t len); 125 int ssl_write(const char* data, size_t len); 126 127 bool is_handshake_done(); 128 bool has_ssl_error(int rc); 129 130 void on_ssl_read(const char* data, size_t len); 131 132 private: 133 Tcp tcp_; 134 ServerConnection* const server_; 135 136 private: 137 enum SslHandshakeState { 138 SSL_HANDSHAKE_INPROGRESS, 139 SSL_HANDSHAKE_FINAL_WRITE, 140 SSL_HANDSHAKE_DONE 141 }; 142 143 SSL* ssl_; 144 BIO* incoming_bio_; 145 BIO* outgoing_bio_; 146 }; 147 148 class ClientConnectionFactory { 149 public: 150 virtual ClientConnection* create(ServerConnection* server) const = 0; ~ClientConnectionFactory()151 virtual ~ClientConnectionFactory() {} 152 }; 153 154 class ServerConnectionTask : public RefCounted<ServerConnectionTask> { 155 public: 156 typedef SharedRefPtr<ServerConnectionTask> Ptr; 157 ~ServerConnectionTask()158 virtual ~ServerConnectionTask() {} 159 virtual void run(ServerConnection* server_connection) = 0; 160 }; 161 162 typedef Vector<ClientConnection*> ClientConnections; 163 164 class ServerConnection : public RefCounted<ServerConnection> { 165 public: 166 typedef SharedRefPtr<ServerConnection> Ptr; 167 168 ServerConnection(const Address& address, const ClientConnectionFactory& factory); 169 ~ServerConnection(); 170 address() const171 const Address& address() const { return address_; } 172 uv_loop_t* loop(); ssl_context()173 SSL_CTX* ssl_context() { return ssl_context_; } clients() const174 const ClientConnections& clients() const { return clients_; } 175 176 bool use_ssl(const String& key, const String& cert, const String& ca_cert = "", 177 bool require_client_cert = false); 178 179 void listen(EventLoopGroup* event_loop_group); 180 int wait_listen(); 181 182 void close(); 183 void wait_close(); 184 185 unsigned connection_attempts() const; 186 void run(const ServerConnectionTask::Ptr& task); 187 188 private: 189 friend class ClientConnection; 190 int accept(uv_stream_t* client); 191 void remove(ClientConnection* connection); 192 193 private: 194 friend class RunListen; 195 friend class RunClose; 196 197 void internal_listen(); 198 void internal_close(); 199 void maybe_close(); 200 201 void signal_listen(int rc); 202 void signal_close(); 203 204 static void on_connection(uv_stream_t* stream, int status); 205 void handle_connection(int status); 206 207 static void on_close(uv_handle_t* handle); 208 void handle_close(); 209 210 static int on_password(char* buf, int size, int rwflag, void* password); 211 212 private: 213 enum State { STATE_CLOSED, STATE_CLOSING, STATE_PENDING, STATE_LISTENING }; 214 215 Tcp tcp_; 216 EventLoop* event_loop_; 217 State state_; 218 int rc_; 219 uv_mutex_t mutex_; 220 uv_cond_t cond_; 221 ClientConnections clients_; 222 const Address address_; 223 const ClientConnectionFactory& factory_; 224 SSL_CTX* ssl_context_; 225 Atomic<unsigned> connection_attempts_; 226 }; 227 228 } // namespace internal 229 230 enum { 231 FLAG_COMPRESSION = 0x01, 232 FLAG_TRACING = 0x02, 233 FLAG_CUSTOM_PAYLOAD = 0x04, 234 FLAG_WARNING = 0x08, 235 FLAG_BETA = 0x10 236 }; 237 238 enum { 239 OPCODE_ERROR = 0x00, 240 OPCODE_STARTUP = 0x01, 241 OPCODE_READY = 0x02, 242 OPCODE_AUTHENTICATE = 0x03, 243 OPCODE_CREDENTIALS = 0x04, 244 OPCODE_OPTIONS = 0x05, 245 OPCODE_SUPPORTED = 0x06, 246 OPCODE_QUERY = 0x07, 247 OPCODE_RESULT = 0x08, 248 OPCODE_PREPARE = 0x09, 249 OPCODE_EXECUTE = 0x0A, 250 OPCODE_REGISTER = 0x0B, 251 OPCODE_EVENT = 0x0C, 252 OPCODE_BATCH = 0x0D, 253 OPCODE_AUTH_CHALLENGE = 0x0E, 254 OPCODE_AUTH_RESPONSE = 0x0F, 255 OPCODE_AUTH_SUCCESS = 0x10, 256 OPCODE_LAST_ENTRY 257 }; 258 259 enum { 260 QUERY_FLAG_VALUES = 0x01, 261 QUERY_FLAG_SKIP_METADATA = 0x02, 262 QUERY_FLAG_PAGE_SIZE = 0x04, 263 QUERY_FLAG_PAGE_STATE = 0x08, 264 QUERY_FLAG_SERIAL_CONSISTENCY = 0x10, 265 QUERY_FLAG_TIMESTAMP = 0x20, 266 QUERY_FLAG_NAMES_FOR_VALUES = 0x40, 267 QUERY_FLAG_KEYSPACE = 0x80 268 }; 269 270 enum { PREPARE_FLAGS_KEYSPACE = 0x01 }; 271 272 enum { 273 ERROR_SERVER_ERROR = 0x0000, 274 ERROR_PROTOCOL_ERROR = 0x000A, 275 ERROR_BAD_CREDENTIALS = 0x0100, 276 ERROR_UNAVAILABLE = 0x1000, 277 ERROR_OVERLOADED = 0x1001, 278 ERROR_IS_BOOTSTRAPPING = 0x1002, 279 ERROR_TRUNCATE_ERROR = 0x1003, 280 ERROR_WRITE_TIMEOUT = 0x1100, 281 ERROR_READ_TIMEOUT = 0x1200, 282 ERROR_READ_FAILURE = 0x1300, 283 ERROR_FUNCTION_FAILURE = 0x1400, 284 ERROR_WRITE_FAILURE = 0x1500, 285 ERROR_SYNTAX_ERROR = 0x2000, 286 ERROR_UNAUTHORIZED = 0x2100, 287 ERROR_INVALID_QUERY = 0x2200, 288 ERROR_CONFIG_ERROR = 0x2300, 289 ERROR_ALREADY_EXISTS = 0x2400, 290 ERROR_UNPREPARED = 0x2500, 291 ERROR_CLIENT_WRITE_FAILURE = 0x8000 292 }; 293 294 enum { 295 RESULT_VOID = 0x0001, 296 RESULT_ROWS = 0x0002, 297 RESULT_SET_KEYSPACE = 0x0003, 298 RESULT_PREPARED = 0x0004, 299 RESULT_SCHEMA_CHANGE = 0x0005 300 }; 301 302 enum { 303 RESULT_FLAG_GLOBAL_TABLESPEC = 0x00000001, 304 RESULT_FLAG_HAS_MORE_PAGES = 0x00000002, 305 RESULT_FLAG_NO_METADATA = 0x00000004, 306 RESULT_FLAG_METADATA_CHANGED = 0x00000008, 307 RESULT_FLAG_CONTINUOUS_PAGING = 0x40000000, 308 RESULT_FLAG_LAST_CONTINUOUS_PAGE = 0x80000000 309 }; 310 311 enum { 312 TYPE_CUSTOM = 0x0000, 313 TYPE_ASCII = 0x0001, 314 TYPE_BIGINT = 0x0002, 315 TYPE_BLOG = 0x0003, 316 TYPE_BOOLEAN = 0x0004, 317 TYPE_COUNTER = 0x0005, 318 TYPE_DECIMAL = 0x0006, 319 TYPE_DOUBLE = 0x0007, 320 TYPE_FLOAT = 0x0008, 321 TYPE_INT = 0x0009, 322 TYPE_TIMESTAMP = 0x000B, 323 TYPE_UUID = 0x000C, 324 TYPE_VARCHAR = 0x000D, 325 TYPE_VARINT = 0x000E, 326 TYPE_TIMEUUD = 0x000F, 327 TYPE_INET = 0x0010, 328 TYPE_DATE = 0x0011, 329 TYPE_TIME = 0x0012, 330 TYPE_SMALLINT = 0x0013, 331 TYPE_TINYINT = 0x0014, 332 TYPE_LIST = 0x0020, 333 TYPE_MAP = 0x0021, 334 TYPE_SET = 0x0022, 335 TYPE_UDT = 0x0030, 336 TYPE_TUPLE = 0x0031 337 }; 338 339 typedef std::pair<String, String> Option; 340 typedef Vector<Option> Options; 341 typedef std::pair<String, String> Credential; 342 typedef Vector<Credential> Credentials; 343 typedef Vector<String> EventTypes; 344 typedef Vector<String> Values; 345 typedef Vector<String> Names; 346 347 struct PrepareParameters { 348 int32_t flags; 349 String keyspace; 350 }; 351 352 struct QueryParameters { 353 uint16_t consistency; 354 int32_t flags; 355 Values values; 356 Names names; 357 int32_t result_page_size; 358 String paging_state; 359 uint16_t serial_consistency; 360 int64_t timestamp; 361 String keyspace; 362 }; 363 364 int32_t encode_int32(int32_t value, String* output); 365 int32_t encode_string(const String& value, String* output); 366 int32_t encode_string_map(const Map<String, Vector<String> >& value, String* output); 367 368 class Type { 369 public: 370 static Type text(); 371 static Type inet(); 372 static Type uuid(); 373 static Type list(const Type& sub_type); 374 375 void encode(int protocol_version, String* output) const; 376 377 private: Type()378 Type() 379 : type_(-1) {} 380 Type(int type)381 Type(int type) 382 : type_(type) {} 383 384 friend class Vector<Type>; 385 386 private: 387 int type_; 388 String custom_; 389 Vector<String> names_; 390 Vector<Type> types_; 391 }; 392 393 class Column { 394 public: Column(const String & name,const Type type)395 Column(const String& name, const Type type) 396 : name_(name) 397 , type_(type) {} 398 399 void encode(int protocol_version, String* output) const; 400 401 private: 402 String name_; 403 Type type_; 404 }; 405 406 class ResultSet; 407 class Collection; 408 409 class Value { 410 private: 411 enum Type { NUL, VALUE, COLLECTION }; 412 413 public: 414 Value(); 415 Value(const String& value); 416 Value(const Collection& collection); 417 Value(const Value& other); 418 ~Value(); 419 420 void encode(int protocol_version, String* output) const; 421 422 private: 423 Type type_; 424 union { 425 String* value_; 426 Collection* collection_; 427 }; 428 }; 429 430 class Collection { 431 public: 432 class Builder { 433 public: Builder(const Type & sub_type)434 Builder(const Type& sub_type) 435 : sub_type_(sub_type) {} 436 text(const String & text)437 Builder& text(const String& text) { 438 values_.push_back(Value(text)); 439 return *this; 440 } 441 build()442 Collection build() { return Collection(sub_type_, values_); } 443 444 private: 445 Type sub_type_; 446 Vector<Value> values_; 447 }; 448 449 void encode(int protocol_version, String* output) const; 450 text(const Vector<String> & values)451 static Collection text(const Vector<String>& values) { 452 Collection::Builder builder(Type::text()); 453 for (Vector<String>::const_iterator it = values.begin(), end = values.end(); it != end; ++it) { 454 builder.text(*it); 455 } 456 return builder.build(); 457 } 458 459 private: Collection(const Type & sub_type,const Vector<Value> values)460 Collection(const Type& sub_type, const Vector<Value> values) 461 : sub_type_(sub_type) 462 , values_(values) {} 463 464 private: 465 const Type sub_type_; 466 const Vector<Value> values_; 467 }; 468 469 class Row { 470 public: 471 class Builder { 472 public: 473 Builder& text(const String& text); 474 475 Builder& inet(const Address& inet); 476 477 Builder& uuid(const CassUuid& uuid); 478 479 Builder& collection(const Collection& collection); 480 build() const481 Row build() const { return Row(values_); } 482 483 private: 484 Vector<Value> values_; 485 }; 486 487 void encode(int protocol_version, String* output) const; 488 489 private: Row(const Vector<Value> & values)490 Row(const Vector<Value>& values) 491 : values_(values) {} 492 493 private: 494 Vector<Value> values_; 495 }; 496 497 class ResultSet { 498 public: 499 class Builder { 500 public: Builder(const String & keyspace_name,const String & table_name)501 Builder(const String& keyspace_name, const String& table_name) 502 : keyspace_name_(keyspace_name) 503 , table_name_(table_name) {} 504 column(const String & name,const Type & type)505 Builder& column(const String& name, const Type& type) { 506 columns_.push_back(Column(name, type)); 507 return *this; 508 } 509 row(const Row & row)510 Builder& row(const Row& row) { 511 rows_.push_back(row); 512 return *this; 513 } 514 build() const515 ResultSet build() const { return ResultSet(keyspace_name_, table_name_, columns_, rows_); } 516 517 private: 518 const String keyspace_name_; 519 const String table_name_; 520 Vector<Column> columns_; 521 Vector<Row> rows_; 522 }; 523 524 String encode(int protocol_version) const; 525 column_count() const526 size_t column_count() const { return columns_.size(); } 527 528 private: ResultSet(const String & keyspace_name,const String & table_name,const Vector<Column> & columns,const Vector<Row> & rows)529 ResultSet(const String& keyspace_name, const String& table_name, const Vector<Column>& columns, 530 const Vector<Row>& rows) 531 : keyspace_name_(keyspace_name) 532 , table_name_(table_name) 533 , columns_(columns) 534 , rows_(rows) {} 535 536 private: 537 const String keyspace_name_; 538 const String table_name_; 539 const Vector<Column> columns_; 540 const Vector<Row> rows_; 541 }; 542 543 struct Exception : public std::exception { Exceptionmockssandra::Exception544 Exception(int code, const String& message) 545 : code(code) 546 , message(message) {} ~Exceptionmockssandra::Exception547 virtual ~Exception() throw() {} 548 const int code; 549 const String message; 550 }; 551 552 struct Host { Hostmockssandra::Host553 Host() {} 554 Host(const Address& address, const String& dc, const String& rack, MT19937_64& token_rng, 555 int num_tokens = 2); 556 Address address; 557 String dc; 558 String rack; 559 String partitioner; 560 Vector<String> tokens; 561 }; 562 563 typedef Vector<Host> Hosts; 564 565 class ClientConnection; 566 class Cluster; 567 class Request; 568 569 typedef std::pair<String, ResultSet> Match; 570 typedef Vector<Match> Matches; 571 572 struct Predicate; 573 574 struct Action { 575 class PredicateBuilder; 576 577 class Builder { 578 public: Builder()579 Builder() 580 : last_(NULL) {} 581 582 Builder& reset(); 583 584 Builder& execute(Action* action); 585 Builder& execute_if(Action* action); 586 587 Builder& nop(); 588 Builder& wait(uint64_t timeout); 589 Builder& close(); 590 Builder& error(int32_t code, const String& message); 591 Builder& invalid_protocol(); 592 Builder& invalid_opcode(); 593 594 Builder& ready(); 595 Builder& authenticate(const String& class_name); 596 Builder& auth_challenge(const String& token); 597 Builder& auth_success(const String& token = ""); 598 Builder& supported(); 599 Builder& up_event(const Address& address); 600 601 Builder& void_result(); 602 Builder& empty_rows_result(int32_t row_count); 603 Builder& no_result(); 604 Builder& match_query(const Matches& matches); 605 606 Builder& client_options(); 607 608 Builder& system_local(); 609 Builder& system_local_dse(); 610 Builder& system_peers(); 611 Builder& system_peers_dse(); 612 Builder& system_traces(); 613 614 Builder& use_keyspace(const String& keyspace); 615 Builder& use_keyspace(const Vector<String>& keyspaces); 616 Builder& plaintext_auth(const String& username = "cassandra", 617 const String& password = "cassandra"); 618 619 Builder& validate_startup(); 620 Builder& validate_credentials(); 621 Builder& validate_auth_response(); 622 Builder& validate_register(); 623 Builder& validate_query(); 624 625 Builder& set_registered_for_events(); 626 Builder& set_protocol_version(); 627 628 PredicateBuilder is_address(const Address& address); 629 PredicateBuilder is_address(const String& address, int port = 9042); 630 631 PredicateBuilder is_query(const String& query); 632 633 Action* build(); 634 635 private: 636 ScopedPtr<Action> first_; 637 Action* last_; 638 }; 639 640 class PredicateBuilder { 641 public: PredicateBuilder(Builder & builder)642 PredicateBuilder(Builder& builder) 643 : builder_(builder) {} 644 then(Builder & builder)645 Builder& then(Builder& builder) { return then(builder.build()); } 646 then(Action * action)647 Builder& then(Action* action) { return builder_.execute_if(action); } 648 649 private: 650 Builder& builder_; 651 }; 652 Actionmockssandra::Action653 Action() 654 : next(NULL) {} ~Actionmockssandra::Action655 virtual ~Action() { delete next; } 656 657 void run(Request* request) const; 658 void run_next(Request* request) const; 659 is_predicatemockssandra::Action660 virtual bool is_predicate() const { return false; } 661 virtual void on_run(Request* request) const = 0; 662 663 const Action* next; 664 }; 665 666 struct Predicate : public Action { Predicatemockssandra::Predicate667 Predicate() 668 : then(NULL) {} ~Predicatemockssandra::Predicate669 virtual ~Predicate() { delete then; } 670 is_predicatemockssandra::Predicate671 virtual bool is_predicate() const { return true; } 672 virtual bool is_true(Request* request) const = 0; 673 on_runmockssandra::Predicate674 virtual void on_run(Request* request) const { 675 if (is_true(request)) { 676 if (then) { 677 then->run(request); 678 } 679 } else { 680 run_next(request); 681 } 682 } 683 684 const Action* then; 685 }; 686 687 class Request 688 : public List<Request>::Node 689 , public RefCounted<Request> { 690 public: 691 typedef SharedRefPtr<Request> Ptr; 692 693 Request(int8_t version, int8_t flags, int16_t stream, int8_t opcode, const String& body, 694 ClientConnection* client); 695 version() const696 int8_t version() const { return version_; } stream() const697 int16_t stream() const { return stream_; } opcode() const698 int8_t opcode() const { return opcode_; } 699 client() const700 ClientConnection* client() const { return client_; } 701 702 void write(int8_t opcode, const String& body); 703 void write(int16_t stream, int8_t opcode, const String& body); 704 void error(int32_t code, const String& message); 705 void wait(uint64_t timeout, const Action* action); 706 void close(); 707 708 bool decode_startup(Options* options); 709 bool decode_credentials(Credentials* creds); 710 bool decode_auth_response(String* token); 711 bool decode_register(EventTypes* types); 712 bool decode_query(String* query, QueryParameters* params); 713 bool decode_execute(String* id, QueryParameters* params); 714 bool decode_prepare(String* query, PrepareParameters* params); 715 716 const Address& address() const; 717 const Host& host(const Address& address) const; 718 Hosts hosts() const; 719 720 private: 721 void on_timeout(Timer* timer); 722 start()723 const char* start() { return body_.data(); } end()724 const char* end() { return body_.data() + body_.size(); } 725 726 private: 727 const int8_t version_; 728 const int8_t flags_; 729 const int16_t stream_; 730 const int8_t opcode_; 731 const String body_; 732 ClientConnection* const client_; 733 Timer timer_; 734 const Action* timer_action_; 735 }; 736 737 struct Nop : public Action { on_runmockssandra::Nop738 virtual void on_run(Request* request) const {} 739 }; 740 741 struct Wait : public Action { Waitmockssandra::Wait742 Wait(uint64_t timeout) 743 : timeout(timeout) {} 744 on_runmockssandra::Wait745 virtual void on_run(Request* request) const { request->wait(timeout, this); } 746 747 const uint64_t timeout; 748 }; 749 750 struct Close : public Action { on_runmockssandra::Close751 virtual void on_run(Request* request) const { request->close(); } 752 }; 753 754 struct SendError : public Action { SendErrormockssandra::SendError755 SendError(int32_t code, const String& message) 756 : code(code) 757 , message(message) {} 758 759 virtual void on_run(Request* request) const; 760 761 int32_t code; 762 String message; 763 }; 764 765 struct SendReady : public Action { 766 virtual void on_run(Request* request) const; 767 }; 768 769 struct SendAuthenticate : public Action { SendAuthenticatemockssandra::SendAuthenticate770 SendAuthenticate(const String& class_name) 771 : class_name(class_name) {} 772 virtual void on_run(Request* request) const; 773 String class_name; 774 }; 775 776 struct SendAuthChallenge : public Action { SendAuthChallengemockssandra::SendAuthChallenge777 SendAuthChallenge(const String& token) 778 : token(token) {} 779 virtual void on_run(Request* request) const; 780 String token; 781 }; 782 783 struct SendAuthSuccess : public Action { SendAuthSuccessmockssandra::SendAuthSuccess784 SendAuthSuccess(const String& token) 785 : token(token) {} 786 virtual void on_run(Request* request) const; 787 String token; 788 }; 789 790 struct SendSupported : public Action { 791 virtual void on_run(Request* request) const; 792 }; 793 794 struct SendUpEvent : public Action { SendUpEventmockssandra::SendUpEvent795 SendUpEvent(const Address& address) 796 : address(address) {} 797 virtual void on_run(Request* request) const; 798 Address address; 799 }; 800 801 struct VoidResult : public Action { 802 virtual void on_run(Request* request) const; 803 }; 804 805 struct EmptyRowsResult : public Action { EmptyRowsResultmockssandra::EmptyRowsResult806 EmptyRowsResult(int row_count) 807 : row_count(row_count) {} 808 virtual void on_run(Request* request) const; 809 int32_t row_count; 810 }; 811 812 struct NoResult : public Action { 813 virtual void on_run(Request* request) const; 814 }; 815 816 struct MatchQuery : public Action { MatchQuerymockssandra::MatchQuery817 MatchQuery(const Matches& matches) 818 : matches(matches) {} 819 virtual void on_run(Request* request) const; 820 Matches matches; 821 }; 822 823 struct ClientOptions : public Action { 824 virtual void on_run(Request* request) const; 825 }; 826 827 struct SystemLocal : public Action { 828 virtual void on_run(Request* request) const; 829 }; 830 831 struct SystemLocalDse : public Action { 832 virtual void on_run(Request* request) const; 833 }; 834 835 struct SystemPeers : public Action { 836 virtual void on_run(Request* request) const; 837 }; 838 839 struct SystemPeersDse : public Action { 840 virtual void on_run(Request* request) const; 841 }; 842 843 struct SystemTraces : public Action { 844 virtual void on_run(Request* request) const; 845 }; 846 847 struct UseKeyspace : public Action { UseKeyspacemockssandra::UseKeyspace848 UseKeyspace(const String& keyspace) { keyspaces.push_back(keyspace); } UseKeyspacemockssandra::UseKeyspace849 UseKeyspace(const Vector<String>& keyspaces) 850 : keyspaces(keyspaces) {} 851 virtual void on_run(Request* request) const; 852 Vector<String> keyspaces; 853 }; 854 855 struct PlaintextAuth : public Action { PlaintextAuthmockssandra::PlaintextAuth856 PlaintextAuth(const String& username, const String& password) 857 : username(username) 858 , password(password) {} 859 virtual void on_run(Request* request) const; 860 String username; 861 String password; 862 }; 863 864 struct ValidateStartup : public Action { 865 virtual void on_run(Request* request) const; 866 }; 867 868 struct ValidateCredentials : public Action { 869 virtual void on_run(Request* request) const; 870 }; 871 872 struct ValidateAuthResponse : public Action { 873 virtual void on_run(Request* request) const; 874 }; 875 876 struct ValidateRegister : public Action { 877 virtual void on_run(Request* request) const; 878 }; 879 880 struct ValidateQuery : public Action { 881 virtual void on_run(Request* request) const; 882 }; 883 884 struct SetRegisteredForEvents : public Action { 885 virtual void on_run(Request* request) const; 886 }; 887 888 struct SetProtocolVersion : public Action { 889 virtual void on_run(Request* request) const; 890 }; 891 892 struct IsAddress : public Predicate { IsAddressmockssandra::IsAddress893 IsAddress(const Address& address) 894 : address(address) {} 895 virtual bool is_true(Request* request) const; 896 const Address address; 897 }; 898 899 struct IsQuery : public Predicate { IsQuerymockssandra::IsQuery900 IsQuery(const String& query) 901 : query(query) {} 902 virtual bool is_true(Request* request) const; 903 const String query; 904 }; 905 906 class RequestHandler { 907 public: 908 class Builder { 909 public: Builder()910 Builder() 911 : lowest_supported_protocol_version_(1) 912 , highest_supported_protocol_version_(5) { 913 invalid_protocol_.invalid_protocol(); 914 invalid_opcode_.invalid_opcode(); 915 } 916 on(int8_t opcode)917 Action::Builder& on(int8_t opcode) { 918 if (opcode < OPCODE_LAST_ENTRY) { 919 return actions_[opcode].reset(); 920 } 921 return dummy_.reset(); 922 } 923 on_invalid_protocol()924 Action::Builder& on_invalid_protocol() { return invalid_protocol_; } on_invalid_opcode()925 Action::Builder& on_invalid_opcode() { return invalid_opcode_; } 926 927 const RequestHandler* build(); 928 with_supported_protocol_versions(int lowest,int highest)929 Builder& with_supported_protocol_versions(int lowest, int highest) { 930 assert(highest >= lowest && "Invalid protocol versions"); 931 lowest_supported_protocol_version_ = lowest < 0 ? 0 : lowest; 932 highest_supported_protocol_version_ = highest > 5 ? 5 : highest; 933 return *this; 934 } 935 936 private: 937 Action::Builder actions_[OPCODE_LAST_ENTRY]; 938 Action::Builder invalid_protocol_; 939 Action::Builder invalid_opcode_; 940 Action::Builder dummy_; 941 int lowest_supported_protocol_version_; 942 int highest_supported_protocol_version_; 943 }; 944 945 RequestHandler(Builder* builder, int lowest_supported_protocol_version, 946 int highest_supported_protocol_version); 947 lowest_supported_protocol_version() const948 int lowest_supported_protocol_version() const { return lowest_supported_protocol_version_; } 949 highest_supported_protocol_version() const950 int highest_supported_protocol_version() const { return highest_supported_protocol_version_; } 951 invalid_protocol(Request * request) const952 void invalid_protocol(Request* request) const { invalid_protocol_->run(request); } 953 run(Request * request) const954 void run(Request* request) const { 955 const ScopedPtr<const Action>& action = actions_[request->opcode()]; 956 if (action) { 957 action->run(request); 958 } else { 959 invalid_opcode_->run(request); 960 } 961 } 962 963 private: 964 ScopedPtr<const Action> invalid_protocol_; 965 ScopedPtr<const Action> invalid_opcode_; 966 ScopedPtr<const Action> actions_[OPCODE_LAST_ENTRY]; 967 const int lowest_supported_protocol_version_; 968 const int highest_supported_protocol_version_; 969 }; 970 971 class ProtocolHandler { 972 public: ProtocolHandler(const RequestHandler * request_handler)973 ProtocolHandler(const RequestHandler* request_handler) 974 : request_handler_(request_handler) 975 , state_(PROTOCOL_VERSION) 976 , version_(0) 977 , flags_(0) 978 , stream_(0) 979 , opcode_(0) 980 , length_(0) {} 981 982 void decode(ClientConnection* client, const char* data, int32_t len); 983 984 private: 985 int32_t decode_frame(ClientConnection* client, const char* frame, int32_t len); 986 void decode_body(ClientConnection* client, const char* body, int32_t len); 987 988 enum State { PROTOCOL_VERSION, HEADER, BODY }; 989 990 private: 991 String buffer_; 992 const RequestHandler* request_handler_; 993 State state_; 994 int8_t version_; 995 int8_t flags_; 996 int16_t stream_; 997 int8_t opcode_; 998 int32_t length_; 999 }; 1000 1001 class ClientConnection : public internal::ClientConnection { 1002 public: ClientConnection(internal::ServerConnection * server,const RequestHandler * request_handler,const Cluster * cluster)1003 ClientConnection(internal::ServerConnection* server, const RequestHandler* request_handler, 1004 const Cluster* cluster) 1005 : internal::ClientConnection(server) 1006 , handler_(request_handler) 1007 , cluster_(cluster) 1008 , protocol_version_(-1) 1009 , is_registered_for_events_(false) {} 1010 1011 virtual void on_read(const char* data, size_t len); 1012 cluster() const1013 const Cluster* cluster() const { return cluster_; } 1014 protocol_version() const1015 int protocol_version() const { return protocol_version_; } set_protocol_version(int protocol_version)1016 void set_protocol_version(int protocol_version) { protocol_version_ = protocol_version; } 1017 is_registered_for_events() const1018 bool is_registered_for_events() const { return is_registered_for_events_; } set_registered_for_events()1019 void set_registered_for_events() { is_registered_for_events_ = true; } options() const1020 const Options& options() const { return options_; } set_options(const Options & options)1021 void set_options(const Options& options) { options_ = options; } 1022 keyspace() const1023 const String& keyspace() const { return keyspace_; } set_keyspace(const String & keyspace)1024 void set_keyspace(const String& keyspace) { keyspace_ = keyspace; } 1025 1026 private: 1027 ProtocolHandler handler_; 1028 String keyspace_; 1029 const Cluster* cluster_; 1030 int protocol_version_; 1031 bool is_registered_for_events_; 1032 Options options_; 1033 }; 1034 1035 class CloseConnection : public ClientConnection { 1036 public: CloseConnection(internal::ServerConnection * server,const RequestHandler * request_handler,const Cluster * cluster)1037 CloseConnection(internal::ServerConnection* server, const RequestHandler* request_handler, 1038 const Cluster* cluster) 1039 : ClientConnection(server, request_handler, cluster) {} 1040 on_accept()1041 int on_accept() { 1042 int rc = accept(); 1043 if (rc != 0) { 1044 return rc; 1045 } 1046 close(); 1047 return rc; 1048 } 1049 }; 1050 1051 class ClientConnectionFactory : public internal::ClientConnectionFactory { 1052 public: ClientConnectionFactory(const RequestHandler * request_handler,const Cluster * cluster)1053 ClientConnectionFactory(const RequestHandler* request_handler, const Cluster* cluster) 1054 : request_handler_(request_handler) 1055 , cluster_(cluster) 1056 , close_immediately_(false) {} 1057 use_close_immediately()1058 void use_close_immediately() { close_immediately_ = true; } 1059 create(internal::ServerConnection * server) const1060 virtual internal::ClientConnection* create(internal::ServerConnection* server) const { 1061 if (close_immediately_) { 1062 return new CloseConnection(server, request_handler_.get(), cluster_); 1063 } else { 1064 return new ClientConnection(server, request_handler_.get(), cluster_); 1065 } 1066 } 1067 1068 private: 1069 ScopedPtr<const RequestHandler> request_handler_; 1070 const Cluster* cluster_; 1071 bool close_immediately_; 1072 }; 1073 1074 class AddressGenerator { 1075 public: 1076 virtual Address next() = 0; 1077 }; 1078 1079 class Ipv4AddressGenerator : public AddressGenerator { 1080 public: Ipv4AddressGenerator(uint8_t a=127,uint8_t b=0,uint8_t c=0,uint8_t d=1,int port=9042)1081 Ipv4AddressGenerator(uint8_t a = 127, uint8_t b = 0, uint8_t c = 0, uint8_t d = 1, 1082 int port = 9042) 1083 : ip_((a << 24) | (b << 16) | (c << 8) | (d & 0xff)) 1084 , port_(port) {} 1085 1086 virtual Address next(); 1087 1088 private: 1089 uint32_t ip_; 1090 int port_; 1091 }; 1092 1093 class Event : public internal::ServerConnectionTask { 1094 public: 1095 typedef SharedRefPtr<Event> Ptr; 1096 1097 Event(const String& event_body); 1098 1099 virtual void run(internal::ServerConnection* server_connection); 1100 1101 private: 1102 String event_body_; 1103 }; 1104 1105 class TopologyChangeEvent : public Event { 1106 public: 1107 enum Type { NEW_NODE, MOVED_NODE, REMOVED_NODE }; 1108 TopologyChangeEvent(Type type,const Address & address)1109 TopologyChangeEvent(Type type, const Address& address) 1110 : Event(encode(type, address)) {} 1111 1112 static Ptr new_node(const Address& address); 1113 static Ptr moved_node(const Address& address); 1114 static Ptr removed_node(const Address& address); 1115 1116 static String encode(Type type, const Address& address); 1117 }; 1118 1119 class StatusChangeEvent : public Event { 1120 public: 1121 enum Type { UP, DOWN }; 1122 StatusChangeEvent(Type type,const Address & address)1123 StatusChangeEvent(Type type, const Address& address) 1124 : Event(encode(type, address)) {} 1125 1126 static Ptr up(const Address& address); 1127 static Ptr down(const Address& address); 1128 1129 static String encode(Type type, const Address& address); 1130 }; 1131 1132 class SchemaChangeEvent : public Event { 1133 public: 1134 enum Type { CREATED, UPDATED, DROPPED }; 1135 1136 enum Target { KEYSPACE, TABLE, USER_TYPE, FUNCTION, AGGREGATE }; 1137 SchemaChangeEvent(Target target,Type type,const String & keyspace_name,const String & target_name="",const Vector<String> & args_types=Vector<String> ())1138 SchemaChangeEvent(Target target, Type type, const String& keyspace_name, 1139 const String& target_name = "", 1140 const Vector<String>& args_types = Vector<String>()) 1141 : Event(encode(target, type, keyspace_name, target_name, args_types)) {} 1142 1143 static Ptr keyspace(Type type, const String& keyspace_name); 1144 static Ptr table(Type type, const String& keyspace_name, const String& table_name); 1145 static Ptr user_type(Type type, const String& keyspace_name, const String& user_type_name); 1146 static Ptr function(Type type, const String& keyspace_name, const String& function_name, 1147 const Vector<String>& args_types); 1148 static Ptr aggregate(Type type, const String& keyspace_name, const String& aggregate_name, 1149 const Vector<String>& args_types); 1150 1151 static String encode(Target target, Type type, const String& keyspace_name, 1152 const String& target_name, const Vector<String>& arg_types); 1153 }; 1154 1155 class Cluster { 1156 protected: 1157 void init(AddressGenerator& generator, ClientConnectionFactory& factory, size_t num_nodes_dc1, 1158 size_t num_nodes_dc2); 1159 1160 public: 1161 ~Cluster(); 1162 1163 String use_ssl(const String& cn = ""); 1164 1165 int start_all(EventLoopGroup* event_loop_group); 1166 void start_all_async(EventLoopGroup* event_loop_group); 1167 1168 void stop_all(); 1169 void stop_all_async(); 1170 1171 int start(EventLoopGroup* event_loop_group, size_t node); 1172 void start_async(EventLoopGroup* event_loop_group, size_t node); 1173 1174 void stop(size_t node); 1175 void stop_async(size_t node); 1176 1177 int add(EventLoopGroup* event_loop_group, size_t node); 1178 void remove(size_t node); 1179 1180 const Host& host(const Address& address) const; 1181 Hosts hosts() const; 1182 1183 unsigned connection_attempts(size_t node) const; 1184 1185 void event(const Event::Ptr& event); 1186 1187 private: 1188 struct Server { Servermockssandra::Cluster::Server1189 Server(const Host& host, const internal::ServerConnection::Ptr& connection) 1190 : host(host) 1191 , connection(connection) 1192 , is_removed(false) {} 1193 Servermockssandra::Cluster::Server1194 Server(const Server& server) 1195 : host(server.host) 1196 , connection(server.connection) 1197 , is_removed(server.is_removed.load()) {} 1198 operator =mockssandra::Cluster::Server1199 Server& operator=(const Server& server) { 1200 host = server.host; 1201 connection = server.connection; 1202 is_removed.store(server.is_removed.load()); 1203 return *this; 1204 } 1205 1206 Host host; 1207 internal::ServerConnection::Ptr connection; 1208 Atomic<bool> is_removed; 1209 }; 1210 1211 typedef Vector<Server> Servers; 1212 1213 int create_and_add_server(AddressGenerator& generator, ClientConnectionFactory& factory, 1214 const String& dc); 1215 1216 private: 1217 Servers servers_; 1218 MT19937_64 token_rng_; 1219 }; 1220 1221 class SimpleEventLoopGroup : public RoundRobinEventLoopGroup { 1222 public: 1223 SimpleEventLoopGroup(size_t num_threads = 1, const String& thread_name = "mockssandra"); 1224 ~SimpleEventLoopGroup(); 1225 }; 1226 1227 class SimpleRequestHandlerBuilder : public RequestHandler::Builder { 1228 public: 1229 SimpleRequestHandlerBuilder(); 1230 }; 1231 1232 class AuthRequestHandlerBuilder : public SimpleRequestHandlerBuilder { 1233 public: 1234 AuthRequestHandlerBuilder(const String& username = "cassandra", 1235 const String& password = "cassandra"); 1236 }; 1237 1238 class SimpleCluster : public Cluster { 1239 public: SimpleCluster(const RequestHandler * request_handler,size_t num_nodes_dc1=1,size_t num_nodes_dc2=0)1240 SimpleCluster(const RequestHandler* request_handler, size_t num_nodes_dc1 = 1, 1241 size_t num_nodes_dc2 = 0) 1242 : factory_(request_handler, this) 1243 , event_loop_group_(1) { 1244 init(generator_, factory_, num_nodes_dc1, num_nodes_dc2); 1245 } 1246 ~SimpleCluster()1247 ~SimpleCluster() { stop_all(); } 1248 use_close_immediately()1249 void use_close_immediately() { factory_.use_close_immediately(); } 1250 start_all()1251 int start_all() { return Cluster::start_all(&event_loop_group_); } 1252 start(size_t node)1253 int start(size_t node) { return Cluster::start(&event_loop_group_, node); } 1254 add(size_t node)1255 int add(size_t node) { return Cluster::add(&event_loop_group_, node); } 1256 1257 private: 1258 Ipv4AddressGenerator generator_; 1259 ClientConnectionFactory factory_; 1260 SimpleEventLoopGroup event_loop_group_; 1261 }; 1262 1263 class SimpleEchoServer { 1264 public: SimpleEchoServer()1265 SimpleEchoServer() 1266 : factory_(new EchoClientConnectionFactory()) 1267 , event_loop_group_(1) {} 1268 ~SimpleEchoServer()1269 ~SimpleEchoServer() { close(); } 1270 close()1271 void close() { 1272 if (server_) { 1273 server_->close(); 1274 server_->wait_close(); 1275 } 1276 } 1277 use_ssl(const String & cn="")1278 String use_ssl(const String& cn = "") { 1279 ssl_key_ = Ssl::generate_key(); 1280 ssl_cert_ = Ssl::generate_cert(ssl_key_, cn); 1281 return ssl_cert_; 1282 } 1283 use_connection_factory(internal::ClientConnectionFactory * factory)1284 void use_connection_factory(internal::ClientConnectionFactory* factory) { 1285 factory_.reset(factory); 1286 } 1287 listen(const Address & address=Address ("127.0.0.1",8888))1288 int listen(const Address& address = Address("127.0.0.1", 8888)) { 1289 server_.reset(new internal::ServerConnection(address, *factory_)); 1290 if (!ssl_key_.empty() && !ssl_cert_.empty() && !server_->use_ssl(ssl_key_, ssl_cert_)) { 1291 return -1; 1292 } 1293 server_->listen(&event_loop_group_); 1294 return server_->wait_listen(); 1295 } 1296 1297 private: 1298 class EchoConnection : public internal::ClientConnection { 1299 public: EchoConnection(internal::ServerConnection * server)1300 EchoConnection(internal::ServerConnection* server) 1301 : internal::ClientConnection(server) {} 1302 on_read(const char * data,size_t len)1303 virtual void on_read(const char* data, size_t len) { write(data, len); } 1304 }; 1305 1306 class EchoClientConnectionFactory : public internal::ClientConnectionFactory { 1307 public: create(internal::ServerConnection * server) const1308 virtual internal::ClientConnection* create(internal::ServerConnection* server) const { 1309 return new EchoConnection(server); 1310 } 1311 }; 1312 1313 private: 1314 ScopedPtr<internal::ClientConnectionFactory> factory_; 1315 SimpleEventLoopGroup event_loop_group_; 1316 internal::ServerConnection::Ptr server_; 1317 String ssl_key_; 1318 String ssl_cert_; 1319 }; 1320 1321 } // namespace mockssandra 1322 1323 #endif // MOCKSSANDRA_HPP 1324