1 /*
2   Copyright (c) DataStax, Inc.
3 
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7 
8   http://www.apache.org/licenses/LICENSE-2.0
9 
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15 */
16 
17 #ifndef MOCKSSANDRA_HPP
18 #define MOCKSSANDRA_HPP
19 
20 #include <uv.h>
21 
22 #include <openssl/bio.h>
23 #include <openssl/err.h>
24 #include <openssl/ssl.h>
25 
26 #include <stdint.h>
27 
28 #include "address.hpp"
29 #include "event_loop.hpp"
30 #include "list.hpp"
31 #include "map.hpp"
32 #include "ref_counted.hpp"
33 #include "scoped_ptr.hpp"
34 #include "string.hpp"
35 #include "third_party/mt19937_64/mt19937_64.hpp"
36 #include "timer.hpp"
37 #include "vector.hpp"
38 
39 #if defined(WIN32) || defined(_WIN32)
40 #undef ERROR_ALREADY_EXISTS
41 #undef ERROR
42 #undef X509_NAME
43 #endif
44 
45 #define CLIENT_OPTIONS_QUERY "client.options"
46 
47 using datastax::String;
48 using datastax::internal::Atomic;
49 using datastax::internal::List;
50 using datastax::internal::Map;
51 using datastax::internal::RefCounted;
52 using datastax::internal::ScopedPtr;
53 using datastax::internal::SharedRefPtr;
54 using datastax::internal::Vector;
55 using datastax::internal::core::Address;
56 using datastax::internal::core::EventLoop;
57 using datastax::internal::core::EventLoopGroup;
58 using datastax::internal::core::RoundRobinEventLoopGroup;
59 using datastax::internal::core::Task;
60 using datastax::internal::core::Timer;
61 
62 namespace mockssandra {
63 
64 class Ssl {
65 public:
66   static String generate_key();
67   static String generate_cert(const String& key, String cn = "", String ca_cert = "",
68                               String ca_key = "");
69 };
70 
71 namespace internal {
72 
73 class ServerConnection;
74 
75 class Tcp {
76 public:
77   Tcp(void* data);
78 
79   int init(uv_loop_t* loop);
80   int bind(const struct sockaddr* addr);
81 
82   uv_handle_t* as_handle();
83   uv_stream_t* as_stream();
84 
85 private:
86   uv_tcp_t tcp_;
87 };
88 
89 class ClientConnection {
90 public:
91   ClientConnection(ServerConnection* server);
92   virtual ~ClientConnection();
93 
server()94   ServerConnection* server() { return server_; }
95 
on_accept()96   virtual int on_accept() { return accept(); }
on_close()97   virtual void on_close() {}
98 
on_read(const char * data,size_t len)99   virtual void on_read(const char* data, size_t len) {}
on_write()100   virtual void on_write() {}
101 
102   int write(const String& data);
103   int write(const char* data, size_t len);
104   void close();
105 
106 protected:
107   int accept();
108 
109   const char* sni_server_name() const;
110 
111 private:
112   static void on_close(uv_handle_t* handle);
113   void handle_close();
114 
115   static void on_alloc(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf);
116 
117   static void on_read(uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf);
118   void handle_read(ssize_t nread, const uv_buf_t* buf);
119 
120   static void on_write(uv_write_t* req, int status);
121   void handle_write(int status);
122 
123 private:
124   int internal_write(const char* data, size_t len);
125   int ssl_write(const char* data, size_t len);
126 
127   bool is_handshake_done();
128   bool has_ssl_error(int rc);
129 
130   void on_ssl_read(const char* data, size_t len);
131 
132 private:
133   Tcp tcp_;
134   ServerConnection* const server_;
135 
136 private:
137   enum SslHandshakeState {
138     SSL_HANDSHAKE_INPROGRESS,
139     SSL_HANDSHAKE_FINAL_WRITE,
140     SSL_HANDSHAKE_DONE
141   };
142 
143   SSL* ssl_;
144   BIO* incoming_bio_;
145   BIO* outgoing_bio_;
146 };
147 
148 class ClientConnectionFactory {
149 public:
150   virtual ClientConnection* create(ServerConnection* server) const = 0;
~ClientConnectionFactory()151   virtual ~ClientConnectionFactory() {}
152 };
153 
154 class ServerConnectionTask : public RefCounted<ServerConnectionTask> {
155 public:
156   typedef SharedRefPtr<ServerConnectionTask> Ptr;
157 
~ServerConnectionTask()158   virtual ~ServerConnectionTask() {}
159   virtual void run(ServerConnection* server_connection) = 0;
160 };
161 
162 typedef Vector<ClientConnection*> ClientConnections;
163 
164 class ServerConnection : public RefCounted<ServerConnection> {
165 public:
166   typedef SharedRefPtr<ServerConnection> Ptr;
167 
168   ServerConnection(const Address& address, const ClientConnectionFactory& factory);
169   ~ServerConnection();
170 
address() const171   const Address& address() const { return address_; }
172   uv_loop_t* loop();
ssl_context()173   SSL_CTX* ssl_context() { return ssl_context_; }
clients() const174   const ClientConnections& clients() const { return clients_; }
175 
176   bool use_ssl(const String& key, const String& cert, const String& ca_cert = "",
177                bool require_client_cert = false);
178 
179   void listen(EventLoopGroup* event_loop_group);
180   int wait_listen();
181 
182   void close();
183   void wait_close();
184 
185   unsigned connection_attempts() const;
186   void run(const ServerConnectionTask::Ptr& task);
187 
188 private:
189   friend class ClientConnection;
190   int accept(uv_stream_t* client);
191   void remove(ClientConnection* connection);
192 
193 private:
194   friend class RunListen;
195   friend class RunClose;
196 
197   void internal_listen();
198   void internal_close();
199   void maybe_close();
200 
201   void signal_listen(int rc);
202   void signal_close();
203 
204   static void on_connection(uv_stream_t* stream, int status);
205   void handle_connection(int status);
206 
207   static void on_close(uv_handle_t* handle);
208   void handle_close();
209 
210   static int on_password(char* buf, int size, int rwflag, void* password);
211 
212 private:
213   enum State { STATE_CLOSED, STATE_CLOSING, STATE_PENDING, STATE_LISTENING };
214 
215   Tcp tcp_;
216   EventLoop* event_loop_;
217   State state_;
218   int rc_;
219   uv_mutex_t mutex_;
220   uv_cond_t cond_;
221   ClientConnections clients_;
222   const Address address_;
223   const ClientConnectionFactory& factory_;
224   SSL_CTX* ssl_context_;
225   Atomic<unsigned> connection_attempts_;
226 };
227 
228 } // namespace internal
229 
230 enum {
231   FLAG_COMPRESSION = 0x01,
232   FLAG_TRACING = 0x02,
233   FLAG_CUSTOM_PAYLOAD = 0x04,
234   FLAG_WARNING = 0x08,
235   FLAG_BETA = 0x10
236 };
237 
238 enum {
239   OPCODE_ERROR = 0x00,
240   OPCODE_STARTUP = 0x01,
241   OPCODE_READY = 0x02,
242   OPCODE_AUTHENTICATE = 0x03,
243   OPCODE_CREDENTIALS = 0x04,
244   OPCODE_OPTIONS = 0x05,
245   OPCODE_SUPPORTED = 0x06,
246   OPCODE_QUERY = 0x07,
247   OPCODE_RESULT = 0x08,
248   OPCODE_PREPARE = 0x09,
249   OPCODE_EXECUTE = 0x0A,
250   OPCODE_REGISTER = 0x0B,
251   OPCODE_EVENT = 0x0C,
252   OPCODE_BATCH = 0x0D,
253   OPCODE_AUTH_CHALLENGE = 0x0E,
254   OPCODE_AUTH_RESPONSE = 0x0F,
255   OPCODE_AUTH_SUCCESS = 0x10,
256   OPCODE_LAST_ENTRY
257 };
258 
259 enum {
260   QUERY_FLAG_VALUES = 0x01,
261   QUERY_FLAG_SKIP_METADATA = 0x02,
262   QUERY_FLAG_PAGE_SIZE = 0x04,
263   QUERY_FLAG_PAGE_STATE = 0x08,
264   QUERY_FLAG_SERIAL_CONSISTENCY = 0x10,
265   QUERY_FLAG_TIMESTAMP = 0x20,
266   QUERY_FLAG_NAMES_FOR_VALUES = 0x40,
267   QUERY_FLAG_KEYSPACE = 0x80
268 };
269 
270 enum { PREPARE_FLAGS_KEYSPACE = 0x01 };
271 
272 enum {
273   ERROR_SERVER_ERROR = 0x0000,
274   ERROR_PROTOCOL_ERROR = 0x000A,
275   ERROR_BAD_CREDENTIALS = 0x0100,
276   ERROR_UNAVAILABLE = 0x1000,
277   ERROR_OVERLOADED = 0x1001,
278   ERROR_IS_BOOTSTRAPPING = 0x1002,
279   ERROR_TRUNCATE_ERROR = 0x1003,
280   ERROR_WRITE_TIMEOUT = 0x1100,
281   ERROR_READ_TIMEOUT = 0x1200,
282   ERROR_READ_FAILURE = 0x1300,
283   ERROR_FUNCTION_FAILURE = 0x1400,
284   ERROR_WRITE_FAILURE = 0x1500,
285   ERROR_SYNTAX_ERROR = 0x2000,
286   ERROR_UNAUTHORIZED = 0x2100,
287   ERROR_INVALID_QUERY = 0x2200,
288   ERROR_CONFIG_ERROR = 0x2300,
289   ERROR_ALREADY_EXISTS = 0x2400,
290   ERROR_UNPREPARED = 0x2500,
291   ERROR_CLIENT_WRITE_FAILURE = 0x8000
292 };
293 
294 enum {
295   RESULT_VOID = 0x0001,
296   RESULT_ROWS = 0x0002,
297   RESULT_SET_KEYSPACE = 0x0003,
298   RESULT_PREPARED = 0x0004,
299   RESULT_SCHEMA_CHANGE = 0x0005
300 };
301 
302 enum {
303   RESULT_FLAG_GLOBAL_TABLESPEC = 0x00000001,
304   RESULT_FLAG_HAS_MORE_PAGES = 0x00000002,
305   RESULT_FLAG_NO_METADATA = 0x00000004,
306   RESULT_FLAG_METADATA_CHANGED = 0x00000008,
307   RESULT_FLAG_CONTINUOUS_PAGING = 0x40000000,
308   RESULT_FLAG_LAST_CONTINUOUS_PAGE = 0x80000000
309 };
310 
311 enum {
312   TYPE_CUSTOM = 0x0000,
313   TYPE_ASCII = 0x0001,
314   TYPE_BIGINT = 0x0002,
315   TYPE_BLOG = 0x0003,
316   TYPE_BOOLEAN = 0x0004,
317   TYPE_COUNTER = 0x0005,
318   TYPE_DECIMAL = 0x0006,
319   TYPE_DOUBLE = 0x0007,
320   TYPE_FLOAT = 0x0008,
321   TYPE_INT = 0x0009,
322   TYPE_TIMESTAMP = 0x000B,
323   TYPE_UUID = 0x000C,
324   TYPE_VARCHAR = 0x000D,
325   TYPE_VARINT = 0x000E,
326   TYPE_TIMEUUD = 0x000F,
327   TYPE_INET = 0x0010,
328   TYPE_DATE = 0x0011,
329   TYPE_TIME = 0x0012,
330   TYPE_SMALLINT = 0x0013,
331   TYPE_TINYINT = 0x0014,
332   TYPE_LIST = 0x0020,
333   TYPE_MAP = 0x0021,
334   TYPE_SET = 0x0022,
335   TYPE_UDT = 0x0030,
336   TYPE_TUPLE = 0x0031
337 };
338 
339 typedef std::pair<String, String> Option;
340 typedef Vector<Option> Options;
341 typedef std::pair<String, String> Credential;
342 typedef Vector<Credential> Credentials;
343 typedef Vector<String> EventTypes;
344 typedef Vector<String> Values;
345 typedef Vector<String> Names;
346 
347 struct PrepareParameters {
348   int32_t flags;
349   String keyspace;
350 };
351 
352 struct QueryParameters {
353   uint16_t consistency;
354   int32_t flags;
355   Values values;
356   Names names;
357   int32_t result_page_size;
358   String paging_state;
359   uint16_t serial_consistency;
360   int64_t timestamp;
361   String keyspace;
362 };
363 
364 int32_t encode_int32(int32_t value, String* output);
365 int32_t encode_string(const String& value, String* output);
366 int32_t encode_string_map(const Map<String, Vector<String> >& value, String* output);
367 
368 class Type {
369 public:
370   static Type text();
371   static Type inet();
372   static Type uuid();
373   static Type list(const Type& sub_type);
374 
375   void encode(int protocol_version, String* output) const;
376 
377 private:
Type()378   Type()
379       : type_(-1) {}
380 
Type(int type)381   Type(int type)
382       : type_(type) {}
383 
384   friend class Vector<Type>;
385 
386 private:
387   int type_;
388   String custom_;
389   Vector<String> names_;
390   Vector<Type> types_;
391 };
392 
393 class Column {
394 public:
Column(const String & name,const Type type)395   Column(const String& name, const Type type)
396       : name_(name)
397       , type_(type) {}
398 
399   void encode(int protocol_version, String* output) const;
400 
401 private:
402   String name_;
403   Type type_;
404 };
405 
406 class ResultSet;
407 class Collection;
408 
409 class Value {
410 private:
411   enum Type { NUL, VALUE, COLLECTION };
412 
413 public:
414   Value();
415   Value(const String& value);
416   Value(const Collection& collection);
417   Value(const Value& other);
418   ~Value();
419 
420   void encode(int protocol_version, String* output) const;
421 
422 private:
423   Type type_;
424   union {
425     String* value_;
426     Collection* collection_;
427   };
428 };
429 
430 class Collection {
431 public:
432   class Builder {
433   public:
Builder(const Type & sub_type)434     Builder(const Type& sub_type)
435         : sub_type_(sub_type) {}
436 
text(const String & text)437     Builder& text(const String& text) {
438       values_.push_back(Value(text));
439       return *this;
440     }
441 
build()442     Collection build() { return Collection(sub_type_, values_); }
443 
444   private:
445     Type sub_type_;
446     Vector<Value> values_;
447   };
448 
449   void encode(int protocol_version, String* output) const;
450 
text(const Vector<String> & values)451   static Collection text(const Vector<String>& values) {
452     Collection::Builder builder(Type::text());
453     for (Vector<String>::const_iterator it = values.begin(), end = values.end(); it != end; ++it) {
454       builder.text(*it);
455     }
456     return builder.build();
457   }
458 
459 private:
Collection(const Type & sub_type,const Vector<Value> values)460   Collection(const Type& sub_type, const Vector<Value> values)
461       : sub_type_(sub_type)
462       , values_(values) {}
463 
464 private:
465   const Type sub_type_;
466   const Vector<Value> values_;
467 };
468 
469 class Row {
470 public:
471   class Builder {
472   public:
473     Builder& text(const String& text);
474 
475     Builder& inet(const Address& inet);
476 
477     Builder& uuid(const CassUuid& uuid);
478 
479     Builder& collection(const Collection& collection);
480 
build() const481     Row build() const { return Row(values_); }
482 
483   private:
484     Vector<Value> values_;
485   };
486 
487   void encode(int protocol_version, String* output) const;
488 
489 private:
Row(const Vector<Value> & values)490   Row(const Vector<Value>& values)
491       : values_(values) {}
492 
493 private:
494   Vector<Value> values_;
495 };
496 
497 class ResultSet {
498 public:
499   class Builder {
500   public:
Builder(const String & keyspace_name,const String & table_name)501     Builder(const String& keyspace_name, const String& table_name)
502         : keyspace_name_(keyspace_name)
503         , table_name_(table_name) {}
504 
column(const String & name,const Type & type)505     Builder& column(const String& name, const Type& type) {
506       columns_.push_back(Column(name, type));
507       return *this;
508     }
509 
row(const Row & row)510     Builder& row(const Row& row) {
511       rows_.push_back(row);
512       return *this;
513     }
514 
build() const515     ResultSet build() const { return ResultSet(keyspace_name_, table_name_, columns_, rows_); }
516 
517   private:
518     const String keyspace_name_;
519     const String table_name_;
520     Vector<Column> columns_;
521     Vector<Row> rows_;
522   };
523 
524   String encode(int protocol_version) const;
525 
column_count() const526   size_t column_count() const { return columns_.size(); }
527 
528 private:
ResultSet(const String & keyspace_name,const String & table_name,const Vector<Column> & columns,const Vector<Row> & rows)529   ResultSet(const String& keyspace_name, const String& table_name, const Vector<Column>& columns,
530             const Vector<Row>& rows)
531       : keyspace_name_(keyspace_name)
532       , table_name_(table_name)
533       , columns_(columns)
534       , rows_(rows) {}
535 
536 private:
537   const String keyspace_name_;
538   const String table_name_;
539   const Vector<Column> columns_;
540   const Vector<Row> rows_;
541 };
542 
543 struct Exception : public std::exception {
Exceptionmockssandra::Exception544   Exception(int code, const String& message)
545       : code(code)
546       , message(message) {}
~Exceptionmockssandra::Exception547   virtual ~Exception() throw() {}
548   const int code;
549   const String message;
550 };
551 
552 struct Host {
Hostmockssandra::Host553   Host() {}
554   Host(const Address& address, const String& dc, const String& rack, MT19937_64& token_rng,
555        int num_tokens = 2);
556   Address address;
557   String dc;
558   String rack;
559   String partitioner;
560   Vector<String> tokens;
561 };
562 
563 typedef Vector<Host> Hosts;
564 
565 class ClientConnection;
566 class Cluster;
567 class Request;
568 
569 typedef std::pair<String, ResultSet> Match;
570 typedef Vector<Match> Matches;
571 
572 struct Predicate;
573 
574 struct Action {
575   class PredicateBuilder;
576 
577   class Builder {
578   public:
Builder()579     Builder()
580         : last_(NULL) {}
581 
582     Builder& reset();
583 
584     Builder& execute(Action* action);
585     Builder& execute_if(Action* action);
586 
587     Builder& nop();
588     Builder& wait(uint64_t timeout);
589     Builder& close();
590     Builder& error(int32_t code, const String& message);
591     Builder& invalid_protocol();
592     Builder& invalid_opcode();
593 
594     Builder& ready();
595     Builder& authenticate(const String& class_name);
596     Builder& auth_challenge(const String& token);
597     Builder& auth_success(const String& token = "");
598     Builder& supported();
599     Builder& up_event(const Address& address);
600 
601     Builder& void_result();
602     Builder& empty_rows_result(int32_t row_count);
603     Builder& no_result();
604     Builder& match_query(const Matches& matches);
605 
606     Builder& client_options();
607 
608     Builder& system_local();
609     Builder& system_local_dse();
610     Builder& system_peers();
611     Builder& system_peers_dse();
612     Builder& system_traces();
613 
614     Builder& use_keyspace(const String& keyspace);
615     Builder& use_keyspace(const Vector<String>& keyspaces);
616     Builder& plaintext_auth(const String& username = "cassandra",
617                             const String& password = "cassandra");
618 
619     Builder& validate_startup();
620     Builder& validate_credentials();
621     Builder& validate_auth_response();
622     Builder& validate_register();
623     Builder& validate_query();
624 
625     Builder& set_registered_for_events();
626     Builder& set_protocol_version();
627 
628     PredicateBuilder is_address(const Address& address);
629     PredicateBuilder is_address(const String& address, int port = 9042);
630 
631     PredicateBuilder is_query(const String& query);
632 
633     Action* build();
634 
635   private:
636     ScopedPtr<Action> first_;
637     Action* last_;
638   };
639 
640   class PredicateBuilder {
641   public:
PredicateBuilder(Builder & builder)642     PredicateBuilder(Builder& builder)
643         : builder_(builder) {}
644 
then(Builder & builder)645     Builder& then(Builder& builder) { return then(builder.build()); }
646 
then(Action * action)647     Builder& then(Action* action) { return builder_.execute_if(action); }
648 
649   private:
650     Builder& builder_;
651   };
652 
Actionmockssandra::Action653   Action()
654       : next(NULL) {}
~Actionmockssandra::Action655   virtual ~Action() { delete next; }
656 
657   void run(Request* request) const;
658   void run_next(Request* request) const;
659 
is_predicatemockssandra::Action660   virtual bool is_predicate() const { return false; }
661   virtual void on_run(Request* request) const = 0;
662 
663   const Action* next;
664 };
665 
666 struct Predicate : public Action {
Predicatemockssandra::Predicate667   Predicate()
668       : then(NULL) {}
~Predicatemockssandra::Predicate669   virtual ~Predicate() { delete then; }
670 
is_predicatemockssandra::Predicate671   virtual bool is_predicate() const { return true; }
672   virtual bool is_true(Request* request) const = 0;
673 
on_runmockssandra::Predicate674   virtual void on_run(Request* request) const {
675     if (is_true(request)) {
676       if (then) {
677         then->run(request);
678       }
679     } else {
680       run_next(request);
681     }
682   }
683 
684   const Action* then;
685 };
686 
687 class Request
688     : public List<Request>::Node
689     , public RefCounted<Request> {
690 public:
691   typedef SharedRefPtr<Request> Ptr;
692 
693   Request(int8_t version, int8_t flags, int16_t stream, int8_t opcode, const String& body,
694           ClientConnection* client);
695 
version() const696   int8_t version() const { return version_; }
stream() const697   int16_t stream() const { return stream_; }
opcode() const698   int8_t opcode() const { return opcode_; }
699 
client() const700   ClientConnection* client() const { return client_; }
701 
702   void write(int8_t opcode, const String& body);
703   void write(int16_t stream, int8_t opcode, const String& body);
704   void error(int32_t code, const String& message);
705   void wait(uint64_t timeout, const Action* action);
706   void close();
707 
708   bool decode_startup(Options* options);
709   bool decode_credentials(Credentials* creds);
710   bool decode_auth_response(String* token);
711   bool decode_register(EventTypes* types);
712   bool decode_query(String* query, QueryParameters* params);
713   bool decode_execute(String* id, QueryParameters* params);
714   bool decode_prepare(String* query, PrepareParameters* params);
715 
716   const Address& address() const;
717   const Host& host(const Address& address) const;
718   Hosts hosts() const;
719 
720 private:
721   void on_timeout(Timer* timer);
722 
start()723   const char* start() { return body_.data(); }
end()724   const char* end() { return body_.data() + body_.size(); }
725 
726 private:
727   const int8_t version_;
728   const int8_t flags_;
729   const int16_t stream_;
730   const int8_t opcode_;
731   const String body_;
732   ClientConnection* const client_;
733   Timer timer_;
734   const Action* timer_action_;
735 };
736 
737 struct Nop : public Action {
on_runmockssandra::Nop738   virtual void on_run(Request* request) const {}
739 };
740 
741 struct Wait : public Action {
Waitmockssandra::Wait742   Wait(uint64_t timeout)
743       : timeout(timeout) {}
744 
on_runmockssandra::Wait745   virtual void on_run(Request* request) const { request->wait(timeout, this); }
746 
747   const uint64_t timeout;
748 };
749 
750 struct Close : public Action {
on_runmockssandra::Close751   virtual void on_run(Request* request) const { request->close(); }
752 };
753 
754 struct SendError : public Action {
SendErrormockssandra::SendError755   SendError(int32_t code, const String& message)
756       : code(code)
757       , message(message) {}
758 
759   virtual void on_run(Request* request) const;
760 
761   int32_t code;
762   String message;
763 };
764 
765 struct SendReady : public Action {
766   virtual void on_run(Request* request) const;
767 };
768 
769 struct SendAuthenticate : public Action {
SendAuthenticatemockssandra::SendAuthenticate770   SendAuthenticate(const String& class_name)
771       : class_name(class_name) {}
772   virtual void on_run(Request* request) const;
773   String class_name;
774 };
775 
776 struct SendAuthChallenge : public Action {
SendAuthChallengemockssandra::SendAuthChallenge777   SendAuthChallenge(const String& token)
778       : token(token) {}
779   virtual void on_run(Request* request) const;
780   String token;
781 };
782 
783 struct SendAuthSuccess : public Action {
SendAuthSuccessmockssandra::SendAuthSuccess784   SendAuthSuccess(const String& token)
785       : token(token) {}
786   virtual void on_run(Request* request) const;
787   String token;
788 };
789 
790 struct SendSupported : public Action {
791   virtual void on_run(Request* request) const;
792 };
793 
794 struct SendUpEvent : public Action {
SendUpEventmockssandra::SendUpEvent795   SendUpEvent(const Address& address)
796       : address(address) {}
797   virtual void on_run(Request* request) const;
798   Address address;
799 };
800 
801 struct VoidResult : public Action {
802   virtual void on_run(Request* request) const;
803 };
804 
805 struct EmptyRowsResult : public Action {
EmptyRowsResultmockssandra::EmptyRowsResult806   EmptyRowsResult(int row_count)
807       : row_count(row_count) {}
808   virtual void on_run(Request* request) const;
809   int32_t row_count;
810 };
811 
812 struct NoResult : public Action {
813   virtual void on_run(Request* request) const;
814 };
815 
816 struct MatchQuery : public Action {
MatchQuerymockssandra::MatchQuery817   MatchQuery(const Matches& matches)
818       : matches(matches) {}
819   virtual void on_run(Request* request) const;
820   Matches matches;
821 };
822 
823 struct ClientOptions : public Action {
824   virtual void on_run(Request* request) const;
825 };
826 
827 struct SystemLocal : public Action {
828   virtual void on_run(Request* request) const;
829 };
830 
831 struct SystemLocalDse : public Action {
832   virtual void on_run(Request* request) const;
833 };
834 
835 struct SystemPeers : public Action {
836   virtual void on_run(Request* request) const;
837 };
838 
839 struct SystemPeersDse : public Action {
840   virtual void on_run(Request* request) const;
841 };
842 
843 struct SystemTraces : public Action {
844   virtual void on_run(Request* request) const;
845 };
846 
847 struct UseKeyspace : public Action {
UseKeyspacemockssandra::UseKeyspace848   UseKeyspace(const String& keyspace) { keyspaces.push_back(keyspace); }
UseKeyspacemockssandra::UseKeyspace849   UseKeyspace(const Vector<String>& keyspaces)
850       : keyspaces(keyspaces) {}
851   virtual void on_run(Request* request) const;
852   Vector<String> keyspaces;
853 };
854 
855 struct PlaintextAuth : public Action {
PlaintextAuthmockssandra::PlaintextAuth856   PlaintextAuth(const String& username, const String& password)
857       : username(username)
858       , password(password) {}
859   virtual void on_run(Request* request) const;
860   String username;
861   String password;
862 };
863 
864 struct ValidateStartup : public Action {
865   virtual void on_run(Request* request) const;
866 };
867 
868 struct ValidateCredentials : public Action {
869   virtual void on_run(Request* request) const;
870 };
871 
872 struct ValidateAuthResponse : public Action {
873   virtual void on_run(Request* request) const;
874 };
875 
876 struct ValidateRegister : public Action {
877   virtual void on_run(Request* request) const;
878 };
879 
880 struct ValidateQuery : public Action {
881   virtual void on_run(Request* request) const;
882 };
883 
884 struct SetRegisteredForEvents : public Action {
885   virtual void on_run(Request* request) const;
886 };
887 
888 struct SetProtocolVersion : public Action {
889   virtual void on_run(Request* request) const;
890 };
891 
892 struct IsAddress : public Predicate {
IsAddressmockssandra::IsAddress893   IsAddress(const Address& address)
894       : address(address) {}
895   virtual bool is_true(Request* request) const;
896   const Address address;
897 };
898 
899 struct IsQuery : public Predicate {
IsQuerymockssandra::IsQuery900   IsQuery(const String& query)
901       : query(query) {}
902   virtual bool is_true(Request* request) const;
903   const String query;
904 };
905 
906 class RequestHandler {
907 public:
908   class Builder {
909   public:
Builder()910     Builder()
911         : lowest_supported_protocol_version_(1)
912         , highest_supported_protocol_version_(5) {
913       invalid_protocol_.invalid_protocol();
914       invalid_opcode_.invalid_opcode();
915     }
916 
on(int8_t opcode)917     Action::Builder& on(int8_t opcode) {
918       if (opcode < OPCODE_LAST_ENTRY) {
919         return actions_[opcode].reset();
920       }
921       return dummy_.reset();
922     }
923 
on_invalid_protocol()924     Action::Builder& on_invalid_protocol() { return invalid_protocol_; }
on_invalid_opcode()925     Action::Builder& on_invalid_opcode() { return invalid_opcode_; }
926 
927     const RequestHandler* build();
928 
with_supported_protocol_versions(int lowest,int highest)929     Builder& with_supported_protocol_versions(int lowest, int highest) {
930       assert(highest >= lowest && "Invalid protocol versions");
931       lowest_supported_protocol_version_ = lowest < 0 ? 0 : lowest;
932       highest_supported_protocol_version_ = highest > 5 ? 5 : highest;
933       return *this;
934     }
935 
936   private:
937     Action::Builder actions_[OPCODE_LAST_ENTRY];
938     Action::Builder invalid_protocol_;
939     Action::Builder invalid_opcode_;
940     Action::Builder dummy_;
941     int lowest_supported_protocol_version_;
942     int highest_supported_protocol_version_;
943   };
944 
945   RequestHandler(Builder* builder, int lowest_supported_protocol_version,
946                  int highest_supported_protocol_version);
947 
lowest_supported_protocol_version() const948   int lowest_supported_protocol_version() const { return lowest_supported_protocol_version_; }
949 
highest_supported_protocol_version() const950   int highest_supported_protocol_version() const { return highest_supported_protocol_version_; }
951 
invalid_protocol(Request * request) const952   void invalid_protocol(Request* request) const { invalid_protocol_->run(request); }
953 
run(Request * request) const954   void run(Request* request) const {
955     const ScopedPtr<const Action>& action = actions_[request->opcode()];
956     if (action) {
957       action->run(request);
958     } else {
959       invalid_opcode_->run(request);
960     }
961   }
962 
963 private:
964   ScopedPtr<const Action> invalid_protocol_;
965   ScopedPtr<const Action> invalid_opcode_;
966   ScopedPtr<const Action> actions_[OPCODE_LAST_ENTRY];
967   const int lowest_supported_protocol_version_;
968   const int highest_supported_protocol_version_;
969 };
970 
971 class ProtocolHandler {
972 public:
ProtocolHandler(const RequestHandler * request_handler)973   ProtocolHandler(const RequestHandler* request_handler)
974       : request_handler_(request_handler)
975       , state_(PROTOCOL_VERSION)
976       , version_(0)
977       , flags_(0)
978       , stream_(0)
979       , opcode_(0)
980       , length_(0) {}
981 
982   void decode(ClientConnection* client, const char* data, int32_t len);
983 
984 private:
985   int32_t decode_frame(ClientConnection* client, const char* frame, int32_t len);
986   void decode_body(ClientConnection* client, const char* body, int32_t len);
987 
988   enum State { PROTOCOL_VERSION, HEADER, BODY };
989 
990 private:
991   String buffer_;
992   const RequestHandler* request_handler_;
993   State state_;
994   int8_t version_;
995   int8_t flags_;
996   int16_t stream_;
997   int8_t opcode_;
998   int32_t length_;
999 };
1000 
1001 class ClientConnection : public internal::ClientConnection {
1002 public:
ClientConnection(internal::ServerConnection * server,const RequestHandler * request_handler,const Cluster * cluster)1003   ClientConnection(internal::ServerConnection* server, const RequestHandler* request_handler,
1004                    const Cluster* cluster)
1005       : internal::ClientConnection(server)
1006       , handler_(request_handler)
1007       , cluster_(cluster)
1008       , protocol_version_(-1)
1009       , is_registered_for_events_(false) {}
1010 
1011   virtual void on_read(const char* data, size_t len);
1012 
cluster() const1013   const Cluster* cluster() const { return cluster_; }
1014 
protocol_version() const1015   int protocol_version() const { return protocol_version_; }
set_protocol_version(int protocol_version)1016   void set_protocol_version(int protocol_version) { protocol_version_ = protocol_version; }
1017 
is_registered_for_events() const1018   bool is_registered_for_events() const { return is_registered_for_events_; }
set_registered_for_events()1019   void set_registered_for_events() { is_registered_for_events_ = true; }
options() const1020   const Options& options() const { return options_; }
set_options(const Options & options)1021   void set_options(const Options& options) { options_ = options; }
1022 
keyspace() const1023   const String& keyspace() const { return keyspace_; }
set_keyspace(const String & keyspace)1024   void set_keyspace(const String& keyspace) { keyspace_ = keyspace; }
1025 
1026 private:
1027   ProtocolHandler handler_;
1028   String keyspace_;
1029   const Cluster* cluster_;
1030   int protocol_version_;
1031   bool is_registered_for_events_;
1032   Options options_;
1033 };
1034 
1035 class CloseConnection : public ClientConnection {
1036 public:
CloseConnection(internal::ServerConnection * server,const RequestHandler * request_handler,const Cluster * cluster)1037   CloseConnection(internal::ServerConnection* server, const RequestHandler* request_handler,
1038                   const Cluster* cluster)
1039       : ClientConnection(server, request_handler, cluster) {}
1040 
on_accept()1041   int on_accept() {
1042     int rc = accept();
1043     if (rc != 0) {
1044       return rc;
1045     }
1046     close();
1047     return rc;
1048   }
1049 };
1050 
1051 class ClientConnectionFactory : public internal::ClientConnectionFactory {
1052 public:
ClientConnectionFactory(const RequestHandler * request_handler,const Cluster * cluster)1053   ClientConnectionFactory(const RequestHandler* request_handler, const Cluster* cluster)
1054       : request_handler_(request_handler)
1055       , cluster_(cluster)
1056       , close_immediately_(false) {}
1057 
use_close_immediately()1058   void use_close_immediately() { close_immediately_ = true; }
1059 
create(internal::ServerConnection * server) const1060   virtual internal::ClientConnection* create(internal::ServerConnection* server) const {
1061     if (close_immediately_) {
1062       return new CloseConnection(server, request_handler_.get(), cluster_);
1063     } else {
1064       return new ClientConnection(server, request_handler_.get(), cluster_);
1065     }
1066   }
1067 
1068 private:
1069   ScopedPtr<const RequestHandler> request_handler_;
1070   const Cluster* cluster_;
1071   bool close_immediately_;
1072 };
1073 
1074 class AddressGenerator {
1075 public:
1076   virtual Address next() = 0;
1077 };
1078 
1079 class Ipv4AddressGenerator : public AddressGenerator {
1080 public:
Ipv4AddressGenerator(uint8_t a=127,uint8_t b=0,uint8_t c=0,uint8_t d=1,int port=9042)1081   Ipv4AddressGenerator(uint8_t a = 127, uint8_t b = 0, uint8_t c = 0, uint8_t d = 1,
1082                        int port = 9042)
1083       : ip_((a << 24) | (b << 16) | (c << 8) | (d & 0xff))
1084       , port_(port) {}
1085 
1086   virtual Address next();
1087 
1088 private:
1089   uint32_t ip_;
1090   int port_;
1091 };
1092 
1093 class Event : public internal::ServerConnectionTask {
1094 public:
1095   typedef SharedRefPtr<Event> Ptr;
1096 
1097   Event(const String& event_body);
1098 
1099   virtual void run(internal::ServerConnection* server_connection);
1100 
1101 private:
1102   String event_body_;
1103 };
1104 
1105 class TopologyChangeEvent : public Event {
1106 public:
1107   enum Type { NEW_NODE, MOVED_NODE, REMOVED_NODE };
1108 
TopologyChangeEvent(Type type,const Address & address)1109   TopologyChangeEvent(Type type, const Address& address)
1110       : Event(encode(type, address)) {}
1111 
1112   static Ptr new_node(const Address& address);
1113   static Ptr moved_node(const Address& address);
1114   static Ptr removed_node(const Address& address);
1115 
1116   static String encode(Type type, const Address& address);
1117 };
1118 
1119 class StatusChangeEvent : public Event {
1120 public:
1121   enum Type { UP, DOWN };
1122 
StatusChangeEvent(Type type,const Address & address)1123   StatusChangeEvent(Type type, const Address& address)
1124       : Event(encode(type, address)) {}
1125 
1126   static Ptr up(const Address& address);
1127   static Ptr down(const Address& address);
1128 
1129   static String encode(Type type, const Address& address);
1130 };
1131 
1132 class SchemaChangeEvent : public Event {
1133 public:
1134   enum Type { CREATED, UPDATED, DROPPED };
1135 
1136   enum Target { KEYSPACE, TABLE, USER_TYPE, FUNCTION, AGGREGATE };
1137 
SchemaChangeEvent(Target target,Type type,const String & keyspace_name,const String & target_name="",const Vector<String> & args_types=Vector<String> ())1138   SchemaChangeEvent(Target target, Type type, const String& keyspace_name,
1139                     const String& target_name = "",
1140                     const Vector<String>& args_types = Vector<String>())
1141       : Event(encode(target, type, keyspace_name, target_name, args_types)) {}
1142 
1143   static Ptr keyspace(Type type, const String& keyspace_name);
1144   static Ptr table(Type type, const String& keyspace_name, const String& table_name);
1145   static Ptr user_type(Type type, const String& keyspace_name, const String& user_type_name);
1146   static Ptr function(Type type, const String& keyspace_name, const String& function_name,
1147                       const Vector<String>& args_types);
1148   static Ptr aggregate(Type type, const String& keyspace_name, const String& aggregate_name,
1149                        const Vector<String>& args_types);
1150 
1151   static String encode(Target target, Type type, const String& keyspace_name,
1152                        const String& target_name, const Vector<String>& arg_types);
1153 };
1154 
1155 class Cluster {
1156 protected:
1157   void init(AddressGenerator& generator, ClientConnectionFactory& factory, size_t num_nodes_dc1,
1158             size_t num_nodes_dc2);
1159 
1160 public:
1161   ~Cluster();
1162 
1163   String use_ssl(const String& cn = "");
1164 
1165   int start_all(EventLoopGroup* event_loop_group);
1166   void start_all_async(EventLoopGroup* event_loop_group);
1167 
1168   void stop_all();
1169   void stop_all_async();
1170 
1171   int start(EventLoopGroup* event_loop_group, size_t node);
1172   void start_async(EventLoopGroup* event_loop_group, size_t node);
1173 
1174   void stop(size_t node);
1175   void stop_async(size_t node);
1176 
1177   int add(EventLoopGroup* event_loop_group, size_t node);
1178   void remove(size_t node);
1179 
1180   const Host& host(const Address& address) const;
1181   Hosts hosts() const;
1182 
1183   unsigned connection_attempts(size_t node) const;
1184 
1185   void event(const Event::Ptr& event);
1186 
1187 private:
1188   struct Server {
Servermockssandra::Cluster::Server1189     Server(const Host& host, const internal::ServerConnection::Ptr& connection)
1190         : host(host)
1191         , connection(connection)
1192         , is_removed(false) {}
1193 
Servermockssandra::Cluster::Server1194     Server(const Server& server)
1195         : host(server.host)
1196         , connection(server.connection)
1197         , is_removed(server.is_removed.load()) {}
1198 
operator =mockssandra::Cluster::Server1199     Server& operator=(const Server& server) {
1200       host = server.host;
1201       connection = server.connection;
1202       is_removed.store(server.is_removed.load());
1203       return *this;
1204     }
1205 
1206     Host host;
1207     internal::ServerConnection::Ptr connection;
1208     Atomic<bool> is_removed;
1209   };
1210 
1211   typedef Vector<Server> Servers;
1212 
1213   int create_and_add_server(AddressGenerator& generator, ClientConnectionFactory& factory,
1214                             const String& dc);
1215 
1216 private:
1217   Servers servers_;
1218   MT19937_64 token_rng_;
1219 };
1220 
1221 class SimpleEventLoopGroup : public RoundRobinEventLoopGroup {
1222 public:
1223   SimpleEventLoopGroup(size_t num_threads = 1, const String& thread_name = "mockssandra");
1224   ~SimpleEventLoopGroup();
1225 };
1226 
1227 class SimpleRequestHandlerBuilder : public RequestHandler::Builder {
1228 public:
1229   SimpleRequestHandlerBuilder();
1230 };
1231 
1232 class AuthRequestHandlerBuilder : public SimpleRequestHandlerBuilder {
1233 public:
1234   AuthRequestHandlerBuilder(const String& username = "cassandra",
1235                             const String& password = "cassandra");
1236 };
1237 
1238 class SimpleCluster : public Cluster {
1239 public:
SimpleCluster(const RequestHandler * request_handler,size_t num_nodes_dc1=1,size_t num_nodes_dc2=0)1240   SimpleCluster(const RequestHandler* request_handler, size_t num_nodes_dc1 = 1,
1241                 size_t num_nodes_dc2 = 0)
1242       : factory_(request_handler, this)
1243       , event_loop_group_(1) {
1244     init(generator_, factory_, num_nodes_dc1, num_nodes_dc2);
1245   }
1246 
~SimpleCluster()1247   ~SimpleCluster() { stop_all(); }
1248 
use_close_immediately()1249   void use_close_immediately() { factory_.use_close_immediately(); }
1250 
start_all()1251   int start_all() { return Cluster::start_all(&event_loop_group_); }
1252 
start(size_t node)1253   int start(size_t node) { return Cluster::start(&event_loop_group_, node); }
1254 
add(size_t node)1255   int add(size_t node) { return Cluster::add(&event_loop_group_, node); }
1256 
1257 private:
1258   Ipv4AddressGenerator generator_;
1259   ClientConnectionFactory factory_;
1260   SimpleEventLoopGroup event_loop_group_;
1261 };
1262 
1263 class SimpleEchoServer {
1264 public:
SimpleEchoServer()1265   SimpleEchoServer()
1266       : factory_(new EchoClientConnectionFactory())
1267       , event_loop_group_(1) {}
1268 
~SimpleEchoServer()1269   ~SimpleEchoServer() { close(); }
1270 
close()1271   void close() {
1272     if (server_) {
1273       server_->close();
1274       server_->wait_close();
1275     }
1276   }
1277 
use_ssl(const String & cn="")1278   String use_ssl(const String& cn = "") {
1279     ssl_key_ = Ssl::generate_key();
1280     ssl_cert_ = Ssl::generate_cert(ssl_key_, cn);
1281     return ssl_cert_;
1282   }
1283 
use_connection_factory(internal::ClientConnectionFactory * factory)1284   void use_connection_factory(internal::ClientConnectionFactory* factory) {
1285     factory_.reset(factory);
1286   }
1287 
listen(const Address & address=Address ("127.0.0.1",8888))1288   int listen(const Address& address = Address("127.0.0.1", 8888)) {
1289     server_.reset(new internal::ServerConnection(address, *factory_));
1290     if (!ssl_key_.empty() && !ssl_cert_.empty() && !server_->use_ssl(ssl_key_, ssl_cert_)) {
1291       return -1;
1292     }
1293     server_->listen(&event_loop_group_);
1294     return server_->wait_listen();
1295   }
1296 
1297 private:
1298   class EchoConnection : public internal::ClientConnection {
1299   public:
EchoConnection(internal::ServerConnection * server)1300     EchoConnection(internal::ServerConnection* server)
1301         : internal::ClientConnection(server) {}
1302 
on_read(const char * data,size_t len)1303     virtual void on_read(const char* data, size_t len) { write(data, len); }
1304   };
1305 
1306   class EchoClientConnectionFactory : public internal::ClientConnectionFactory {
1307   public:
create(internal::ServerConnection * server) const1308     virtual internal::ClientConnection* create(internal::ServerConnection* server) const {
1309       return new EchoConnection(server);
1310     }
1311   };
1312 
1313 private:
1314   ScopedPtr<internal::ClientConnectionFactory> factory_;
1315   SimpleEventLoopGroup event_loop_group_;
1316   internal::ServerConnection::Ptr server_;
1317   String ssl_key_;
1318   String ssl_cert_;
1319 };
1320 
1321 } // namespace mockssandra
1322 
1323 #endif // MOCKSSANDRA_HPP
1324