1 /*
2   Copyright (c) DataStax, Inc.
3 
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7 
8   http://www.apache.org/licenses/LICENSE-2.0
9 
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15 */
16 
17 #include "mockssandra.hpp"
18 
19 #include <assert.h>
20 #include <stdio.h>
21 
22 #include "control_connection.hpp" // For host queries
23 #include "memory.hpp"
24 #include "scoped_lock.hpp"
25 #include "tracing_data_handler.hpp" // For tracing query
26 #include "utils.hpp"
27 #include "uuids.hpp"
28 
29 #include <openssl/bio.h>
30 #include <openssl/dh.h>
31 #include <openssl/x509v3.h>
32 
33 #ifdef WIN32
34 #include "winsock.h"
35 #endif
36 
37 using datastax::internal::bind_callback;
38 using datastax::internal::escape_id;
39 using datastax::internal::Map;
40 using datastax::internal::Memory;
41 using datastax::internal::OStringStream;
42 using datastax::internal::ScopedMutex;
43 using datastax::internal::trim;
44 using datastax::internal::core::UuidGen;
45 
46 #define SSL_BUF_SIZE 8192
47 #define CASSANDRA_VERSION "3.11.4"
48 #define DSE_VERSION "6.7.1"
49 #define DSE_CASSANDRA_VERSION "4.0.0.671"
50 
51 #if defined(OPENSSL_VERSION_NUMBER) && \
52     !defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
53                                       // as 2.0.0
54 #if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
55 #define SSL_SERVER_METHOD TLS_server_method
56 #else
57 #define SSL_SERVER_METHOD SSLv23_server_method
58 #endif
59 #else
60 #if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
61 #define SSL_SERVER_METHOD TLS_server_method
62 #else
63 #define SSL_SERVER_METHOD SSLv23_server_method
64 #endif
65 #endif
66 
67 namespace mockssandra {
68 
69 namespace {
70 
71 template <class T>
72 struct FreeDeleterImpl {};
73 
74 #define MAKE_DELETER(type, free_func)               \
75   template <>                                       \
76   struct FreeDeleterImpl<type> {                    \
77     static void free(type* ptr) { free_func(ptr); } \
78   };
79 
80 MAKE_DELETER(BIO, BIO_free)
81 MAKE_DELETER(DH, DH_free)
82 MAKE_DELETER(EVP_PKEY, EVP_PKEY_free)
83 MAKE_DELETER(EVP_PKEY_CTX, EVP_PKEY_CTX_free)
84 MAKE_DELETER(X509, X509_free)
85 MAKE_DELETER(X509_REQ, X509_REQ_free)
86 MAKE_DELETER(X509_EXTENSION, X509_EXTENSION_free)
87 
88 template <class T>
89 struct FreeDeleter {
operator ()mockssandra::__anon49d184150111::FreeDeleter90   void operator()(T* ptr) const { FreeDeleterImpl<T>::free(ptr); }
91 };
92 
93 template <class T>
94 class Scoped : public ScopedPtr<T, FreeDeleter<T> > {
95 public:
Scoped(T * ptr=NULL)96   Scoped(T* ptr = NULL)
97       : ScopedPtr<T, FreeDeleter<T> >(ptr) {}
98 };
99 
print_ssl_error()100 void print_ssl_error() {
101   unsigned long err = ERR_get_error();
102   char buf[256];
103   ERR_error_string_n(err, buf, sizeof(buf));
104   fprintf(stderr, "%s\n", buf);
105 }
106 
load_cert(const String & cert)107 X509* load_cert(const String& cert) {
108   X509* x509 = NULL;
109   Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(cert.c_str()), cert.length()));
110   if (PEM_read_bio_X509(bio.get(), &x509, NULL, NULL) == NULL) {
111     print_ssl_error();
112     return NULL;
113   }
114   return x509;
115 }
116 
load_private_key(const String & key)117 EVP_PKEY* load_private_key(const String& key) {
118   EVP_PKEY* pkey = NULL;
119   Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(key.c_str()), key.length()));
120   if (!PEM_read_bio_PrivateKey(bio.get(), &pkey, NULL, NULL)) {
121     print_ssl_error();
122     return NULL;
123   }
124   return pkey;
125 }
126 
dh_parameters()127 DH* dh_parameters() {
128   // Generated using the following command: `openssl dhparam -C 2048`
129   // Prime length of 2048 chosen to bypass client-side error:
130   // `SSL3_CHECK_CERT_AND_ALGORITHM:dh key too small`
131 
132   // Note: This is not generated, programmatically, using something like the following:
133   // `DH_generate_parameters_ex(dh, 2048, DH_GENERATOR_5, NULL)`
134   // because DH prime generation takes a *REALLY* long time.
135   static const char* dh_parameters_pem =
136       "-----BEGIN DH PARAMETERS-----\n"
137       "MIIBCAKCAQEAusYypYO7u8mHelHjpDuUy7hjBgPw/KS03iSRnP5SNMB6OxVFslXv\n"
138       "s6McqEf218Fqpzi18tWA7fq3fvlT+Nx1Tda+Za5C8o5niRYxHks5N+RfnnrFf7vn\n"
139       "0lxrzsXP6es08Ts/UGMsp1nEaCSd/gjDglPgjdC1V/KmBsbT+8IwpbzPPdir0/jA\n"
140       "r+DXssZRZl7JtymGHXPkXTSBhsqSHamfzGRnAQFWToKAinqAdhY7pN/8krwvRj04\n"
141       "VYp84xAy2M6mWWqUm/kokN9QjAiT/DZRxZK8VhY7O9+oATo7/YPCMd9Em417O13k\n"
142       "+F0o/8IMaQvpmtlAsLc2ZKwGqqG+HD2dOwIBAg==\n"
143       "-----END DH PARAMETERS-----";
144   Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(dh_parameters_pem),
145                                   -1)); // Use null terminator for length
146   return PEM_read_bio_DHparams(bio.get(), NULL, NULL, NULL);
147 }
148 
149 } // namespace
150 
generate_key()151 String Ssl::generate_key() {
152   Scoped<EVP_PKEY_CTX> pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL));
153 
154   EVP_PKEY_keygen_init(pctx.get());
155   EVP_PKEY_CTX_set_rsa_keygen_bits(pctx.get(), 2048);
156 
157   Scoped<EVP_PKEY> pkey;
158   { // Generate RSA key
159     EVP_PKEY* temp = NULL;
160     EVP_PKEY_keygen(pctx.get(), &temp);
161     pkey.reset(temp);
162   }
163 
164   Scoped<BIO> bio(BIO_new(BIO_s_mem()));
165   PEM_write_bio_PrivateKey(bio.get(), pkey.get(), NULL, NULL, 0, NULL, NULL);
166 
167   BUF_MEM* mem = NULL;
168   BIO_get_mem_ptr(bio.get(), &mem);
169   return String(mem->data, mem->length);
170 }
171 
generate_cert(const String & key,String cn,String ca_cert,String ca_key)172 String Ssl::generate_cert(const String& key, String cn, String ca_cert, String ca_key) {
173   // Assign the proper default hostname
174   if (cn.empty()) {
175 #ifdef WIN32
176     char win_hostname[64];
177     gethostname(win_hostname, 64);
178     cn = win_hostname;
179 #else
180     cn = "localhost";
181 #endif
182   }
183 
184   Scoped<EVP_PKEY> pkey(load_private_key(key));
185   if (!pkey) return "";
186 
187   Scoped<X509_REQ> x509_req;
188   if (!ca_cert.empty() && !ca_key.empty()) {
189     x509_req.reset(X509_REQ_new());
190     X509_REQ_set_version(x509_req.get(), 2);
191     X509_REQ_set_pubkey(x509_req.get(), pkey.get());
192 
193     X509_NAME* name = X509_REQ_get_subject_name(x509_req.get());
194     X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC,
195                                reinterpret_cast<const unsigned char*>("US"), -1, -1, 0);
196     X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC,
197                                reinterpret_cast<const unsigned char*>(cn.c_str()), -1, -1, 0);
198     X509_REQ_sign(x509_req.get(), pkey.get(), EVP_sha256());
199   }
200 
201   Scoped<X509> x509(X509_new());
202   X509_set_version(x509.get(), 2);
203   ASN1_INTEGER_set(X509_get_serialNumber(x509.get()), 0);
204   X509_gmtime_adj(X509_get_notBefore(x509.get()), 0);
205   X509_gmtime_adj(X509_get_notAfter(x509.get()), static_cast<long>(60 * 60 * 24 * 365));
206   X509_set_pubkey(x509.get(), pkey.get());
207 
208   if (x509_req) {
209     X509_set_subject_name(x509.get(), X509_REQ_get_subject_name(x509_req.get()));
210 
211     Scoped<X509> x509_ca(load_cert(ca_cert));
212     if (!x509_ca) return "";
213     X509_set_issuer_name(x509.get(), X509_get_issuer_name(x509_ca.get()));
214 
215     Scoped<EVP_PKEY> pkey_ca(load_private_key(ca_key));
216     if (!pkey_ca) return "";
217     X509_sign(x509.get(), pkey_ca.get(), EVP_sha256());
218   } else {
219     if (cn == "CA") { // Set the purpose as a CA certificate.
220       X509V3_CTX x509v3_ctx;
221       X509V3_set_ctx_nodb(&x509v3_ctx);
222       X509V3_set_ctx(&x509v3_ctx, x509.get(), x509.get(), NULL, NULL, 0);
223 
224       Scoped<X509_EXTENSION> x509_ex(X509V3_EXT_conf_nid(NULL, &x509v3_ctx, NID_basic_constraints,
225                                                          const_cast<char*>("critical,CA:TRUE")));
226       if (!x509_ex) return "";
227 
228       X509_add_ext(x509.get(), x509_ex.get(), -1);
229     }
230     X509_NAME* name = X509_get_subject_name(x509.get());
231     X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC,
232                                reinterpret_cast<const unsigned char*>("US"), -1, -1, 0);
233     X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC,
234                                reinterpret_cast<const unsigned char*>(cn.c_str()), -1, -1, 0);
235     X509_set_issuer_name(x509.get(), name);
236     X509_sign(x509.get(), pkey.get(), EVP_sha256());
237   }
238 
239   String result;
240   { // Write cert into string
241     Scoped<BIO> bio(BIO_new(BIO_s_mem()));
242     PEM_write_bio_X509(bio.get(), x509.get());
243     BUF_MEM* mem = NULL;
244     BIO_get_mem_ptr(bio.get(), &mem);
245     result.append(mem->data, mem->length);
246   }
247 
248   return result;
249 }
250 
251 namespace internal {
252 
253 struct WriteReq {
WriteReqmockssandra::internal::WriteReq254   WriteReq(const char* data, size_t len, ClientConnection* connection)
255       : data(data, len)
256       , connection(connection) {
257     req.data = this;
258   }
259   const String data;
260   ClientConnection* const connection;
261   uv_write_t req;
262 };
263 
Tcp(void * data)264 Tcp::Tcp(void* data) { tcp_.data = data; }
265 
init(uv_loop_t * loop)266 int Tcp::init(uv_loop_t* loop) { return uv_tcp_init(loop, &tcp_); }
267 
bind(const struct sockaddr * addr)268 int Tcp::bind(const struct sockaddr* addr) { return uv_tcp_bind(&tcp_, addr, 0); }
269 
as_handle()270 uv_handle_t* Tcp::as_handle() { return reinterpret_cast<uv_handle_t*>(&tcp_); }
271 
as_stream()272 uv_stream_t* Tcp::as_stream() { return reinterpret_cast<uv_stream_t*>(&tcp_); }
273 
ClientConnection(ServerConnection * server)274 ClientConnection::ClientConnection(ServerConnection* server)
275     : tcp_(this)
276     , server_(server)
277     , ssl_(server->ssl_context() ? SSL_new(server->ssl_context()) : NULL)
278     , incoming_bio_(ssl_ ? BIO_new(BIO_s_mem()) : NULL)
279     , outgoing_bio_(ssl_ ? BIO_new(BIO_s_mem()) : NULL) {
280   tcp_.init(server->loop());
281   if (ssl_) {
282     SSL_set_bio(ssl_, incoming_bio_, outgoing_bio_);
283   }
284 }
285 
~ClientConnection()286 ClientConnection::~ClientConnection() {
287   if (ssl_) SSL_free(ssl_);
288 }
289 
write(const String & data)290 int ClientConnection::write(const String& data) { return write(data.data(), data.length()); }
291 
write(const char * data,size_t len)292 int ClientConnection::write(const char* data, size_t len) {
293   if (ssl_) {
294     return ssl_write(data, len);
295   } else {
296     return internal_write(data, len);
297   }
298 }
299 
close()300 void ClientConnection::close() {
301   if (!uv_is_closing(tcp_.as_handle())) {
302     uv_close(tcp_.as_handle(), on_close);
303   }
304 }
305 
accept()306 int ClientConnection::accept() {
307   int rc = server_->accept(tcp_.as_stream());
308   if (rc != 0) return rc;
309   return uv_read_start(tcp_.as_stream(), on_alloc, on_read);
310 }
311 
sni_server_name() const312 const char* ClientConnection::sni_server_name() const {
313   if (ssl_) {
314     return SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name);
315   }
316   return NULL;
317 }
318 
on_close(uv_handle_t * handle)319 void ClientConnection::on_close(uv_handle_t* handle) {
320   ClientConnection* connection = static_cast<ClientConnection*>(handle->data);
321   connection->handle_close();
322 }
323 
handle_close()324 void ClientConnection::handle_close() {
325   on_close();
326   server_->remove(this);
327   delete this;
328 }
329 
on_alloc(uv_handle_t * handle,size_t suggested_size,uv_buf_t * buf)330 void ClientConnection::on_alloc(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) {
331   buf->base = static_cast<char*>(Memory::malloc(suggested_size));
332   buf->len = suggested_size;
333 }
334 
on_read(uv_stream_t * stream,ssize_t nread,const uv_buf_t * buf)335 void ClientConnection::on_read(uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf) {
336   ClientConnection* connection = static_cast<ClientConnection*>(stream->data);
337   connection->handle_read(nread, buf);
338   Memory::free(buf->base);
339 }
340 
handle_read(ssize_t nread,const uv_buf_t * buf)341 void ClientConnection::handle_read(ssize_t nread, const uv_buf_t* buf) {
342   if (nread < 0) {
343     if (nread != UV_EOF && nread != UV_ECONNRESET) {
344       fprintf(stderr, "Read failure: %s\n", uv_strerror(nread));
345     }
346     close();
347     return;
348   }
349   if (ssl_) {
350     on_ssl_read(buf->base, nread);
351   } else {
352     on_read(buf->base, nread);
353   }
354 }
355 
on_write(uv_write_t * req,int status)356 void ClientConnection::on_write(uv_write_t* req, int status) {
357   WriteReq* write = static_cast<WriteReq*>(req->data);
358   write->connection->handle_write(status);
359   delete write;
360 }
361 
handle_write(int status)362 void ClientConnection::handle_write(int status) {
363   if (status != 0) {
364     fprintf(stderr, "Write failure: %s\n", uv_strerror(status));
365     close();
366     return;
367   }
368 
369   on_write();
370 }
371 
internal_write(const char * data,size_t len)372 int ClientConnection::internal_write(const char* data, size_t len) {
373   uv_buf_t buf;
374   WriteReq* write = new WriteReq(data, len, this);
375   buf.base = const_cast<char*>(write->data.data());
376   buf.len = write->data.length();
377   int rc = uv_write(&write->req, tcp_.as_stream(), &buf, 1, on_write);
378   if (rc != 0) {
379     delete write;
380   }
381   return rc;
382 }
383 
ssl_write(const char * data,size_t len)384 int ClientConnection::ssl_write(const char* data, size_t len) {
385   if (has_ssl_error(SSL_write(ssl_, data, len))) {
386     return -1;
387   }
388 
389   char buf[SSL_BUF_SIZE];
390   int num_bytes;
391   while ((num_bytes = BIO_read(outgoing_bio_, buf, sizeof(buf))) > 0) {
392     int rc = internal_write(buf, num_bytes);
393     if (rc != 0) {
394       return rc;
395     }
396   }
397 
398   return 0;
399 }
400 
is_handshake_done()401 bool ClientConnection::is_handshake_done() { return SSL_is_init_finished(ssl_) != 0; }
402 
has_ssl_error(int rc)403 bool ClientConnection::has_ssl_error(int rc) {
404   if (rc > 0) return false;
405 
406   int err = SSL_get_error(ssl_, rc);
407   if (err == SSL_ERROR_ZERO_RETURN) {
408     close();
409   } else if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_NONE) {
410     const char* data;
411     int flags;
412     int err;
413     String error;
414     while ((err = ERR_get_error_line_data(NULL, NULL, &data, &flags)) != 0) {
415       char buf[256];
416       ERR_error_string_n(err, buf, sizeof(buf));
417       if (!error.empty()) error.push_back(',');
418       error.append(buf);
419       if (flags & ERR_TXT_STRING) {
420         error.push_back(':');
421         error.append(data);
422       }
423     }
424     fprintf(stderr, "SSL error: %s\n", error.c_str());
425     close();
426     return true;
427   }
428 
429   return false;
430 }
431 
on_ssl_read(const char * data,size_t len)432 void ClientConnection::on_ssl_read(const char* data, size_t len) {
433   int rc;
434   BIO_write(incoming_bio_, data, len);
435 
436   if (!is_handshake_done()) {
437     int rc = SSL_accept(ssl_);
438     if (has_ssl_error(rc)) {
439       return;
440     }
441 
442     char buf[SSL_BUF_SIZE];
443     bool data_written = false;
444     int num_bytes;
445     while ((num_bytes = BIO_read(outgoing_bio_, buf, sizeof(buf))) > 0) {
446       data_written = true;
447       internal_write(buf, num_bytes);
448     }
449 
450     if (is_handshake_done() && data_written) {
451       return; // Handshake is not completed; ingore remaining data
452     }
453   } else {
454     char buf[SSL_BUF_SIZE];
455     while ((rc = SSL_read(ssl_, buf, sizeof(buf))) > 0) {
456       on_read(buf, rc);
457     }
458     has_ssl_error(rc);
459   }
460 }
461 
ServerConnection(const Address & address,const ClientConnectionFactory & factory)462 ServerConnection::ServerConnection(const Address& address, const ClientConnectionFactory& factory)
463     : tcp_(this)
464     , event_loop_(NULL)
465     , state_(STATE_CLOSED)
466     , rc_(0)
467     , address_(address)
468     , factory_(factory)
469     , ssl_context_(NULL)
470     , connection_attempts_(0) {
471   uv_mutex_init(&mutex_);
472   uv_cond_init(&cond_);
473 }
474 
~ServerConnection()475 ServerConnection::~ServerConnection() {
476   uv_mutex_destroy(&mutex_);
477   uv_cond_destroy(&cond_);
478   if (ssl_context_) {
479     SSL_CTX_free(ssl_context_);
480   }
481 }
482 
loop()483 uv_loop_t* ServerConnection::loop() {
484   ScopedMutex l(&mutex_);
485   return event_loop_->loop();
486 }
487 
use_ssl(const String & key,const String & cert,const String & ca_cert,bool require_client_cert)488 bool ServerConnection::use_ssl(const String& key, const String& cert,
489                                const String& ca_cert /*= ""*/,
490                                bool require_client_cert /*= false*/) {
491   if (ssl_context_) {
492     SSL_CTX_free(ssl_context_);
493   }
494 
495   if ((ssl_context_ = SSL_CTX_new(SSL_SERVER_METHOD())) == NULL) {
496     print_ssl_error();
497     return false;
498   }
499 
500   SSL_CTX_set_default_passwd_cb_userdata(ssl_context_, (void*)"");
501   SSL_CTX_set_default_passwd_cb(ssl_context_, on_password);
502   SSL_CTX_set_verify(ssl_context_, SSL_VERIFY_NONE, NULL);
503 
504   { // Load server certificate
505     Scoped<X509> x509(load_cert(cert));
506     if (!x509) return false;
507     if (SSL_CTX_use_certificate(ssl_context_, x509.get()) <= 0) {
508       print_ssl_error();
509       return false;
510     }
511   }
512 
513   if (!ca_cert.empty()) { // Load CA certificate
514 
515     { // Add CA certificate to chain to send to the client
516       Scoped<X509> x509(load_cert(ca_cert));
517       if (!x509) return false;
518 
519       if (SSL_CTX_add_extra_chain_cert(ssl_context_, x509.release()) <=
520           0) { // Certificate freed by function
521         print_ssl_error();
522         return false;
523       }
524     }
525 
526     if (require_client_cert) {
527       Scoped<X509> x509(load_cert(ca_cert));
528       if (!x509) return false;
529 
530       // Add CA certificate to chain to validate peer certificate
531       X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context_);
532       if (X509_STORE_add_cert(cert_store, x509.get()) <= 0) {
533         print_ssl_error();
534         return false;
535       }
536 
537       SSL_CTX_set_verify(ssl_context_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
538     }
539   }
540 
541   Scoped<EVP_PKEY> pkey(load_private_key(key));
542   if (!pkey) return false;
543 
544   if (SSL_CTX_use_PrivateKey(ssl_context_, pkey.get()) <= 0) {
545     print_ssl_error();
546     return false;
547   }
548 
549   Scoped<DH> dh(dh_parameters());
550   if (!dh || !SSL_CTX_set_tmp_dh(ssl_context_, dh.get())) {
551     print_ssl_error();
552     return false;
553   }
554 
555   return true;
556 }
557 
558 using datastax::internal::core::Task;
559 
560 class RunListen : public Task {
561 public:
RunListen(ServerConnection * server)562   RunListen(ServerConnection* server)
563       : server_(server) {}
564 
run(EventLoop * event_loop)565   virtual void run(EventLoop* event_loop) { server_->internal_listen(); }
566 
567 private:
568   ServerConnection* server_;
569 };
570 
571 class RunClose : public Task {
572 public:
RunClose(ServerConnection * server)573   RunClose(ServerConnection* server)
574       : server_(server) {}
575 
run(EventLoop * event_loop)576   virtual void run(EventLoop* event_loop) { server_->internal_close(); }
577 
578 private:
579   ServerConnection* server_;
580 };
581 
582 class RunTask : public Task {
583 public:
RunTask(const ServerConnectionTask::Ptr & task,const ServerConnection::Ptr & connection)584   RunTask(const ServerConnectionTask::Ptr& task, const ServerConnection::Ptr& connection)
585       : task_(task)
586       , connection_(connection) {}
587 
run(EventLoop * event_loop)588   virtual void run(EventLoop* event_loop) { task_->run(connection_.get()); }
589 
590 private:
591   ServerConnectionTask::Ptr task_;
592   ServerConnection::Ptr connection_;
593 };
594 
listen(EventLoopGroup * event_loop_group)595 void ServerConnection::listen(EventLoopGroup* event_loop_group) {
596   ScopedMutex l(&mutex_);
597   if (state_ != STATE_CLOSED) return;
598   rc_ = 0;
599   state_ = STATE_PENDING;
600   event_loop_ = event_loop_group->add(new RunListen(this));
601 }
602 
wait_listen()603 int ServerConnection::wait_listen() {
604   ScopedMutex l(&mutex_);
605   while (state_ == STATE_PENDING) {
606     uv_cond_wait(&cond_, l.get());
607   }
608   return rc_;
609 }
610 
close()611 void ServerConnection::close() {
612   ScopedMutex l(&mutex_);
613   if (state_ != STATE_LISTENING && state_ != STATE_PENDING) return;
614   state_ = STATE_CLOSING;
615   event_loop_->add(new RunClose(this));
616 }
617 
wait_close()618 void ServerConnection::wait_close() {
619   ScopedMutex l(&mutex_);
620   while (state_ == STATE_CLOSING) {
621     uv_cond_wait(&cond_, l.get());
622   }
623 }
624 
connection_attempts() const625 unsigned ServerConnection::connection_attempts() const { return connection_attempts_.load(); }
run(const ServerConnectionTask::Ptr & task)626 void ServerConnection::run(const ServerConnectionTask::Ptr& task) {
627   ScopedMutex l(&mutex_);
628   if (state_ != STATE_LISTENING) return;
629   event_loop_->add(new RunTask(task, Ptr(this)));
630 }
631 
internal_listen()632 void ServerConnection::internal_listen() {
633   int rc = 0;
634 
635   rc = tcp_.init(loop());
636   if (rc != 0) {
637     fprintf(stderr, "Unable to initialize socket\n");
638     signal_listen(rc);
639     return;
640   }
641 
642   inc_ref(); // For the TCP handle
643 
644   Address::SocketStorage storage;
645   rc = tcp_.bind(address_.to_sockaddr(&storage));
646   if (rc != 0) {
647     fprintf(stderr, "Unable to bind address %s\n", address_.to_string(true).c_str());
648     uv_close(tcp_.as_handle(), on_close);
649     signal_listen(rc);
650     return;
651   }
652 
653   rc = uv_listen(tcp_.as_stream(), 128, on_connection);
654   if (rc != 0) {
655     fprintf(stderr, "Unable to listen on address %s\n", address_.to_string(true).c_str());
656     uv_close(tcp_.as_handle(), on_close);
657     signal_listen(rc);
658     return;
659   }
660 
661   signal_listen(rc);
662 }
663 
accept(uv_stream_t * client)664 int ServerConnection::accept(uv_stream_t* client) { return uv_accept(tcp_.as_stream(), client); }
665 
remove(ClientConnection * connection)666 void ServerConnection::remove(ClientConnection* connection) {
667   clients_.erase(std::remove(clients_.begin(), clients_.end(), connection), clients_.end());
668   maybe_close();
669 }
670 
internal_close()671 void ServerConnection::internal_close() {
672   for (ClientConnections::iterator it = clients_.begin(), end = clients_.end(); it != end; ++it) {
673     (*it)->close();
674   }
675   maybe_close();
676 }
677 
maybe_close()678 void ServerConnection::maybe_close() {
679   ScopedMutex l(&mutex_);
680   if (state_ == STATE_CLOSING && clients_.empty() && !uv_is_closing(tcp_.as_handle())) {
681     uv_close(tcp_.as_handle(), on_close);
682   }
683 }
684 
signal_listen(int rc)685 void ServerConnection::signal_listen(int rc) {
686   ScopedMutex l(&mutex_);
687   if (rc != 0) {
688     rc_ = rc;
689     state_ = STATE_CLOSING;
690   } else {
691     state_ = STATE_LISTENING;
692   }
693   uv_cond_signal(&cond_);
694 }
695 
signal_close()696 void ServerConnection::signal_close() {
697   ScopedMutex l(&mutex_);
698   event_loop_ = NULL;
699   state_ = STATE_CLOSED;
700   uv_cond_signal(&cond_);
701 }
702 
on_connection(uv_stream_t * server,int status)703 void ServerConnection::on_connection(uv_stream_t* server, int status) {
704   ServerConnection* self = static_cast<ServerConnection*>(server->data);
705   self->handle_connection(status);
706 }
707 
handle_connection(int status)708 void ServerConnection::handle_connection(int status) {
709   connection_attempts_.fetch_add(1);
710 
711   if (status != 0) {
712     fprintf(stderr, "Listen failure: %s\n", uv_strerror(status));
713     return;
714   }
715 
716   ClientConnection* connection = factory_.create(this);
717   if (connection && connection->on_accept() != 0) {
718     delete connection;
719     return;
720   }
721   clients_.push_back(connection);
722 }
723 
on_close(uv_handle_t * handle)724 void ServerConnection::on_close(uv_handle_t* handle) {
725   ServerConnection* self = static_cast<ServerConnection*>(handle->data);
726   self->handle_close();
727 }
728 
handle_close()729 void ServerConnection::handle_close() {
730   signal_close();
731   dec_ref();
732 }
733 
on_password(char * buf,int size,int rwflag,void * password)734 int ServerConnection::on_password(char* buf, int size, int rwflag, void* password) {
735   strncpy(buf, (char*)(password), size);
736   buf[size - 1] = '\0';
737   return strlen(buf);
738 }
739 
740 } // namespace internal
741 
742 #define CHECK(pos, error)                               \
743   do {                                                  \
744     if ((pos) > end) {                                  \
745       fprintf(stderr, "Decoding error: %s\n", (error)); \
746       return end + 1;                                   \
747     }                                                   \
748   } while (0)
749 
decode_int8(const char * input,const char * end,int8_t * value)750 const char* decode_int8(const char* input, const char* end, int8_t* value) {
751   CHECK(input + 1, "Unable to decode byte");
752   *value = static_cast<int8_t>(input[0]);
753   return input + sizeof(int8_t);
754 }
755 
decode_int16(const char * input,const char * end,int16_t * value)756 const char* decode_int16(const char* input, const char* end, int16_t* value) {
757   CHECK(input + 2, "Unable to decode signed short");
758   *value = (static_cast<int16_t>(static_cast<uint8_t>(input[1])) << 0) |
759            (static_cast<int16_t>(static_cast<uint8_t>(input[0])) << 8);
760   return input + sizeof(int16_t);
761 }
762 
decode_uint16(const char * input,const char * end,uint16_t * value)763 const char* decode_uint16(const char* input, const char* end, uint16_t* value) {
764   CHECK(input + 2, "Unable to decode unsigned short");
765   *value = (static_cast<uint16_t>(static_cast<uint8_t>(input[1])) << 0) |
766            (static_cast<uint16_t>(static_cast<uint8_t>(input[0])) << 8);
767   return input + sizeof(uint16_t);
768 }
769 
decode_int32(const char * input,const char * end,int32_t * value)770 const char* decode_int32(const char* input, const char* end, int32_t* value) {
771   CHECK(input + 4, "Unable to decode integer");
772   *value = (static_cast<int32_t>(static_cast<uint8_t>(input[3])) << 0) |
773            (static_cast<int32_t>(static_cast<uint8_t>(input[2])) << 8) |
774            (static_cast<int32_t>(static_cast<uint8_t>(input[1])) << 16) |
775            (static_cast<int32_t>(static_cast<uint8_t>(input[0])) << 24);
776   return input + sizeof(int32_t);
777 }
778 
decode_int64(const char * input,const char * end,int64_t * value)779 const char* decode_int64(const char* input, const char* end, int64_t* value) {
780   CHECK(input + 8, "Unable to decode long");
781   *value = (static_cast<int64_t>(static_cast<uint8_t>(input[7])) << 0) |
782            (static_cast<int64_t>(static_cast<uint8_t>(input[6])) << 8) |
783            (static_cast<int64_t>(static_cast<uint8_t>(input[5])) << 16) |
784            (static_cast<int64_t>(static_cast<uint8_t>(input[4])) << 24) |
785            (static_cast<int64_t>(static_cast<uint8_t>(input[3])) << 32) |
786            (static_cast<int64_t>(static_cast<uint8_t>(input[2])) << 40) |
787            (static_cast<int64_t>(static_cast<uint8_t>(input[1])) << 48) |
788            (static_cast<int64_t>(static_cast<uint8_t>(input[0])) << 56);
789   return input + sizeof(int64_t);
790 }
791 
decode_string(const char * input,const char * end,String * output)792 const char* decode_string(const char* input, const char* end, String* output) {
793   uint16_t len = 0;
794   const char* pos = decode_uint16(input, end, &len);
795   CHECK(pos + len, "Unable to decode string");
796   output->assign(pos, len);
797   return pos + len;
798 }
799 
decode_long_string(const char * input,const char * end,String * output)800 const char* decode_long_string(const char* input, const char* end, String* output) {
801   int32_t len = 0;
802   const char* pos = decode_int32(input, end, &len);
803   CHECK(pos + len, "Unable to decode long string");
804   assert(len >= 0);
805   output->assign(pos, len);
806   return pos + len;
807 }
808 
decode_bytes(const char * input,const char * end,String * output)809 const char* decode_bytes(const char* input, const char* end, String* output) {
810   int32_t len = 0;
811   const char* pos = decode_int32(input, end, &len);
812   if (len > 0) {
813     CHECK(pos + len, "Unable to decode bytes");
814     output->assign(pos, len);
815   }
816   return pos + len;
817 }
818 
decode_uuid(const char * input,CassUuid * output)819 const char* decode_uuid(const char* input, CassUuid* output) {
820   output->time_and_version = static_cast<uint64_t>(static_cast<uint8_t>(input[3]));
821   output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[2])) << 8;
822   output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[1])) << 16;
823   output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[0])) << 24;
824 
825   output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[5])) << 32;
826   output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[4])) << 40;
827 
828   output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[7])) << 48;
829   output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[6])) << 56;
830 
831   output->clock_seq_and_node = 0;
832   for (size_t i = 0; i < 8; ++i) {
833     output->clock_seq_and_node |= static_cast<uint64_t>(static_cast<uint8_t>(input[15 - i]))
834                                   << (8 * i);
835   }
836   return input + 16;
837 }
838 
decode_string_map(const char * input,const char * end,Vector<std::pair<String,String>> * output)839 const char* decode_string_map(const char* input, const char* end,
840                               Vector<std::pair<String, String> >* output) {
841 
842   uint16_t len = 0;
843   const char* pos = decode_uint16(input, end, &len);
844   output->reserve(len);
845   for (int i = 0; i < len; ++i) {
846     String key;
847     String value;
848     pos = decode_string(pos, end, &key);
849     pos = decode_string(pos, end, &value);
850     output->push_back(std::pair<String, String>(key, value));
851   }
852   return pos;
853 }
854 
decode_stringlist(const char * input,const char * end,Vector<String> * output)855 const char* decode_stringlist(const char* input, const char* end, Vector<String>* output) {
856   uint16_t len = 0;
857   const char* pos = decode_uint16(input, end, &len);
858   output->reserve(len);
859   for (int i = 0; i < len; ++i) {
860     String value;
861     pos = decode_string(pos, end, &value);
862     output->push_back(value);
863   }
864   return pos;
865 }
866 
decode_values(const char * input,const char * end,Vector<String> * output)867 const char* decode_values(const char* input, const char* end, Vector<String>* output) {
868   uint16_t len = 0;
869   const char* pos = decode_uint16(input, end, &len);
870   output->reserve(len);
871   for (int i = 0; i < len; ++i) {
872     String value;
873     pos = decode_bytes(pos, end, &value);
874     output->push_back(value);
875   }
876   return pos;
877 }
878 
decode_values_with_names(const char * input,const char * end,Vector<String> * names,Vector<String> * values)879 const char* decode_values_with_names(const char* input, const char* end, Vector<String>* names,
880                                      Vector<String>* values) {
881   uint16_t len = 0;
882   const char* pos = decode_uint16(input, end, &len);
883   names->reserve(len);
884   values->reserve(len);
885   for (int i = 0; i < len; ++i) {
886     String name;
887     pos = decode_string(pos, end, &name);
888     names->push_back(name);
889     String value;
890     pos = decode_bytes(pos, end, &value);
891     values->push_back(value);
892   }
893   return pos;
894 }
895 
decode_query_params_v1(const char * input,const char * end,bool is_execute,QueryParameters * params)896 const char* decode_query_params_v1(const char* input, const char* end, bool is_execute,
897                                    QueryParameters* params) {
898   const char* pos = input;
899   if (is_execute) {
900     pos = decode_values(pos, end, &params->values);
901     pos = decode_uint16(pos, end, &params->consistency);
902   } else {
903     pos = decode_uint16(pos, end, &params->consistency);
904   }
905   return pos;
906 }
907 
decode_query_params_v2(const char * input,const char * end,QueryParameters * params)908 const char* decode_query_params_v2(const char* input, const char* end, QueryParameters* params) {
909   int8_t flags = 0;
910   const char* pos = input;
911   pos = decode_uint16(pos, end, &params->consistency);
912   pos = decode_int8(pos, end, &flags);
913   params->flags = flags;
914   if (flags & QUERY_FLAG_VALUES) {
915     pos = decode_values(pos, end, &params->values);
916   }
917   if (flags & QUERY_FLAG_PAGE_SIZE) {
918     pos = decode_int32(pos, end, &params->result_page_size);
919   }
920   if (flags & QUERY_FLAG_PAGE_STATE) {
921     pos = decode_bytes(pos, end, &params->paging_state);
922   }
923   if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) {
924     pos = decode_uint16(pos, end, &params->serial_consistency);
925   }
926   return pos;
927 }
928 
decode_query_params_v3v4(const char * input,const char * end,QueryParameters * params)929 const char* decode_query_params_v3v4(const char* input, const char* end, QueryParameters* params) {
930   int8_t flags = 0;
931   const char* pos = input;
932   pos = decode_uint16(pos, end, &params->consistency);
933   pos = decode_int8(pos, end, &flags);
934   params->flags = flags;
935   if (flags & QUERY_FLAG_VALUES && flags & QUERY_FLAG_NAMES_FOR_VALUES) {
936     pos = decode_values_with_names(pos, end, &params->names, &params->values);
937   } else if (flags & QUERY_FLAG_VALUES) {
938     pos = decode_values(pos, end, &params->values);
939   }
940   if (flags & QUERY_FLAG_PAGE_SIZE) {
941     pos = decode_int32(pos, end, &params->result_page_size);
942   }
943   if (flags & QUERY_FLAG_PAGE_STATE) {
944     pos = decode_bytes(pos, end, &params->paging_state);
945   }
946   if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) {
947     pos = decode_uint16(pos, end, &params->serial_consistency);
948   }
949   if (flags & QUERY_FLAG_TIMESTAMP) {
950     pos = decode_int64(pos, end, &params->timestamp);
951   }
952   return pos;
953 }
954 
decode_query_params_v5(const char * input,const char * end,QueryParameters * params)955 const char* decode_query_params_v5(const char* input, const char* end, QueryParameters* params) {
956   int32_t flags = 0;
957   const char* pos = input;
958   pos = decode_uint16(pos, end, &params->consistency);
959   pos = decode_int32(pos, end, &flags);
960   params->flags = flags;
961   if (flags & QUERY_FLAG_VALUES && flags & QUERY_FLAG_NAMES_FOR_VALUES) {
962     pos = decode_values_with_names(pos, end, &params->names, &params->values);
963   } else if (flags & QUERY_FLAG_VALUES) {
964     pos = decode_values(pos, end, &params->values);
965   }
966   if (flags & QUERY_FLAG_PAGE_SIZE) {
967     pos = decode_int32(pos, end, &params->result_page_size);
968   }
969   if (flags & QUERY_FLAG_PAGE_STATE) {
970     pos = decode_bytes(pos, end, &params->paging_state);
971   }
972   if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) {
973     pos = decode_uint16(pos, end, &params->serial_consistency);
974   }
975   if (flags & QUERY_FLAG_TIMESTAMP) {
976     pos = decode_int64(pos, end, &params->timestamp);
977   }
978   if (flags & QUERY_FLAG_KEYSPACE) {
979     pos = decode_string(pos, end, &params->keyspace);
980   }
981   return pos;
982 }
983 
decode_query_params(int version,const char * input,const char * end,bool is_execute,QueryParameters * params)984 const char* decode_query_params(int version, const char* input, const char* end, bool is_execute,
985                                 QueryParameters* params) {
986   const char* pos = input;
987   if (version == 1) {
988     pos = decode_query_params_v1(pos, end, is_execute, params);
989   } else if (version == 2) {
990     pos = decode_query_params_v2(pos, end, params);
991   } else if (version == 3 || version == 4) {
992     pos = decode_query_params_v3v4(pos, end, params);
993   } else if (version == 5) {
994     pos = decode_query_params_v5(pos, end, params);
995   } else {
996     assert(false && "Unsupported protocol version");
997   }
998   return pos;
999 }
1000 
decode_prepare_params(int version,const char * input,const char * end,PrepareParameters * params)1001 const char* decode_prepare_params(int version, const char* input, const char* end,
1002                                   PrepareParameters* params) {
1003   const char* pos = input;
1004   if (version >= 5) {
1005     pos = decode_int32(pos, end, &params->flags);
1006     if (params->flags & PREPARE_FLAGS_KEYSPACE) {
1007       pos = decode_string(pos, end, &params->keyspace);
1008     }
1009   }
1010   return pos;
1011 }
1012 
encode_int8(int8_t value,String * output)1013 int32_t encode_int8(int8_t value, String* output) {
1014   output->push_back(static_cast<char>(value));
1015   return 1;
1016 }
1017 
encode_int16(int16_t value,String * output)1018 int32_t encode_int16(int16_t value, String* output) {
1019   output->push_back(static_cast<char>(value >> 8));
1020   output->push_back(static_cast<char>(value >> 0));
1021   return 2;
1022 }
1023 
encode_uint16(uint16_t value,String * output)1024 int32_t encode_uint16(uint16_t value, String* output) {
1025   output->push_back(static_cast<char>(value >> 8));
1026   output->push_back(static_cast<char>(value >> 0));
1027   return 2;
1028 }
1029 
encode_int32(int32_t value,String * output)1030 int32_t encode_int32(int32_t value, String* output) {
1031   output->push_back(static_cast<char>(value >> 24));
1032   output->push_back(static_cast<char>(value >> 16));
1033   output->push_back(static_cast<char>(value >> 8));
1034   output->push_back(static_cast<char>(value >> 0));
1035   return 4;
1036 }
1037 
encode_string(const String & value,String * output)1038 int32_t encode_string(const String& value, String* output) {
1039   int32_t size = encode_uint16(value.size(), output) + value.size();
1040   output->append(value);
1041   return size + value.size();
1042 }
1043 
encode_string_list(const Vector<String> & value,String * output)1044 int32_t encode_string_list(const Vector<String>& value, String* output) {
1045   int32_t size = encode_int16(value.size(), output);
1046   for (Vector<String>::const_iterator it = value.begin(), end = value.end(); it != end; ++it) {
1047     size += encode_string(*it, output);
1048   }
1049   return size;
1050 }
1051 
encode_bytes(const String & value,String * output)1052 int32_t encode_bytes(const String& value, String* output) {
1053   int32_t size = encode_int32(value.size(), output) + value.size();
1054   output->append(value);
1055   return size + value.size();
1056 }
1057 
encode_inet(const Address & value,String * output)1058 int32_t encode_inet(const Address& value, String* output) {
1059   uint8_t buf[16];
1060   uint8_t len = value.to_inet(buf);
1061   encode_int8(len, output);
1062   for (uint8_t i = 0; i < len; ++i) {
1063     output->push_back(static_cast<char>(buf[i]));
1064   }
1065   encode_int32(value.port(), output);
1066   return 1 + len + 4;
1067 }
1068 
encode_uuid(CassUuid uuid,String * output)1069 int32_t encode_uuid(CassUuid uuid, String* output) {
1070   uint64_t time_and_version = uuid.time_and_version;
1071   char buf[16];
1072   buf[3] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1073   time_and_version >>= 8;
1074   buf[2] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1075   time_and_version >>= 8;
1076   buf[1] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1077   time_and_version >>= 8;
1078   buf[0] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1079   time_and_version >>= 8;
1080 
1081   buf[5] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1082   time_and_version >>= 8;
1083   buf[4] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1084   time_and_version >>= 8;
1085 
1086   buf[7] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1087   time_and_version >>= 8;
1088   buf[6] = static_cast<char>(time_and_version & 0x000000000000000FFLL);
1089 
1090   uint64_t clock_seq_and_node = uuid.clock_seq_and_node;
1091   for (size_t i = 0; i < 8; ++i) {
1092     buf[15 - i] = static_cast<char>(clock_seq_and_node & 0x00000000000000FFL);
1093     clock_seq_and_node >>= 8;
1094   }
1095   output->append(buf, 16);
1096   return 16;
1097 }
1098 
encode_string_map(const Map<String,Vector<String>> & value,String * output)1099 int32_t encode_string_map(const Map<String, Vector<String> >& value, String* output) {
1100   int32_t size = encode_uint16(value.size(), output);
1101   for (Map<String, Vector<String> >::const_iterator it = value.begin(); it != value.end(); ++it) {
1102     size += encode_string(it->first, output);
1103     size += encode_string_list(it->second, output);
1104   }
1105   return size;
1106 }
1107 
encode_header(int8_t version,int8_t flags,int16_t stream,int8_t opcode,int32_t len)1108 static String encode_header(int8_t version, int8_t flags, int16_t stream, int8_t opcode,
1109                             int32_t len) {
1110   String header;
1111   encode_int8(0x80 | version, &header);
1112   encode_int8(flags, &header);
1113   if (version >= 3) {
1114     encode_int16(stream, &header);
1115   } else {
1116     encode_int8(stream, &header);
1117   }
1118   encode_int8(opcode, &header);
1119   if (flags & FLAG_TRACING) {
1120     len += 16; // Add enough space for the tracing ID
1121   }
1122   encode_int32(len, &header);
1123   if (flags & FLAG_TRACING) {
1124     UuidGen gen;
1125     CassUuid tracing_id;
1126     gen.generate_random(&tracing_id);
1127     encode_uuid(tracing_id, &header);
1128   }
1129   return header;
1130 }
1131 
text()1132 Type Type::text() { return Type(TYPE_VARCHAR); }
1133 
inet()1134 Type Type::inet() { return Type(TYPE_INET); }
1135 
uuid()1136 Type Type::uuid() { return Type(TYPE_UUID); }
1137 
list(const Type & sub_type)1138 Type Type::list(const Type& sub_type) {
1139   Type type(TYPE_LIST);
1140   type.types_.push_back(sub_type);
1141   return type;
1142 }
1143 
encode(int protocol_version,String * output) const1144 void Type::encode(int protocol_version, String* output) const {
1145   switch (type_) {
1146     case TYPE_VARCHAR:
1147     case TYPE_INET:
1148     case TYPE_UUID:
1149       encode_int16(type_, output);
1150       break;
1151     case TYPE_LIST:
1152       encode_int16(type_, output);
1153       types_[0].encode(protocol_version, output);
1154       break;
1155     default:
1156       assert(false && "Unsupported type");
1157       break;
1158   };
1159 }
1160 
encode(int protocol_version,String * output) const1161 void Column::encode(int protocol_version, String* output) const {
1162   encode_string(name_, output);
1163   type_.encode(protocol_version, output);
1164 }
1165 
encode(int protocol_version,String * output) const1166 void Collection::encode(int protocol_version, String* output) const {
1167   encode_int32(values_.size(), output);
1168   for (Vector<Value>::const_iterator it = values_.begin(), end = values_.end(); it != end; ++it) {
1169     it->encode(protocol_version, output);
1170   }
1171 }
1172 
Value()1173 Value::Value()
1174     : type_(NUL) {}
1175 
Value(const String & value)1176 Value::Value(const String& value)
1177     : type_(VALUE)
1178     , value_(new String(value)) {}
1179 
Value(const Collection & collection)1180 Value::Value(const Collection& collection)
1181     : type_(COLLECTION)
1182     , collection_(new Collection(collection)) {}
1183 
Value(const Value & other)1184 Value::Value(const Value& other)
1185     : type_(other.type_) {
1186   if (type_ == VALUE) {
1187     value_ = new String(*other.value_);
1188   } else if (type_ == COLLECTION) {
1189     collection_ = new Collection(*other.collection_);
1190   }
1191 }
1192 
~Value()1193 Value::~Value() {
1194   if (type_ == VALUE) {
1195     delete value_;
1196   } else if (type_ == COLLECTION) {
1197     delete collection_;
1198   }
1199 }
1200 
encode(int protocol_version,String * output) const1201 void Value::encode(int protocol_version, String* output) const {
1202   if (type_ == NUL) {
1203     encode_int32(-1, output);
1204   } else if (type_ == VALUE) {
1205     encode_bytes(*value_, output);
1206   } else if (type_ == COLLECTION) {
1207     String buf;
1208     collection_->encode(protocol_version, &buf);
1209     encode_bytes(buf, output);
1210   }
1211 }
1212 
text(const String & text)1213 Row::Builder& Row::Builder::text(const String& text) {
1214   values_.push_back(Value(text));
1215   return *this;
1216 }
1217 
inet(const Address & inet)1218 Row::Builder& Row::Builder::inet(const Address& inet) {
1219   String value;
1220   uint8_t buf[16];
1221   uint8_t len = inet.to_inet(buf);
1222   for (uint8_t i = 0; i < len; ++i) {
1223     value.push_back(static_cast<char>(buf[i]));
1224   }
1225   values_.push_back(Value(value));
1226   return *this;
1227 }
1228 
uuid(const CassUuid & uuid)1229 Row::Builder& Row::Builder::uuid(const CassUuid& uuid) {
1230   String value;
1231   encode_uuid(uuid, &value);
1232   values_.push_back(Value(value));
1233   return *this;
1234 }
1235 
collection(const Collection & collection)1236 Row::Builder& Row::Builder::collection(const Collection& collection) {
1237   values_.push_back(Value(collection));
1238   return *this;
1239 }
1240 
encode(int protocol_version,String * output) const1241 void Row::encode(int protocol_version, String* output) const {
1242   for (Vector<Value>::const_iterator it = values_.begin(), end = values_.end(); it != end; ++it) {
1243     it->encode(protocol_version, output);
1244   }
1245 }
1246 
encode(int protocol_version) const1247 String ResultSet::encode(int protocol_version) const {
1248   String body;
1249 
1250   encode_int32(RESULT_ROWS, &body); // Result type
1251 
1252   encode_int32(RESULT_FLAG_GLOBAL_TABLESPEC, &body); // Flags
1253   encode_int32(columns_.size(), &body);              // Column count
1254   encode_string(keyspace_name_, &body);              // Global spec keyspace name
1255   encode_string(table_name_, &body);                 // Global spec table name
1256 
1257   // Columns
1258   for (Vector<Column>::const_iterator it = columns_.begin(), end = columns_.end(); it != end;
1259        ++it) {
1260     it->encode(protocol_version, &body);
1261   }
1262 
1263   encode_int32(rows_.size(), &body); // Row count
1264 
1265   // Rows
1266   for (Vector<Row>::const_iterator it = rows_.begin(), end = rows_.end(); it != end; ++it) {
1267     it->encode(protocol_version, &body);
1268   }
1269 
1270   return body;
1271 }
1272 
reset()1273 Action::Builder& Action::Builder::reset() {
1274   first_.reset();
1275   last_ = NULL;
1276   return *this;
1277 }
1278 
execute(Action * action)1279 Action::Builder& Action::Builder::execute(Action* action) {
1280   if (!first_) {
1281     first_.reset(action);
1282   }
1283   if (last_) {
1284     last_->next = action;
1285   }
1286   last_ = action;
1287   return *this;
1288 }
1289 
execute_if(Action * action)1290 Action::Builder& Action::Builder::execute_if(Action* action) {
1291   if (last_ && last_->is_predicate()) {
1292     static_cast<Predicate*>(last_)->then = action;
1293   }
1294   return *this;
1295 }
1296 
nop()1297 Action::Builder& Action::Builder::nop() { return execute(new Nop()); }
1298 
wait(uint64_t timeout)1299 Action::Builder& Action::Builder::wait(uint64_t timeout) { return execute(new Wait(timeout)); }
1300 
close()1301 Action::Builder& Action::Builder::close() { return execute(new Close()); }
1302 
error(int32_t code,const String & message)1303 Action::Builder& Action::Builder::error(int32_t code, const String& message) {
1304   return execute(new SendError(code, message));
1305 }
1306 
invalid_protocol()1307 Action::Builder& Action::Builder::invalid_protocol() {
1308   return error(ERROR_PROTOCOL_ERROR, "Invalid or unsupported protocol version");
1309 }
1310 
invalid_opcode()1311 Action::Builder& Action::Builder::invalid_opcode() {
1312   return error(ERROR_PROTOCOL_ERROR, "Invalid opcode (or not implemented)");
1313 }
1314 
ready()1315 Action::Builder& Action::Builder::ready() { return execute(new SendReady()); }
1316 
authenticate(const String & class_name)1317 Action::Builder& Action::Builder::authenticate(const String& class_name) {
1318   return execute(new SendAuthenticate(class_name));
1319 }
1320 
auth_challenge(const String & token)1321 Action::Builder& Action::Builder::auth_challenge(const String& token) {
1322   return execute(new SendAuthChallenge(token));
1323 }
1324 
auth_success(const String & token)1325 Action::Builder& Action::Builder::auth_success(const String& token) {
1326   return execute(new SendAuthSuccess(token));
1327 }
1328 
supported()1329 Action::Builder& Action::Builder::supported() { return execute(new SendSupported()); }
1330 
up_event(const Address & address)1331 Action::Builder& Action::Builder::up_event(const Address& address) {
1332   return execute(new SendUpEvent(address));
1333 }
1334 
void_result()1335 Action::Builder& Action::Builder::void_result() { return execute(new VoidResult()); }
1336 
empty_rows_result(int32_t row_count)1337 Action::Builder& Action::Builder::empty_rows_result(int32_t row_count) {
1338   return execute(new EmptyRowsResult(row_count));
1339 }
1340 
no_result()1341 Action::Builder& Action::Builder::no_result() { return execute(new NoResult()); }
1342 
match_query(const Matches & matches)1343 Action::Builder& Action::Builder::match_query(const Matches& matches) {
1344   return execute(new MatchQuery(matches));
1345 }
1346 
client_options()1347 Action::Builder& Action::Builder::client_options() { return execute(new ClientOptions()); }
1348 
system_local()1349 Action::Builder& Action::Builder::system_local() { return execute(new SystemLocal()); }
1350 
system_local_dse()1351 Action::Builder& Action::Builder::system_local_dse() { return execute(new SystemLocalDse()); }
1352 
system_peers()1353 Action::Builder& Action::Builder::system_peers() { return execute(new SystemPeers()); }
1354 
system_peers_dse()1355 Action::Builder& Action::Builder::system_peers_dse() { return execute(new SystemPeersDse()); }
1356 
system_traces()1357 Action::Builder& Action::Builder::system_traces() { return execute(new SystemTraces()); }
1358 
use_keyspace(const String & keyspace)1359 Action::Builder& Action::Builder::use_keyspace(const String& keyspace) {
1360   return execute((new UseKeyspace(keyspace)));
1361 }
1362 
use_keyspace(const Vector<String> & keyspaces)1363 Action::Builder& Action::Builder::use_keyspace(const Vector<String>& keyspaces) {
1364   return execute((new UseKeyspace(keyspaces)));
1365 }
1366 
plaintext_auth(const String & username,const String & password)1367 Action::Builder& Action::Builder::plaintext_auth(const String& username, const String& password) {
1368   return execute((new PlaintextAuth(username, password)));
1369 }
1370 
validate_startup()1371 Action::Builder& Action::Builder::validate_startup() { return execute(new ValidateStartup()); }
1372 
validate_credentials()1373 Action::Builder& Action::Builder::validate_credentials() {
1374   return execute(new ValidateCredentials());
1375 }
1376 
validate_auth_response()1377 Action::Builder& Action::Builder::validate_auth_response() {
1378   return execute(new ValidateAuthResponse());
1379 }
1380 
validate_register()1381 Action::Builder& Action::Builder::validate_register() { return execute(new ValidateRegister()); }
1382 
validate_query()1383 Action::Builder& Action::Builder::validate_query() { return execute(new ValidateQuery()); }
1384 
set_registered_for_events()1385 Action::Builder& Action::Builder::set_registered_for_events() {
1386   return execute(new SetRegisteredForEvents());
1387 }
1388 
set_protocol_version()1389 Action::Builder& Action::Builder::set_protocol_version() {
1390   return execute(new SetProtocolVersion());
1391 }
1392 
build()1393 Action* Action::Builder::build() { return first_.release(); }
1394 
is_address(const Address & address)1395 Action::PredicateBuilder Action::Builder::is_address(const Address& address) {
1396   return PredicateBuilder(execute(new IsAddress(address)));
1397 }
1398 
is_address(const String & address,int port)1399 Action::PredicateBuilder Action::Builder::is_address(const String& address, int port) {
1400   return PredicateBuilder(execute(new IsAddress(Address(address, port))));
1401 }
1402 
is_query(const String & query)1403 Action::PredicateBuilder Action::Builder::is_query(const String& query) {
1404   return PredicateBuilder(execute(new IsQuery(query)));
1405 }
1406 
run(Request * request) const1407 void Action::run(Request* request) const { on_run(request); }
1408 
run_next(Request * request) const1409 void Action::run_next(Request* request) const {
1410   if (next) {
1411     next->on_run(request);
1412   }
1413 }
1414 
Request(int8_t version,int8_t flags,int16_t stream,int8_t opcode,const String & body,ClientConnection * client)1415 Request::Request(int8_t version, int8_t flags, int16_t stream, int8_t opcode, const String& body,
1416                  ClientConnection* client)
1417     : version_(version)
1418     , flags_(flags)
1419     , stream_(stream)
1420     , opcode_(opcode)
1421     , body_(body)
1422     , client_(client)
1423     , timer_action_(NULL) {
1424   (void)flags_; // TODO: Implement custom payload etc.
1425 }
1426 
write(int8_t opcode,const String & body)1427 void Request::write(int8_t opcode, const String& body) { write(stream_, opcode, body); }
1428 
write(int16_t stream,int8_t opcode,const String & body)1429 void Request::write(int16_t stream, int8_t opcode, const String& body) {
1430   client_->write(encode_header(version_, flags_, stream, opcode, body.size()) + body);
1431 }
1432 
error(int32_t code,const String & message)1433 void Request::error(int32_t code, const String& message) {
1434   String body;
1435   encode_int32(code, &body);
1436   encode_string(message, &body);
1437   write(OPCODE_ERROR, body);
1438 }
1439 
wait(uint64_t timeout,const Action * action)1440 void Request::wait(uint64_t timeout, const Action* action) {
1441   inc_ref();
1442   timer_action_ = action;
1443   timer_.start(client_->server()->loop(), timeout, bind_callback(&Request::on_timeout, this));
1444 }
1445 
close()1446 void Request::close() { client_->close(); }
1447 
decode_startup(Options * options)1448 bool Request::decode_startup(Options* options) {
1449   return decode_string_map(start(), end(), options) == end();
1450 }
1451 
decode_credentials(Credentials * creds)1452 bool Request::decode_credentials(Credentials* creds) {
1453   return decode_string_map(start(), end(), creds) == end();
1454 }
1455 
decode_auth_response(String * token)1456 bool Request::decode_auth_response(String* token) {
1457   return decode_bytes(start(), end(), token) == end();
1458 }
1459 
decode_register(EventTypes * types)1460 bool Request::decode_register(EventTypes* types) {
1461   return decode_stringlist(start(), end(), types) == end();
1462 }
1463 
decode_query(String * query,QueryParameters * params)1464 bool Request::decode_query(String* query, QueryParameters* params) {
1465   return decode_query_params(version_, decode_long_string(start(), end(), query), end(), false,
1466                              params) == end();
1467 }
1468 
decode_execute(String * id,QueryParameters * params)1469 bool Request::decode_execute(String* id, QueryParameters* params) {
1470   return decode_query_params(version_, decode_string(start(), end(), id), end(), true, params) ==
1471          end();
1472 }
1473 
decode_prepare(String * query,PrepareParameters * params)1474 bool Request::decode_prepare(String* query, PrepareParameters* params) {
1475   return decode_prepare_params(version_, decode_long_string(start(), end(), query), end(),
1476                                params) == end();
1477 }
1478 
address() const1479 const Address& Request::address() const { return client_->server()->address(); }
1480 
host(const Address & address) const1481 const Host& Request::host(const Address& address) const {
1482   return client_->cluster()->host(address);
1483 }
1484 
hosts() const1485 Hosts Request::hosts() const { return client_->cluster()->hosts(); }
1486 
on_timeout(Timer * timer)1487 void Request::on_timeout(Timer* timer) {
1488   timer_action_->run_next(this);
1489   dec_ref();
1490 }
1491 
on_run(Request * request) const1492 void SendError::on_run(Request* request) const { request->error(code, message); }
1493 
on_run(Request * request) const1494 void SendReady::on_run(Request* request) const { request->write(OPCODE_READY, String()); }
1495 
on_run(Request * request) const1496 void SendAuthenticate::on_run(Request* request) const {
1497   String body;
1498   encode_string(class_name, &body);
1499   request->write(OPCODE_AUTHENTICATE, body);
1500 }
1501 
on_run(Request * request) const1502 void SendAuthChallenge::on_run(Request* request) const {
1503   String body;
1504   encode_string(token, &body);
1505   request->write(OPCODE_AUTH_CHALLENGE, body);
1506 }
1507 
on_run(Request * request) const1508 void SendAuthSuccess::on_run(Request* request) const {
1509   String body;
1510   encode_string(token, &body);
1511   request->write(OPCODE_AUTH_SUCCESS, body);
1512 }
1513 
on_run(Request * request) const1514 void SendSupported::on_run(Request* request) const {
1515   String body;
1516   encode_uint16(0, &body);
1517   request->write(OPCODE_SUPPORTED, body);
1518 }
1519 
on_run(Request * request) const1520 void SendUpEvent::on_run(Request* request) const {
1521   request->write(-1, OPCODE_EVENT, StatusChangeEvent::encode(StatusChangeEvent::UP, address));
1522   run_next(request);
1523 }
1524 
on_run(Request * request) const1525 void VoidResult::on_run(Request* request) const {
1526   String body;
1527   encode_int32(RESULT_VOID, &body);
1528   request->write(OPCODE_RESULT, body);
1529 }
1530 
on_run(Request * request) const1531 void EmptyRowsResult::on_run(Request* request) const {
1532   String query;
1533   QueryParameters params;
1534   if (!request->decode_query(&query, &params)) {
1535     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1536   } else {
1537     String body;
1538     encode_int32(RESULT_ROWS, &body);
1539     encode_int32(0, &body);         // Flags
1540     encode_int32(0, &body);         // Column count
1541     encode_int32(row_count, &body); // Row count
1542     request->write(OPCODE_RESULT, body);
1543   }
1544 }
1545 
on_run(Request * request) const1546 void NoResult::on_run(Request* request) const {}
1547 
on_run(Request * request) const1548 void MatchQuery::on_run(Request* request) const {
1549   String query;
1550   QueryParameters params;
1551   if (!request->decode_query(&query, &params)) {
1552     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1553     return;
1554   } else {
1555     for (Matches::const_iterator it = matches.begin(), end = matches.end(); it != end; ++it) {
1556       if (it->first == query) {
1557         request->write(OPCODE_RESULT, it->second.encode(request->version()));
1558         return;
1559       }
1560     }
1561   }
1562   run_next(request);
1563 }
1564 
on_run(Request * request) const1565 void ClientOptions::on_run(Request* request) const {
1566   String query;
1567   QueryParameters params;
1568   if (!request->decode_query(&query, &params)) {
1569     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1570   } else if (query == CLIENT_OPTIONS_QUERY) {
1571     ResultSet::Builder builder("client", "options");
1572     Row::Builder row_builder;
1573     for (Options::const_iterator it = request->client()->options().begin(),
1574                                  end = request->client()->options().end();
1575          it != end; ++it) {
1576       builder.column((*it).first, Type::text());
1577       row_builder.text((*it).second);
1578     }
1579     ResultSet client_options = builder.row(row_builder.build()).build();
1580 
1581     request->write(OPCODE_RESULT, client_options.encode(request->version()));
1582   } else {
1583     run_next(request);
1584   }
1585 }
1586 
on_run(Request * request) const1587 void SystemLocal::on_run(Request* request) const {
1588   String query;
1589   QueryParameters params;
1590   if (!request->decode_query(&query, &params)) {
1591     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1592   } else if (query.find(SELECT_LOCAL) != String::npos) {
1593     const Host& host(request->host(request->address()));
1594 
1595     ResultSet local_rs = ResultSet::Builder("system", "local")
1596                              .column("key", Type::text())
1597                              .column("data_center", Type::text())
1598                              .column("rack", Type::text())
1599                              .column("release_version", Type::text())
1600                              .column("rpc_address", Type::inet())
1601                              .column("partitioner", Type::text())
1602                              .column("tokens", Type::list(Type::text()))
1603                              .row(Row::Builder()
1604                                       .text(request->client()->server()->address().to_string())
1605                                       .text(host.dc)
1606                                       .text(host.rack)
1607                                       .text(CASSANDRA_VERSION)
1608                                       .inet(request->client()->server()->address())
1609                                       .text(host.partitioner)
1610                                       .collection(Collection::text(host.tokens))
1611                                       .build())
1612                              .build();
1613 
1614     request->write(OPCODE_RESULT, local_rs.encode(request->version()));
1615   } else {
1616     run_next(request);
1617   }
1618 }
1619 
on_run(Request * request) const1620 void SystemLocalDse::on_run(Request* request) const {
1621   String query;
1622   QueryParameters params;
1623   if (!request->decode_query(&query, &params)) {
1624     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1625   } else if (query.find(SELECT_LOCAL) != String::npos) {
1626     const Host& host(request->host(request->address()));
1627 
1628     ResultSet local_rs = ResultSet::Builder("system", "local")
1629                              .column("key", Type::text())
1630                              .column("data_center", Type::text())
1631                              .column("rack", Type::text())
1632                              .column("dse_version", Type::text())
1633                              .column("release_version", Type::text())
1634                              .column("rpc_address", Type::inet())
1635                              .column("partitioner", Type::text())
1636                              .column("tokens", Type::list(Type::text()))
1637                              .row(Row::Builder()
1638                                       .text(request->client()->server()->address().to_string())
1639                                       .text(host.dc)
1640                                       .text(host.rack)
1641                                       .text(DSE_VERSION)
1642                                       .text(DSE_CASSANDRA_VERSION)
1643                                       .inet(request->client()->server()->address())
1644                                       .text(host.partitioner)
1645                                       .collection(Collection::text(host.tokens))
1646                                       .build())
1647                              .build();
1648 
1649     request->write(OPCODE_RESULT, local_rs.encode(request->version()));
1650   } else {
1651     run_next(request);
1652   }
1653 }
1654 
on_run(Request * request) const1655 void SystemPeers::on_run(Request* request) const {
1656   String query;
1657   QueryParameters params;
1658   if (!request->decode_query(&query, &params)) {
1659     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1660   } else if (query.find(SELECT_PEERS) != String::npos) {
1661     const String where_clause(" WHERE peer = '");
1662     ResultSet::Builder peers_builder = ResultSet::Builder("system", "peers")
1663                                            .column("peer", Type::inet())
1664                                            .column("data_center", Type::text())
1665                                            .column("rack", Type::text())
1666                                            .column("release_version", Type::text())
1667                                            .column("rpc_address", Type::inet())
1668                                            .column("tokens", Type::list(Type::text()));
1669 
1670     size_t pos = query.find(where_clause);
1671     if (pos == String::npos) {
1672       Hosts hosts(request->hosts());
1673       for (Hosts::const_iterator it = hosts.begin(), end = hosts.end(); it != end; ++it) {
1674         const Host& host(*it);
1675         if (host.address == request->address()) {
1676           continue;
1677         }
1678         peers_builder.row(Row::Builder()
1679                               .inet(host.address)
1680                               .text(host.dc)
1681                               .text(host.rack)
1682                               .text(CASSANDRA_VERSION)
1683                               .inet(host.address)
1684                               .collection(Collection::text(host.tokens))
1685                               .build());
1686       }
1687       ResultSet peers_rs = peers_builder.build();
1688       request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1689     } else {
1690       pos += where_clause.size();
1691       size_t end_pos = query.find("'", pos);
1692       if (end_pos == String::npos) {
1693         request->error(ERROR_INVALID_QUERY, "Invalid WHERE clause");
1694         return;
1695       }
1696 
1697       String ip = query.substr(pos, end_pos - pos);
1698       Address address(ip, request->address().port());
1699       if (!address.is_valid_and_resolved()) {
1700         request->error(ERROR_INVALID_QUERY, "Invalid inet address in WHERE clause");
1701         return;
1702       }
1703 
1704       const Host& host(request->host(address));
1705       ResultSet peers_rs = peers_builder
1706                                .row(Row::Builder()
1707                                         .inet(host.address)
1708                                         .text(host.dc)
1709                                         .text(host.rack)
1710                                         .text(CASSANDRA_VERSION)
1711                                         .inet(host.address)
1712                                         .collection(Collection::text(host.tokens))
1713                                         .build())
1714                                .build();
1715       request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1716     }
1717   } else {
1718     run_next(request);
1719   }
1720 }
1721 
on_run(Request * request) const1722 void SystemPeersDse::on_run(Request* request) const {
1723   String query;
1724   QueryParameters params;
1725   if (!request->decode_query(&query, &params)) {
1726     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1727   } else if (query.find(SELECT_PEERS) != String::npos) {
1728     const String where_clause(" WHERE peer = '");
1729     ResultSet::Builder peers_builder = ResultSet::Builder("system", "peers")
1730                                            .column("peer", Type::inet())
1731                                            .column("data_center", Type::text())
1732                                            .column("rack", Type::text())
1733                                            .column("dse_version", Type::text())
1734                                            .column("release_version", Type::text())
1735                                            .column("rpc_address", Type::inet())
1736                                            .column("tokens", Type::list(Type::text()));
1737 
1738     size_t pos = query.find(where_clause);
1739     if (pos == String::npos) {
1740       Hosts hosts(request->hosts());
1741       for (Hosts::const_iterator it = hosts.begin(), end = hosts.end(); it != end; ++it) {
1742         const Host& host(*it);
1743         if (host.address == request->address()) {
1744           continue;
1745         }
1746         peers_builder.row(Row::Builder()
1747                               .inet(host.address)
1748                               .text(host.dc)
1749                               .text(host.rack)
1750                               .text(DSE_VERSION)
1751                               .text(DSE_CASSANDRA_VERSION)
1752                               .inet(host.address)
1753                               .collection(Collection::text(host.tokens))
1754                               .build());
1755       }
1756       ResultSet peers_rs = peers_builder.build();
1757       request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1758     } else {
1759       pos += where_clause.size();
1760       size_t end_pos = query.find("'", pos);
1761       if (end_pos == String::npos) {
1762         request->error(ERROR_INVALID_QUERY, "Invalid WHERE clause");
1763         return;
1764       }
1765 
1766       String ip = query.substr(pos, end_pos - pos);
1767       Address address(ip, request->address().port());
1768       if (!address.is_valid_and_resolved()) {
1769         request->error(ERROR_INVALID_QUERY, "Invalid inet address in WHERE clause");
1770         return;
1771       }
1772 
1773       const Host& host(request->host(address));
1774       ResultSet peers_rs = peers_builder
1775                                .row(Row::Builder()
1776                                         .inet(host.address)
1777                                         .text(host.dc)
1778                                         .text(host.rack)
1779                                         .text(CASSANDRA_VERSION)
1780                                         .inet(host.address)
1781                                         .collection(Collection::text(host.tokens))
1782                                         .build())
1783                                .build();
1784       request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1785     }
1786   } else {
1787     run_next(request);
1788   }
1789 }
1790 
on_run(Request * request) const1791 void SystemTraces::on_run(Request* request) const {
1792   String query;
1793   QueryParameters params;
1794   if (!request->decode_query(&query, &params)) {
1795     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1796   } else if (query.find(SELECT_TRACES_SESSION) != String::npos) {
1797     if (params.values.empty() || params.values.front().size() < 16) {
1798       request->error(ERROR_INVALID_QUERY, "Query expects a UUID parameter (tracing)");
1799       return;
1800     }
1801     CassUuid tracing_id;
1802     decode_uuid(params.values.front().data(), &tracing_id);
1803     ResultSet session_rs = ResultSet::Builder("system_traces", "session")
1804                                .column("session_id", Type::uuid())
1805                                .row(Row::Builder().uuid(tracing_id).build())
1806                                .build();
1807     request->write(OPCODE_RESULT, session_rs.encode(request->version()));
1808   } else {
1809     run_next(request);
1810   }
1811 }
1812 
on_run(Request * request) const1813 void UseKeyspace::on_run(Request* request) const {
1814   String query;
1815   QueryParameters params;
1816   if (request->decode_query(&query, &params)) {
1817     trim(query);
1818     if (query.compare(0, 3, "USE") == 0 || query.compare(0, 3, "use") == 0) {
1819       String keyspace(query.substr(query.find_first_not_of(" \t", 3)));
1820       for (Vector<String>::const_iterator it = keyspaces.begin(), end = keyspaces.end(); it != end;
1821            ++it) {
1822         String temp(*it);
1823         if (keyspace == escape_id(temp)) {
1824           String body;
1825           encode_int32(RESULT_SET_KEYSPACE, &body);
1826           encode_string(*it, &body);
1827           request->client()->set_keyspace(*it);
1828           request->write(OPCODE_RESULT, body);
1829           return;
1830         }
1831       }
1832       request->error(ERROR_INVALID_QUERY, "Keyspace '" + keyspace + "' does not exist");
1833     } else {
1834       run_next(request);
1835     }
1836   } else {
1837     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1838   }
1839 }
1840 
on_run(Request * request) const1841 void PlaintextAuth::on_run(Request* request) const {
1842   String token;
1843   if (request->decode_auth_response(&token)) {
1844     String username, password;
1845     String::const_reverse_iterator last = token.rbegin();
1846     enum { PASSWORD, USERNAME } state = PASSWORD;
1847     for (String::const_reverse_iterator it = token.rbegin(), end = token.rend(); it != end; ++it) {
1848       if (*it == '\0') {
1849         if (state == PASSWORD) {
1850           password.assign(it.base(), last.base());
1851           state = USERNAME;
1852         } else if (state == USERNAME) {
1853           username.assign(it.base(), last.base());
1854           break;
1855         }
1856         last = it + 1;
1857       }
1858     }
1859 
1860     if (username == this->username && password == this->password) {
1861       String body;
1862       encode_int32(-1, &body); // Null bytes
1863       request->write(OPCODE_AUTH_SUCCESS, body);
1864     } else {
1865       request->error(ERROR_BAD_CREDENTIALS, "Invalid credentials");
1866     }
1867   } else {
1868     request->error(ERROR_PROTOCOL_ERROR, "Invalid auth response message");
1869   }
1870 }
1871 
on_run(Request * request) const1872 void ValidateStartup::on_run(Request* request) const {
1873   Options options;
1874   if (!request->decode_startup(&options)) {
1875     request->error(ERROR_PROTOCOL_ERROR, "Invalid startup message");
1876   } else {
1877     request->client()->set_options(options);
1878     run_next(request);
1879   }
1880 }
1881 
on_run(Request * request) const1882 void ValidateCredentials::on_run(Request* request) const {
1883   Credentials creds;
1884   if (!request->decode_credentials(&creds)) {
1885     request->error(ERROR_PROTOCOL_ERROR, "Invalid credentials message");
1886   } else {
1887     run_next(request);
1888   }
1889 }
1890 
on_run(Request * request) const1891 void ValidateAuthResponse::on_run(Request* request) const {
1892   String token;
1893   if (!request->decode_auth_response(&token)) {
1894     request->error(ERROR_PROTOCOL_ERROR, "Invalid auth response message");
1895   } else {
1896     run_next(request);
1897   }
1898 }
1899 
on_run(Request * request) const1900 void ValidateRegister::on_run(Request* request) const {
1901   EventTypes types;
1902   if (!request->decode_register(&types)) {
1903     request->error(ERROR_PROTOCOL_ERROR, "Invalid register message");
1904   } else {
1905     run_next(request);
1906   }
1907 }
1908 
on_run(Request * request) const1909 void ValidateQuery::on_run(Request* request) const {
1910   String query;
1911   QueryParameters params;
1912   if (!request->decode_query(&query, &params)) {
1913     request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1914   } else {
1915     run_next(request);
1916   }
1917 }
1918 
on_run(Request * request) const1919 void SetRegisteredForEvents::on_run(Request* request) const {
1920   request->client()->set_registered_for_events();
1921   run_next(request);
1922 }
1923 
on_run(Request * request) const1924 void SetProtocolVersion::on_run(Request* request) const {
1925   request->client()->set_protocol_version(request->version());
1926   run_next(request);
1927 }
1928 
is_true(Request * request) const1929 bool IsAddress::is_true(Request* request) const {
1930   return request->client()->server()->address() == address;
1931 }
1932 
is_true(Request * request) const1933 bool IsQuery::is_true(Request* request) const {
1934   String query;
1935   QueryParameters params;
1936   return request->decode_query(&query, &params) && query == this->query;
1937 }
1938 
RequestHandler(RequestHandler::Builder * builder,int lowest_supported_protocol_version,int highest_supported_protocol_version)1939 RequestHandler::RequestHandler(RequestHandler::Builder* builder,
1940                                int lowest_supported_protocol_version,
1941                                int highest_supported_protocol_version)
1942     : invalid_protocol_(builder->on_invalid_protocol().build())
1943     , invalid_opcode_(builder->on_invalid_opcode().build())
1944     , lowest_supported_protocol_version_(lowest_supported_protocol_version)
1945     , highest_supported_protocol_version_(highest_supported_protocol_version) {}
1946 
build()1947 const RequestHandler* RequestHandler::Builder::build() {
1948   RequestHandler* handler(new RequestHandler(this, lowest_supported_protocol_version_,
1949                                              highest_supported_protocol_version_));
1950 
1951   handler->actions_[OPCODE_STARTUP].reset(actions_[OPCODE_STARTUP].build());
1952   handler->actions_[OPCODE_OPTIONS].reset(actions_[OPCODE_OPTIONS].build());
1953   handler->actions_[OPCODE_CREDENTIALS].reset(actions_[OPCODE_CREDENTIALS].build());
1954   handler->actions_[OPCODE_QUERY].reset(actions_[OPCODE_QUERY].build());
1955   handler->actions_[OPCODE_PREPARE].reset(actions_[OPCODE_PREPARE].build());
1956   handler->actions_[OPCODE_EXECUTE].reset(actions_[OPCODE_EXECUTE].build());
1957   handler->actions_[OPCODE_REGISTER].reset(actions_[OPCODE_REGISTER].build());
1958   handler->actions_[OPCODE_AUTH_RESPONSE].reset(actions_[OPCODE_AUTH_RESPONSE].build());
1959 
1960   return handler;
1961 }
1962 
decode(ClientConnection * client,const char * data,int32_t len)1963 void ProtocolHandler::decode(ClientConnection* client, const char* data, int32_t len) {
1964   buffer_.append(data, len);
1965   int32_t result = decode_frame(client, buffer_.data(), buffer_.size());
1966   if (result > 0) {
1967     if (static_cast<size_t>(result) == buffer_.size()) {
1968       buffer_.clear();
1969     } else {
1970       // Not efficient, but concise. Copy the consumed part of the buffer
1971       // forward then resize the buffer to what's left over.
1972       std::copy(buffer_.begin() + result, buffer_.end(), buffer_.begin());
1973       buffer_.resize(buffer_.size() - result);
1974     }
1975   }
1976 }
1977 
decode_frame(ClientConnection * client,const char * frame,int32_t len)1978 int32_t ProtocolHandler::decode_frame(ClientConnection* client, const char* frame, int32_t len) {
1979   const char* pos = frame;
1980   const char* end = pos + len;
1981   int32_t remaining = len;
1982 
1983   while (remaining > 0) {
1984     switch (state_) {
1985       case PROTOCOL_VERSION:
1986         // Version requires a single byte and that's guaranteed by the loop check.
1987         version_ = *pos++;
1988         remaining--;
1989         if (version_ < request_handler_->lowest_supported_protocol_version() ||
1990             version_ > request_handler_->highest_supported_protocol_version()) {
1991           // Use the highest supported protocol version unless it's less than the lowest supported
1992           // then use the original request's protocol version.
1993           int8_t response_version = request_handler_->highest_supported_protocol_version();
1994           if (version_ < request_handler_->lowest_supported_protocol_version()) {
1995             response_version = version_;
1996           }
1997           Request::Ptr request(
1998               new Request(response_version, flags_, stream_, opcode_, String(), client));
1999           request_handler_->invalid_protocol(request.get());
2000           client->close();
2001           return -1;
2002         }
2003         state_ = HEADER;
2004         break;
2005       case HEADER:
2006         if ((version_ == 1 || version_ == 2) && remaining >= 7) {
2007           flags_ = *pos++;
2008           stream_ = *pos++;
2009           opcode_ = *pos++;
2010           pos = decode_int32(pos, end, &length_);
2011           remaining -= 7;
2012         } else if (version_ >= 3 && remaining >= 8) {
2013           flags_ = *pos++;
2014           pos = decode_int16(pos, end, &stream_);
2015           opcode_ = *pos++;
2016           pos = decode_int32(pos, end, &length_);
2017           remaining -= 8;
2018         } else {
2019           return len - remaining;
2020         }
2021 
2022         if (length_ == 0) {
2023           decode_body(client, pos, 0);
2024           version_ = 0;
2025           flags_ = 0;
2026           opcode_ = 0;
2027           length_ = 0;
2028           state_ = PROTOCOL_VERSION;
2029         } else {
2030           state_ = BODY;
2031         }
2032         break;
2033       case BODY:
2034         if (remaining >= length_) {
2035           decode_body(client, pos, length_);
2036           pos += length_;
2037           remaining -= length_;
2038         } else {
2039           return len - remaining;
2040         }
2041         version_ = 0;
2042         flags_ = 0;
2043         opcode_ = 0;
2044         length_ = 0;
2045         state_ = PROTOCOL_VERSION;
2046         break;
2047     }
2048   }
2049 
2050   return len; // All bytes have been consumed
2051 }
2052 
decode_body(ClientConnection * client,const char * body,int32_t len)2053 void ProtocolHandler::decode_body(ClientConnection* client, const char* body, int32_t len) {
2054   Request::Ptr request(new Request(version_, flags_, stream_, opcode_, String(body, len), client));
2055   request_handler_->run(request.get());
2056 }
2057 
on_read(const char * data,size_t len)2058 void ClientConnection::on_read(const char* data, size_t len) { handler_.decode(this, data, len); }
2059 
Event(const String & event_body)2060 Event::Event(const String& event_body)
2061     : event_body_(event_body) {}
2062 
run(internal::ServerConnection * server_connection)2063 void Event::run(internal::ServerConnection* server_connection) {
2064   for (internal::ClientConnections::const_iterator it = server_connection->clients().begin(),
2065                                                    end = server_connection->clients().end();
2066        it != end; ++it) {
2067     ClientConnection* client = static_cast<ClientConnection*>(*it);
2068     if (client->is_registered_for_events() && client->protocol_version() > 0) {
2069       client->write(
2070           encode_header(client->protocol_version(), 0, -1, OPCODE_EVENT, event_body_.size()) +
2071           event_body_);
2072     }
2073   }
2074 }
2075 
new_node(const Address & address)2076 Event::Ptr TopologyChangeEvent::new_node(const Address& address) {
2077   return Ptr(new TopologyChangeEvent(NEW_NODE, address));
2078 }
2079 
moved_node(const Address & address)2080 Event::Ptr TopologyChangeEvent::moved_node(const Address& address) {
2081   return Ptr(new TopologyChangeEvent(MOVED_NODE, address));
2082 }
2083 
removed_node(const Address & address)2084 Event::Ptr TopologyChangeEvent::removed_node(const Address& address) {
2085   return Ptr(new TopologyChangeEvent(REMOVED_NODE, address));
2086 }
2087 
encode(TopologyChangeEvent::Type type,const Address & address)2088 String TopologyChangeEvent::encode(TopologyChangeEvent::Type type, const Address& address) {
2089   String body;
2090   encode_string("TOPOLOGY_CHANGE", &body);
2091   switch (type) {
2092     case NEW_NODE:
2093       encode_string("NEW_NODE", &body);
2094       break;
2095     case MOVED_NODE:
2096       encode_string("MOVED_NODE", &body);
2097       break;
2098     case REMOVED_NODE:
2099       encode_string("REMOVED_NODE", &body);
2100       break;
2101   };
2102   encode_inet(address, &body);
2103   return body;
2104 }
2105 
up(const Address & address)2106 Event::Ptr StatusChangeEvent::up(const Address& address) {
2107   return Ptr(new StatusChangeEvent(UP, address));
2108 }
2109 
down(const Address & address)2110 Event::Ptr StatusChangeEvent::down(const Address& address) {
2111   return Ptr(new StatusChangeEvent(DOWN, address));
2112 }
2113 
encode(Type type,const Address & address)2114 String StatusChangeEvent::encode(Type type, const Address& address) {
2115   String body;
2116   encode_string("STATUS_CHANGE", &body);
2117   switch (type) {
2118     case UP:
2119       encode_string("UP", &body);
2120       break;
2121     case DOWN:
2122       encode_string("DOWN", &body);
2123       break;
2124   };
2125   encode_inet(address, &body);
2126   return body;
2127 }
2128 
keyspace(SchemaChangeEvent::Type type,const String & keyspace_name)2129 Event::Ptr SchemaChangeEvent::keyspace(SchemaChangeEvent::Type type, const String& keyspace_name) {
2130   return Ptr(new SchemaChangeEvent(KEYSPACE, type, keyspace_name));
2131 }
2132 
table(SchemaChangeEvent::Type type,const String & keyspace_name,const String & table_name)2133 Event::Ptr SchemaChangeEvent::table(SchemaChangeEvent::Type type, const String& keyspace_name,
2134                                     const String& table_name) {
2135   return Ptr(new SchemaChangeEvent(TABLE, type, keyspace_name, table_name));
2136 }
2137 
user_type(SchemaChangeEvent::Type type,const String & keyspace_name,const String & user_type_name)2138 Event::Ptr SchemaChangeEvent::user_type(SchemaChangeEvent::Type type, const String& keyspace_name,
2139                                         const String& user_type_name) {
2140   return Ptr(new SchemaChangeEvent(USER_TYPE, type, keyspace_name, user_type_name));
2141 }
2142 
function(SchemaChangeEvent::Type type,const String & keyspace_name,const String & function_name,const Vector<String> & args_types)2143 Event::Ptr SchemaChangeEvent::function(SchemaChangeEvent::Type type, const String& keyspace_name,
2144                                        const String& function_name,
2145                                        const Vector<String>& args_types) {
2146   return Ptr(new SchemaChangeEvent(FUNCTION, type, keyspace_name, function_name, args_types));
2147 }
2148 
aggregate(SchemaChangeEvent::Type type,const String & keyspace_name,const String & aggregate_name,const Vector<String> & args_types)2149 Event::Ptr SchemaChangeEvent::aggregate(SchemaChangeEvent::Type type, const String& keyspace_name,
2150                                         const String& aggregate_name,
2151                                         const Vector<String>& args_types) {
2152   return Ptr(new SchemaChangeEvent(AGGREGATE, type, keyspace_name, aggregate_name, args_types));
2153 }
2154 
encode(SchemaChangeEvent::Target target,SchemaChangeEvent::Type type,const String & keyspace_name,const String & target_name,const Vector<String> & arg_types)2155 String SchemaChangeEvent::encode(SchemaChangeEvent::Target target, SchemaChangeEvent::Type type,
2156                                  const String& keyspace_name, const String& target_name,
2157                                  const Vector<String>& arg_types) {
2158   String body;
2159   encode_string("SCHEMA_CHANGE", &body);
2160   switch (type) {
2161     case CREATED:
2162       encode_string("CREATED", &body);
2163       break;
2164     case UPDATED:
2165       encode_string("UPDATED", &body);
2166       break;
2167     case DROPPED:
2168       encode_string("DROPPED", &body);
2169       break;
2170   }
2171   switch (target) {
2172     case KEYSPACE:
2173       encode_string("KEYSPACE", &body);
2174       encode_string(keyspace_name, &body);
2175       break;
2176     case TABLE:
2177       encode_string("TABLE", &body);
2178       encode_string(keyspace_name, &body);
2179       encode_string(target_name, &body);
2180       break;
2181       ;
2182     case USER_TYPE:
2183       encode_string("TYPE", &body);
2184       encode_string(keyspace_name, &body);
2185       encode_string(target_name, &body);
2186       break;
2187     case FUNCTION:
2188       encode_string("FUNCTION", &body);
2189       encode_string(keyspace_name, &body);
2190       encode_string(target_name, &body);
2191       encode_string_list(arg_types, &body);
2192       break;
2193     case AGGREGATE:
2194       encode_string("AGGREGATE", &body);
2195       encode_string(keyspace_name, &body);
2196       encode_string(target_name, &body);
2197       encode_string_list(arg_types, &body);
2198       break;
2199   }
2200   return body;
2201 }
2202 
init(AddressGenerator & generator,ClientConnectionFactory & factory,size_t num_nodes_dc1,size_t num_nodes_dc2)2203 void Cluster::init(AddressGenerator& generator, ClientConnectionFactory& factory,
2204                    size_t num_nodes_dc1, size_t num_nodes_dc2) {
2205   for (size_t i = 0; i < num_nodes_dc1; ++i) {
2206     create_and_add_server(generator, factory, "dc1");
2207   }
2208   for (size_t i = 0; i < num_nodes_dc2; ++i) {
2209     create_and_add_server(generator, factory, "dc2");
2210   }
2211 }
2212 
~Cluster()2213 Cluster::~Cluster() { stop_all(); }
2214 
use_ssl(const String & cn)2215 String Cluster::use_ssl(const String& cn /*= ""*/) {
2216   String key(Ssl::generate_key());
2217   String cert(Ssl::generate_cert(key, cn));
2218   for (size_t i = 0; i < servers_.size(); ++i) {
2219     Server& server = servers_[i];
2220     if (!server.connection->use_ssl(key, cert)) {
2221       return "";
2222     }
2223   }
2224   return cert;
2225 }
2226 
start_all(EventLoopGroup * event_loop_group)2227 int Cluster::start_all(EventLoopGroup* event_loop_group) {
2228   start_all_async(event_loop_group);
2229   for (size_t i = 0; i < servers_.size(); ++i) {
2230     Server& server = servers_[i];
2231     int rc = server.connection->wait_listen();
2232     if (rc != 0) return rc;
2233   }
2234   return 0;
2235 }
2236 
start_all_async(EventLoopGroup * event_loop_group)2237 void Cluster::start_all_async(EventLoopGroup* event_loop_group) {
2238   for (size_t i = 0; i < servers_.size(); ++i) {
2239     Server& server = servers_[i];
2240     server.connection->listen(event_loop_group);
2241   }
2242 }
2243 
stop_all()2244 void Cluster::stop_all() {
2245   stop_all_async();
2246   for (size_t i = 0; i < servers_.size(); ++i) {
2247     Server& server = servers_[i];
2248     server.connection->wait_close();
2249   }
2250 }
2251 
stop_all_async()2252 void Cluster::stop_all_async() {
2253   for (size_t i = 0; i < servers_.size(); ++i) {
2254     Server& server = servers_[i];
2255     server.connection->close();
2256   }
2257 }
2258 
start(EventLoopGroup * event_loop_group,size_t node)2259 int Cluster::start(EventLoopGroup* event_loop_group, size_t node) {
2260   if (node < 1 || node > servers_.size()) {
2261     return -1;
2262   }
2263   Server& server = servers_[node - 1];
2264   server.connection->listen(event_loop_group);
2265   return server.connection->wait_listen();
2266 }
2267 
start_async(EventLoopGroup * event_loop_group,size_t node)2268 void Cluster::start_async(EventLoopGroup* event_loop_group, size_t node) {
2269   if (node < 1 || node > servers_.size()) {
2270     return;
2271   }
2272   Server& server = servers_[node - 1];
2273   server.connection->listen(event_loop_group);
2274 }
2275 
stop(size_t node)2276 void Cluster::stop(size_t node) {
2277   if (node < 1 || node > servers_.size()) {
2278     return;
2279   }
2280   Server& server = servers_[node - 1];
2281   server.connection->close();
2282   server.connection->wait_close();
2283 }
2284 
stop_async(size_t node)2285 void Cluster::stop_async(size_t node) {
2286   if (node < 1 || node > servers_.size()) {
2287     return;
2288   }
2289   Server& server = servers_[node - 1];
2290   server.connection->close();
2291 }
2292 
add(EventLoopGroup * event_loop_group,size_t node)2293 int Cluster::add(EventLoopGroup* event_loop_group, size_t node) {
2294   if (node < 1 || node > servers_.size()) {
2295     return -1;
2296   }
2297   Server& server = servers_[node - 1];
2298   bool is_removed = server.is_removed.exchange(false);
2299   server.connection->listen(event_loop_group);
2300   int rc = server.connection->wait_listen();
2301 
2302   // Send the added node event after starting the socket
2303   if (is_removed) { // Only send topology change event if node was previously removed
2304     event(TopologyChangeEvent::new_node(server.connection->address()));
2305   }
2306 
2307   return rc;
2308 }
2309 
remove(size_t node)2310 void Cluster::remove(size_t node) {
2311   if (node < 1 || node > servers_.size()) {
2312     return;
2313   }
2314   Server& server = servers_[node - 1];
2315   bool is_removed = server.is_removed.exchange(true);
2316 
2317   // Send the remove node event before closing the socket
2318   if (!is_removed) { // Only send the topology change event if node was previously active
2319     event(TopologyChangeEvent::removed_node(server.connection->address()));
2320   }
2321 
2322   server.connection->close();
2323   server.connection->wait_close();
2324 }
2325 
host(const Address & address) const2326 const Host& Cluster::host(const Address& address) const {
2327   for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) {
2328     if (it->host.address == address) {
2329       return it->host;
2330     }
2331   }
2332 
2333   throw Exception(ERROR_PROTOCOL_ERROR, "Unable to find host " + address.to_string());
2334 }
2335 
hosts() const2336 Hosts Cluster::hosts() const {
2337   Hosts hosts;
2338   hosts.reserve(servers_.size());
2339   for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) {
2340     if (!it->is_removed.load()) {
2341       hosts.push_back(it->host);
2342     }
2343   }
2344   return hosts;
2345 }
2346 
connection_attempts(size_t node) const2347 unsigned Cluster::connection_attempts(size_t node) const {
2348   if (node < 1 || node > servers_.size()) {
2349     return 0;
2350   }
2351   const Server& server = servers_[node - 1];
2352   return server.connection->connection_attempts();
2353 }
2354 
create_and_add_server(AddressGenerator & generator,ClientConnectionFactory & factory,const String & dc)2355 int Cluster::create_and_add_server(AddressGenerator& generator, ClientConnectionFactory& factory,
2356                                    const String& dc) {
2357   Address address(generator.next());
2358   Server server(Host(address, dc, "rack1", token_rng_),
2359                 internal::ServerConnection::Ptr(new internal::ServerConnection(address, factory)));
2360 
2361   servers_.push_back(server);
2362   return static_cast<int>(servers_.size());
2363 }
2364 
event(const Event::Ptr & event)2365 void Cluster::event(const Event::Ptr& event) {
2366   for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) {
2367     it->connection->run(internal::ServerConnectionTask::Ptr(event));
2368   }
2369 }
2370 
next()2371 Address Ipv4AddressGenerator::next() {
2372   char buf[32];
2373   sprintf(buf, "%d.%d.%d.%d", (ip_ >> 24) & 0xff, (ip_ >> 16) & 0xff, (ip_ >> 8) & 0xff,
2374           ip_ & 0xff);
2375   ip_++;
2376   return Address(buf, port_);
2377 }
2378 
Host(const Address & address,const String & dc,const String & rack,MT19937_64 & token_rng,int num_tokens)2379 Host::Host(const Address& address, const String& dc, const String& rack, MT19937_64& token_rng,
2380            int num_tokens)
2381     : address(address)
2382     , dc(dc)
2383     , rack(rack)
2384     , partitioner("org.apache.cassandra.dht.Murmur3Partitioner") {
2385   // Only murmur tokens currently supported
2386   for (int i = 0; i < num_tokens; ++i) {
2387     OStringStream ss;
2388     ss << static_cast<int64_t>(token_rng());
2389     tokens.push_back(ss.str());
2390   }
2391 }
2392 
SimpleEventLoopGroup(size_t num_threads,const String & thread_name)2393 SimpleEventLoopGroup::SimpleEventLoopGroup(size_t num_threads,
2394                                            const String& thread_name /*= "mockssandra"*/)
2395     : RoundRobinEventLoopGroup(num_threads) {
2396   int rc = init(thread_name);
2397   UNUSED_(rc);
2398   assert(rc == 0 && "Unable to initialize simple event loop");
2399   run();
2400 }
2401 
~SimpleEventLoopGroup()2402 SimpleEventLoopGroup::~SimpleEventLoopGroup() {
2403   close_handles();
2404   join();
2405 }
2406 
SimpleRequestHandlerBuilder()2407 SimpleRequestHandlerBuilder::SimpleRequestHandlerBuilder()
2408     : RequestHandler::Builder() {
2409   on(OPCODE_STARTUP).validate_startup().ready();
2410   on(OPCODE_OPTIONS).supported();
2411   on(OPCODE_CREDENTIALS).validate_credentials().ready();
2412   on(OPCODE_AUTH_RESPONSE).validate_auth_response().auth_success("");
2413   on(OPCODE_REGISTER)
2414       .validate_register()
2415       .set_protocol_version()
2416       .set_registered_for_events()
2417       .ready();
2418   on(OPCODE_QUERY).system_local().system_peers().empty_rows_result(1);
2419 }
2420 
AuthRequestHandlerBuilder(const String & username,const String & password)2421 AuthRequestHandlerBuilder::AuthRequestHandlerBuilder(const String& username, const String& password)
2422     : SimpleRequestHandlerBuilder() {
2423   on(mockssandra::OPCODE_STARTUP).validate_startup().authenticate("com.datastax.SomeAuthenticator");
2424   on(mockssandra::OPCODE_AUTH_RESPONSE).validate_auth_response().plaintext_auth(username, password);
2425 }
2426 
2427 } // namespace mockssandra
2428