1 /*
2 Copyright (c) DataStax, Inc.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #include "mockssandra.hpp"
18
19 #include <assert.h>
20 #include <stdio.h>
21
22 #include "control_connection.hpp" // For host queries
23 #include "memory.hpp"
24 #include "scoped_lock.hpp"
25 #include "tracing_data_handler.hpp" // For tracing query
26 #include "utils.hpp"
27 #include "uuids.hpp"
28
29 #include <openssl/bio.h>
30 #include <openssl/dh.h>
31 #include <openssl/x509v3.h>
32
33 #ifdef WIN32
34 #include "winsock.h"
35 #endif
36
37 using datastax::internal::bind_callback;
38 using datastax::internal::escape_id;
39 using datastax::internal::Map;
40 using datastax::internal::Memory;
41 using datastax::internal::OStringStream;
42 using datastax::internal::ScopedMutex;
43 using datastax::internal::trim;
44 using datastax::internal::core::UuidGen;
45
46 #define SSL_BUF_SIZE 8192
47 #define CASSANDRA_VERSION "3.11.4"
48 #define DSE_VERSION "6.7.1"
49 #define DSE_CASSANDRA_VERSION "4.0.0.671"
50
51 #if defined(OPENSSL_VERSION_NUMBER) && \
52 !defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
53 // as 2.0.0
54 #if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
55 #define SSL_SERVER_METHOD TLS_server_method
56 #else
57 #define SSL_SERVER_METHOD SSLv23_server_method
58 #endif
59 #else
60 #if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
61 #define SSL_SERVER_METHOD TLS_server_method
62 #else
63 #define SSL_SERVER_METHOD SSLv23_server_method
64 #endif
65 #endif
66
67 namespace mockssandra {
68
69 namespace {
70
71 template <class T>
72 struct FreeDeleterImpl {};
73
74 #define MAKE_DELETER(type, free_func) \
75 template <> \
76 struct FreeDeleterImpl<type> { \
77 static void free(type* ptr) { free_func(ptr); } \
78 };
79
80 MAKE_DELETER(BIO, BIO_free)
81 MAKE_DELETER(DH, DH_free)
82 MAKE_DELETER(EVP_PKEY, EVP_PKEY_free)
83 MAKE_DELETER(EVP_PKEY_CTX, EVP_PKEY_CTX_free)
84 MAKE_DELETER(X509, X509_free)
85 MAKE_DELETER(X509_REQ, X509_REQ_free)
86 MAKE_DELETER(X509_EXTENSION, X509_EXTENSION_free)
87
88 template <class T>
89 struct FreeDeleter {
operator ()mockssandra::__anon49d184150111::FreeDeleter90 void operator()(T* ptr) const { FreeDeleterImpl<T>::free(ptr); }
91 };
92
93 template <class T>
94 class Scoped : public ScopedPtr<T, FreeDeleter<T> > {
95 public:
Scoped(T * ptr=NULL)96 Scoped(T* ptr = NULL)
97 : ScopedPtr<T, FreeDeleter<T> >(ptr) {}
98 };
99
print_ssl_error()100 void print_ssl_error() {
101 unsigned long err = ERR_get_error();
102 char buf[256];
103 ERR_error_string_n(err, buf, sizeof(buf));
104 fprintf(stderr, "%s\n", buf);
105 }
106
load_cert(const String & cert)107 X509* load_cert(const String& cert) {
108 X509* x509 = NULL;
109 Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(cert.c_str()), cert.length()));
110 if (PEM_read_bio_X509(bio.get(), &x509, NULL, NULL) == NULL) {
111 print_ssl_error();
112 return NULL;
113 }
114 return x509;
115 }
116
load_private_key(const String & key)117 EVP_PKEY* load_private_key(const String& key) {
118 EVP_PKEY* pkey = NULL;
119 Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(key.c_str()), key.length()));
120 if (!PEM_read_bio_PrivateKey(bio.get(), &pkey, NULL, NULL)) {
121 print_ssl_error();
122 return NULL;
123 }
124 return pkey;
125 }
126
dh_parameters()127 DH* dh_parameters() {
128 // Generated using the following command: `openssl dhparam -C 2048`
129 // Prime length of 2048 chosen to bypass client-side error:
130 // `SSL3_CHECK_CERT_AND_ALGORITHM:dh key too small`
131
132 // Note: This is not generated, programmatically, using something like the following:
133 // `DH_generate_parameters_ex(dh, 2048, DH_GENERATOR_5, NULL)`
134 // because DH prime generation takes a *REALLY* long time.
135 static const char* dh_parameters_pem =
136 "-----BEGIN DH PARAMETERS-----\n"
137 "MIIBCAKCAQEAusYypYO7u8mHelHjpDuUy7hjBgPw/KS03iSRnP5SNMB6OxVFslXv\n"
138 "s6McqEf218Fqpzi18tWA7fq3fvlT+Nx1Tda+Za5C8o5niRYxHks5N+RfnnrFf7vn\n"
139 "0lxrzsXP6es08Ts/UGMsp1nEaCSd/gjDglPgjdC1V/KmBsbT+8IwpbzPPdir0/jA\n"
140 "r+DXssZRZl7JtymGHXPkXTSBhsqSHamfzGRnAQFWToKAinqAdhY7pN/8krwvRj04\n"
141 "VYp84xAy2M6mWWqUm/kokN9QjAiT/DZRxZK8VhY7O9+oATo7/YPCMd9Em417O13k\n"
142 "+F0o/8IMaQvpmtlAsLc2ZKwGqqG+HD2dOwIBAg==\n"
143 "-----END DH PARAMETERS-----";
144 Scoped<BIO> bio(BIO_new_mem_buf(const_cast<char*>(dh_parameters_pem),
145 -1)); // Use null terminator for length
146 return PEM_read_bio_DHparams(bio.get(), NULL, NULL, NULL);
147 }
148
149 } // namespace
150
generate_key()151 String Ssl::generate_key() {
152 Scoped<EVP_PKEY_CTX> pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL));
153
154 EVP_PKEY_keygen_init(pctx.get());
155 EVP_PKEY_CTX_set_rsa_keygen_bits(pctx.get(), 2048);
156
157 Scoped<EVP_PKEY> pkey;
158 { // Generate RSA key
159 EVP_PKEY* temp = NULL;
160 EVP_PKEY_keygen(pctx.get(), &temp);
161 pkey.reset(temp);
162 }
163
164 Scoped<BIO> bio(BIO_new(BIO_s_mem()));
165 PEM_write_bio_PrivateKey(bio.get(), pkey.get(), NULL, NULL, 0, NULL, NULL);
166
167 BUF_MEM* mem = NULL;
168 BIO_get_mem_ptr(bio.get(), &mem);
169 return String(mem->data, mem->length);
170 }
171
generate_cert(const String & key,String cn,String ca_cert,String ca_key)172 String Ssl::generate_cert(const String& key, String cn, String ca_cert, String ca_key) {
173 // Assign the proper default hostname
174 if (cn.empty()) {
175 #ifdef WIN32
176 char win_hostname[64];
177 gethostname(win_hostname, 64);
178 cn = win_hostname;
179 #else
180 cn = "localhost";
181 #endif
182 }
183
184 Scoped<EVP_PKEY> pkey(load_private_key(key));
185 if (!pkey) return "";
186
187 Scoped<X509_REQ> x509_req;
188 if (!ca_cert.empty() && !ca_key.empty()) {
189 x509_req.reset(X509_REQ_new());
190 X509_REQ_set_version(x509_req.get(), 2);
191 X509_REQ_set_pubkey(x509_req.get(), pkey.get());
192
193 X509_NAME* name = X509_REQ_get_subject_name(x509_req.get());
194 X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC,
195 reinterpret_cast<const unsigned char*>("US"), -1, -1, 0);
196 X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC,
197 reinterpret_cast<const unsigned char*>(cn.c_str()), -1, -1, 0);
198 X509_REQ_sign(x509_req.get(), pkey.get(), EVP_sha256());
199 }
200
201 Scoped<X509> x509(X509_new());
202 X509_set_version(x509.get(), 2);
203 ASN1_INTEGER_set(X509_get_serialNumber(x509.get()), 0);
204 X509_gmtime_adj(X509_get_notBefore(x509.get()), 0);
205 X509_gmtime_adj(X509_get_notAfter(x509.get()), static_cast<long>(60 * 60 * 24 * 365));
206 X509_set_pubkey(x509.get(), pkey.get());
207
208 if (x509_req) {
209 X509_set_subject_name(x509.get(), X509_REQ_get_subject_name(x509_req.get()));
210
211 Scoped<X509> x509_ca(load_cert(ca_cert));
212 if (!x509_ca) return "";
213 X509_set_issuer_name(x509.get(), X509_get_issuer_name(x509_ca.get()));
214
215 Scoped<EVP_PKEY> pkey_ca(load_private_key(ca_key));
216 if (!pkey_ca) return "";
217 X509_sign(x509.get(), pkey_ca.get(), EVP_sha256());
218 } else {
219 if (cn == "CA") { // Set the purpose as a CA certificate.
220 X509V3_CTX x509v3_ctx;
221 X509V3_set_ctx_nodb(&x509v3_ctx);
222 X509V3_set_ctx(&x509v3_ctx, x509.get(), x509.get(), NULL, NULL, 0);
223
224 Scoped<X509_EXTENSION> x509_ex(X509V3_EXT_conf_nid(NULL, &x509v3_ctx, NID_basic_constraints,
225 const_cast<char*>("critical,CA:TRUE")));
226 if (!x509_ex) return "";
227
228 X509_add_ext(x509.get(), x509_ex.get(), -1);
229 }
230 X509_NAME* name = X509_get_subject_name(x509.get());
231 X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC,
232 reinterpret_cast<const unsigned char*>("US"), -1, -1, 0);
233 X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC,
234 reinterpret_cast<const unsigned char*>(cn.c_str()), -1, -1, 0);
235 X509_set_issuer_name(x509.get(), name);
236 X509_sign(x509.get(), pkey.get(), EVP_sha256());
237 }
238
239 String result;
240 { // Write cert into string
241 Scoped<BIO> bio(BIO_new(BIO_s_mem()));
242 PEM_write_bio_X509(bio.get(), x509.get());
243 BUF_MEM* mem = NULL;
244 BIO_get_mem_ptr(bio.get(), &mem);
245 result.append(mem->data, mem->length);
246 }
247
248 return result;
249 }
250
251 namespace internal {
252
253 struct WriteReq {
WriteReqmockssandra::internal::WriteReq254 WriteReq(const char* data, size_t len, ClientConnection* connection)
255 : data(data, len)
256 , connection(connection) {
257 req.data = this;
258 }
259 const String data;
260 ClientConnection* const connection;
261 uv_write_t req;
262 };
263
Tcp(void * data)264 Tcp::Tcp(void* data) { tcp_.data = data; }
265
init(uv_loop_t * loop)266 int Tcp::init(uv_loop_t* loop) { return uv_tcp_init(loop, &tcp_); }
267
bind(const struct sockaddr * addr)268 int Tcp::bind(const struct sockaddr* addr) { return uv_tcp_bind(&tcp_, addr, 0); }
269
as_handle()270 uv_handle_t* Tcp::as_handle() { return reinterpret_cast<uv_handle_t*>(&tcp_); }
271
as_stream()272 uv_stream_t* Tcp::as_stream() { return reinterpret_cast<uv_stream_t*>(&tcp_); }
273
ClientConnection(ServerConnection * server)274 ClientConnection::ClientConnection(ServerConnection* server)
275 : tcp_(this)
276 , server_(server)
277 , ssl_(server->ssl_context() ? SSL_new(server->ssl_context()) : NULL)
278 , incoming_bio_(ssl_ ? BIO_new(BIO_s_mem()) : NULL)
279 , outgoing_bio_(ssl_ ? BIO_new(BIO_s_mem()) : NULL) {
280 tcp_.init(server->loop());
281 if (ssl_) {
282 SSL_set_bio(ssl_, incoming_bio_, outgoing_bio_);
283 }
284 }
285
~ClientConnection()286 ClientConnection::~ClientConnection() {
287 if (ssl_) SSL_free(ssl_);
288 }
289
write(const String & data)290 int ClientConnection::write(const String& data) { return write(data.data(), data.length()); }
291
write(const char * data,size_t len)292 int ClientConnection::write(const char* data, size_t len) {
293 if (ssl_) {
294 return ssl_write(data, len);
295 } else {
296 return internal_write(data, len);
297 }
298 }
299
close()300 void ClientConnection::close() {
301 if (!uv_is_closing(tcp_.as_handle())) {
302 uv_close(tcp_.as_handle(), on_close);
303 }
304 }
305
accept()306 int ClientConnection::accept() {
307 int rc = server_->accept(tcp_.as_stream());
308 if (rc != 0) return rc;
309 return uv_read_start(tcp_.as_stream(), on_alloc, on_read);
310 }
311
sni_server_name() const312 const char* ClientConnection::sni_server_name() const {
313 if (ssl_) {
314 return SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name);
315 }
316 return NULL;
317 }
318
on_close(uv_handle_t * handle)319 void ClientConnection::on_close(uv_handle_t* handle) {
320 ClientConnection* connection = static_cast<ClientConnection*>(handle->data);
321 connection->handle_close();
322 }
323
handle_close()324 void ClientConnection::handle_close() {
325 on_close();
326 server_->remove(this);
327 delete this;
328 }
329
on_alloc(uv_handle_t * handle,size_t suggested_size,uv_buf_t * buf)330 void ClientConnection::on_alloc(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) {
331 buf->base = static_cast<char*>(Memory::malloc(suggested_size));
332 buf->len = suggested_size;
333 }
334
on_read(uv_stream_t * stream,ssize_t nread,const uv_buf_t * buf)335 void ClientConnection::on_read(uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf) {
336 ClientConnection* connection = static_cast<ClientConnection*>(stream->data);
337 connection->handle_read(nread, buf);
338 Memory::free(buf->base);
339 }
340
handle_read(ssize_t nread,const uv_buf_t * buf)341 void ClientConnection::handle_read(ssize_t nread, const uv_buf_t* buf) {
342 if (nread < 0) {
343 if (nread != UV_EOF && nread != UV_ECONNRESET) {
344 fprintf(stderr, "Read failure: %s\n", uv_strerror(nread));
345 }
346 close();
347 return;
348 }
349 if (ssl_) {
350 on_ssl_read(buf->base, nread);
351 } else {
352 on_read(buf->base, nread);
353 }
354 }
355
on_write(uv_write_t * req,int status)356 void ClientConnection::on_write(uv_write_t* req, int status) {
357 WriteReq* write = static_cast<WriteReq*>(req->data);
358 write->connection->handle_write(status);
359 delete write;
360 }
361
handle_write(int status)362 void ClientConnection::handle_write(int status) {
363 if (status != 0) {
364 fprintf(stderr, "Write failure: %s\n", uv_strerror(status));
365 close();
366 return;
367 }
368
369 on_write();
370 }
371
internal_write(const char * data,size_t len)372 int ClientConnection::internal_write(const char* data, size_t len) {
373 uv_buf_t buf;
374 WriteReq* write = new WriteReq(data, len, this);
375 buf.base = const_cast<char*>(write->data.data());
376 buf.len = write->data.length();
377 int rc = uv_write(&write->req, tcp_.as_stream(), &buf, 1, on_write);
378 if (rc != 0) {
379 delete write;
380 }
381 return rc;
382 }
383
ssl_write(const char * data,size_t len)384 int ClientConnection::ssl_write(const char* data, size_t len) {
385 if (has_ssl_error(SSL_write(ssl_, data, len))) {
386 return -1;
387 }
388
389 char buf[SSL_BUF_SIZE];
390 int num_bytes;
391 while ((num_bytes = BIO_read(outgoing_bio_, buf, sizeof(buf))) > 0) {
392 int rc = internal_write(buf, num_bytes);
393 if (rc != 0) {
394 return rc;
395 }
396 }
397
398 return 0;
399 }
400
is_handshake_done()401 bool ClientConnection::is_handshake_done() { return SSL_is_init_finished(ssl_) != 0; }
402
has_ssl_error(int rc)403 bool ClientConnection::has_ssl_error(int rc) {
404 if (rc > 0) return false;
405
406 int err = SSL_get_error(ssl_, rc);
407 if (err == SSL_ERROR_ZERO_RETURN) {
408 close();
409 } else if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_NONE) {
410 const char* data;
411 int flags;
412 int err;
413 String error;
414 while ((err = ERR_get_error_line_data(NULL, NULL, &data, &flags)) != 0) {
415 char buf[256];
416 ERR_error_string_n(err, buf, sizeof(buf));
417 if (!error.empty()) error.push_back(',');
418 error.append(buf);
419 if (flags & ERR_TXT_STRING) {
420 error.push_back(':');
421 error.append(data);
422 }
423 }
424 fprintf(stderr, "SSL error: %s\n", error.c_str());
425 close();
426 return true;
427 }
428
429 return false;
430 }
431
on_ssl_read(const char * data,size_t len)432 void ClientConnection::on_ssl_read(const char* data, size_t len) {
433 int rc;
434 BIO_write(incoming_bio_, data, len);
435
436 if (!is_handshake_done()) {
437 int rc = SSL_accept(ssl_);
438 if (has_ssl_error(rc)) {
439 return;
440 }
441
442 char buf[SSL_BUF_SIZE];
443 bool data_written = false;
444 int num_bytes;
445 while ((num_bytes = BIO_read(outgoing_bio_, buf, sizeof(buf))) > 0) {
446 data_written = true;
447 internal_write(buf, num_bytes);
448 }
449
450 if (is_handshake_done() && data_written) {
451 return; // Handshake is not completed; ingore remaining data
452 }
453 } else {
454 char buf[SSL_BUF_SIZE];
455 while ((rc = SSL_read(ssl_, buf, sizeof(buf))) > 0) {
456 on_read(buf, rc);
457 }
458 has_ssl_error(rc);
459 }
460 }
461
ServerConnection(const Address & address,const ClientConnectionFactory & factory)462 ServerConnection::ServerConnection(const Address& address, const ClientConnectionFactory& factory)
463 : tcp_(this)
464 , event_loop_(NULL)
465 , state_(STATE_CLOSED)
466 , rc_(0)
467 , address_(address)
468 , factory_(factory)
469 , ssl_context_(NULL)
470 , connection_attempts_(0) {
471 uv_mutex_init(&mutex_);
472 uv_cond_init(&cond_);
473 }
474
~ServerConnection()475 ServerConnection::~ServerConnection() {
476 uv_mutex_destroy(&mutex_);
477 uv_cond_destroy(&cond_);
478 if (ssl_context_) {
479 SSL_CTX_free(ssl_context_);
480 }
481 }
482
loop()483 uv_loop_t* ServerConnection::loop() {
484 ScopedMutex l(&mutex_);
485 return event_loop_->loop();
486 }
487
use_ssl(const String & key,const String & cert,const String & ca_cert,bool require_client_cert)488 bool ServerConnection::use_ssl(const String& key, const String& cert,
489 const String& ca_cert /*= ""*/,
490 bool require_client_cert /*= false*/) {
491 if (ssl_context_) {
492 SSL_CTX_free(ssl_context_);
493 }
494
495 if ((ssl_context_ = SSL_CTX_new(SSL_SERVER_METHOD())) == NULL) {
496 print_ssl_error();
497 return false;
498 }
499
500 SSL_CTX_set_default_passwd_cb_userdata(ssl_context_, (void*)"");
501 SSL_CTX_set_default_passwd_cb(ssl_context_, on_password);
502 SSL_CTX_set_verify(ssl_context_, SSL_VERIFY_NONE, NULL);
503
504 { // Load server certificate
505 Scoped<X509> x509(load_cert(cert));
506 if (!x509) return false;
507 if (SSL_CTX_use_certificate(ssl_context_, x509.get()) <= 0) {
508 print_ssl_error();
509 return false;
510 }
511 }
512
513 if (!ca_cert.empty()) { // Load CA certificate
514
515 { // Add CA certificate to chain to send to the client
516 Scoped<X509> x509(load_cert(ca_cert));
517 if (!x509) return false;
518
519 if (SSL_CTX_add_extra_chain_cert(ssl_context_, x509.release()) <=
520 0) { // Certificate freed by function
521 print_ssl_error();
522 return false;
523 }
524 }
525
526 if (require_client_cert) {
527 Scoped<X509> x509(load_cert(ca_cert));
528 if (!x509) return false;
529
530 // Add CA certificate to chain to validate peer certificate
531 X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context_);
532 if (X509_STORE_add_cert(cert_store, x509.get()) <= 0) {
533 print_ssl_error();
534 return false;
535 }
536
537 SSL_CTX_set_verify(ssl_context_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
538 }
539 }
540
541 Scoped<EVP_PKEY> pkey(load_private_key(key));
542 if (!pkey) return false;
543
544 if (SSL_CTX_use_PrivateKey(ssl_context_, pkey.get()) <= 0) {
545 print_ssl_error();
546 return false;
547 }
548
549 Scoped<DH> dh(dh_parameters());
550 if (!dh || !SSL_CTX_set_tmp_dh(ssl_context_, dh.get())) {
551 print_ssl_error();
552 return false;
553 }
554
555 return true;
556 }
557
558 using datastax::internal::core::Task;
559
560 class RunListen : public Task {
561 public:
RunListen(ServerConnection * server)562 RunListen(ServerConnection* server)
563 : server_(server) {}
564
run(EventLoop * event_loop)565 virtual void run(EventLoop* event_loop) { server_->internal_listen(); }
566
567 private:
568 ServerConnection* server_;
569 };
570
571 class RunClose : public Task {
572 public:
RunClose(ServerConnection * server)573 RunClose(ServerConnection* server)
574 : server_(server) {}
575
run(EventLoop * event_loop)576 virtual void run(EventLoop* event_loop) { server_->internal_close(); }
577
578 private:
579 ServerConnection* server_;
580 };
581
582 class RunTask : public Task {
583 public:
RunTask(const ServerConnectionTask::Ptr & task,const ServerConnection::Ptr & connection)584 RunTask(const ServerConnectionTask::Ptr& task, const ServerConnection::Ptr& connection)
585 : task_(task)
586 , connection_(connection) {}
587
run(EventLoop * event_loop)588 virtual void run(EventLoop* event_loop) { task_->run(connection_.get()); }
589
590 private:
591 ServerConnectionTask::Ptr task_;
592 ServerConnection::Ptr connection_;
593 };
594
listen(EventLoopGroup * event_loop_group)595 void ServerConnection::listen(EventLoopGroup* event_loop_group) {
596 ScopedMutex l(&mutex_);
597 if (state_ != STATE_CLOSED) return;
598 rc_ = 0;
599 state_ = STATE_PENDING;
600 event_loop_ = event_loop_group->add(new RunListen(this));
601 }
602
wait_listen()603 int ServerConnection::wait_listen() {
604 ScopedMutex l(&mutex_);
605 while (state_ == STATE_PENDING) {
606 uv_cond_wait(&cond_, l.get());
607 }
608 return rc_;
609 }
610
close()611 void ServerConnection::close() {
612 ScopedMutex l(&mutex_);
613 if (state_ != STATE_LISTENING && state_ != STATE_PENDING) return;
614 state_ = STATE_CLOSING;
615 event_loop_->add(new RunClose(this));
616 }
617
wait_close()618 void ServerConnection::wait_close() {
619 ScopedMutex l(&mutex_);
620 while (state_ == STATE_CLOSING) {
621 uv_cond_wait(&cond_, l.get());
622 }
623 }
624
connection_attempts() const625 unsigned ServerConnection::connection_attempts() const { return connection_attempts_.load(); }
run(const ServerConnectionTask::Ptr & task)626 void ServerConnection::run(const ServerConnectionTask::Ptr& task) {
627 ScopedMutex l(&mutex_);
628 if (state_ != STATE_LISTENING) return;
629 event_loop_->add(new RunTask(task, Ptr(this)));
630 }
631
internal_listen()632 void ServerConnection::internal_listen() {
633 int rc = 0;
634
635 rc = tcp_.init(loop());
636 if (rc != 0) {
637 fprintf(stderr, "Unable to initialize socket\n");
638 signal_listen(rc);
639 return;
640 }
641
642 inc_ref(); // For the TCP handle
643
644 Address::SocketStorage storage;
645 rc = tcp_.bind(address_.to_sockaddr(&storage));
646 if (rc != 0) {
647 fprintf(stderr, "Unable to bind address %s\n", address_.to_string(true).c_str());
648 uv_close(tcp_.as_handle(), on_close);
649 signal_listen(rc);
650 return;
651 }
652
653 rc = uv_listen(tcp_.as_stream(), 128, on_connection);
654 if (rc != 0) {
655 fprintf(stderr, "Unable to listen on address %s\n", address_.to_string(true).c_str());
656 uv_close(tcp_.as_handle(), on_close);
657 signal_listen(rc);
658 return;
659 }
660
661 signal_listen(rc);
662 }
663
accept(uv_stream_t * client)664 int ServerConnection::accept(uv_stream_t* client) { return uv_accept(tcp_.as_stream(), client); }
665
remove(ClientConnection * connection)666 void ServerConnection::remove(ClientConnection* connection) {
667 clients_.erase(std::remove(clients_.begin(), clients_.end(), connection), clients_.end());
668 maybe_close();
669 }
670
internal_close()671 void ServerConnection::internal_close() {
672 for (ClientConnections::iterator it = clients_.begin(), end = clients_.end(); it != end; ++it) {
673 (*it)->close();
674 }
675 maybe_close();
676 }
677
maybe_close()678 void ServerConnection::maybe_close() {
679 ScopedMutex l(&mutex_);
680 if (state_ == STATE_CLOSING && clients_.empty() && !uv_is_closing(tcp_.as_handle())) {
681 uv_close(tcp_.as_handle(), on_close);
682 }
683 }
684
signal_listen(int rc)685 void ServerConnection::signal_listen(int rc) {
686 ScopedMutex l(&mutex_);
687 if (rc != 0) {
688 rc_ = rc;
689 state_ = STATE_CLOSING;
690 } else {
691 state_ = STATE_LISTENING;
692 }
693 uv_cond_signal(&cond_);
694 }
695
signal_close()696 void ServerConnection::signal_close() {
697 ScopedMutex l(&mutex_);
698 event_loop_ = NULL;
699 state_ = STATE_CLOSED;
700 uv_cond_signal(&cond_);
701 }
702
on_connection(uv_stream_t * server,int status)703 void ServerConnection::on_connection(uv_stream_t* server, int status) {
704 ServerConnection* self = static_cast<ServerConnection*>(server->data);
705 self->handle_connection(status);
706 }
707
handle_connection(int status)708 void ServerConnection::handle_connection(int status) {
709 connection_attempts_.fetch_add(1);
710
711 if (status != 0) {
712 fprintf(stderr, "Listen failure: %s\n", uv_strerror(status));
713 return;
714 }
715
716 ClientConnection* connection = factory_.create(this);
717 if (connection && connection->on_accept() != 0) {
718 delete connection;
719 return;
720 }
721 clients_.push_back(connection);
722 }
723
on_close(uv_handle_t * handle)724 void ServerConnection::on_close(uv_handle_t* handle) {
725 ServerConnection* self = static_cast<ServerConnection*>(handle->data);
726 self->handle_close();
727 }
728
handle_close()729 void ServerConnection::handle_close() {
730 signal_close();
731 dec_ref();
732 }
733
on_password(char * buf,int size,int rwflag,void * password)734 int ServerConnection::on_password(char* buf, int size, int rwflag, void* password) {
735 strncpy(buf, (char*)(password), size);
736 buf[size - 1] = '\0';
737 return strlen(buf);
738 }
739
740 } // namespace internal
741
742 #define CHECK(pos, error) \
743 do { \
744 if ((pos) > end) { \
745 fprintf(stderr, "Decoding error: %s\n", (error)); \
746 return end + 1; \
747 } \
748 } while (0)
749
decode_int8(const char * input,const char * end,int8_t * value)750 const char* decode_int8(const char* input, const char* end, int8_t* value) {
751 CHECK(input + 1, "Unable to decode byte");
752 *value = static_cast<int8_t>(input[0]);
753 return input + sizeof(int8_t);
754 }
755
decode_int16(const char * input,const char * end,int16_t * value)756 const char* decode_int16(const char* input, const char* end, int16_t* value) {
757 CHECK(input + 2, "Unable to decode signed short");
758 *value = (static_cast<int16_t>(static_cast<uint8_t>(input[1])) << 0) |
759 (static_cast<int16_t>(static_cast<uint8_t>(input[0])) << 8);
760 return input + sizeof(int16_t);
761 }
762
decode_uint16(const char * input,const char * end,uint16_t * value)763 const char* decode_uint16(const char* input, const char* end, uint16_t* value) {
764 CHECK(input + 2, "Unable to decode unsigned short");
765 *value = (static_cast<uint16_t>(static_cast<uint8_t>(input[1])) << 0) |
766 (static_cast<uint16_t>(static_cast<uint8_t>(input[0])) << 8);
767 return input + sizeof(uint16_t);
768 }
769
decode_int32(const char * input,const char * end,int32_t * value)770 const char* decode_int32(const char* input, const char* end, int32_t* value) {
771 CHECK(input + 4, "Unable to decode integer");
772 *value = (static_cast<int32_t>(static_cast<uint8_t>(input[3])) << 0) |
773 (static_cast<int32_t>(static_cast<uint8_t>(input[2])) << 8) |
774 (static_cast<int32_t>(static_cast<uint8_t>(input[1])) << 16) |
775 (static_cast<int32_t>(static_cast<uint8_t>(input[0])) << 24);
776 return input + sizeof(int32_t);
777 }
778
decode_int64(const char * input,const char * end,int64_t * value)779 const char* decode_int64(const char* input, const char* end, int64_t* value) {
780 CHECK(input + 8, "Unable to decode long");
781 *value = (static_cast<int64_t>(static_cast<uint8_t>(input[7])) << 0) |
782 (static_cast<int64_t>(static_cast<uint8_t>(input[6])) << 8) |
783 (static_cast<int64_t>(static_cast<uint8_t>(input[5])) << 16) |
784 (static_cast<int64_t>(static_cast<uint8_t>(input[4])) << 24) |
785 (static_cast<int64_t>(static_cast<uint8_t>(input[3])) << 32) |
786 (static_cast<int64_t>(static_cast<uint8_t>(input[2])) << 40) |
787 (static_cast<int64_t>(static_cast<uint8_t>(input[1])) << 48) |
788 (static_cast<int64_t>(static_cast<uint8_t>(input[0])) << 56);
789 return input + sizeof(int64_t);
790 }
791
decode_string(const char * input,const char * end,String * output)792 const char* decode_string(const char* input, const char* end, String* output) {
793 uint16_t len = 0;
794 const char* pos = decode_uint16(input, end, &len);
795 CHECK(pos + len, "Unable to decode string");
796 output->assign(pos, len);
797 return pos + len;
798 }
799
decode_long_string(const char * input,const char * end,String * output)800 const char* decode_long_string(const char* input, const char* end, String* output) {
801 int32_t len = 0;
802 const char* pos = decode_int32(input, end, &len);
803 CHECK(pos + len, "Unable to decode long string");
804 assert(len >= 0);
805 output->assign(pos, len);
806 return pos + len;
807 }
808
decode_bytes(const char * input,const char * end,String * output)809 const char* decode_bytes(const char* input, const char* end, String* output) {
810 int32_t len = 0;
811 const char* pos = decode_int32(input, end, &len);
812 if (len > 0) {
813 CHECK(pos + len, "Unable to decode bytes");
814 output->assign(pos, len);
815 }
816 return pos + len;
817 }
818
decode_uuid(const char * input,CassUuid * output)819 const char* decode_uuid(const char* input, CassUuid* output) {
820 output->time_and_version = static_cast<uint64_t>(static_cast<uint8_t>(input[3]));
821 output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[2])) << 8;
822 output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[1])) << 16;
823 output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[0])) << 24;
824
825 output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[5])) << 32;
826 output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[4])) << 40;
827
828 output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[7])) << 48;
829 output->time_and_version |= static_cast<uint64_t>(static_cast<uint8_t>(input[6])) << 56;
830
831 output->clock_seq_and_node = 0;
832 for (size_t i = 0; i < 8; ++i) {
833 output->clock_seq_and_node |= static_cast<uint64_t>(static_cast<uint8_t>(input[15 - i]))
834 << (8 * i);
835 }
836 return input + 16;
837 }
838
decode_string_map(const char * input,const char * end,Vector<std::pair<String,String>> * output)839 const char* decode_string_map(const char* input, const char* end,
840 Vector<std::pair<String, String> >* output) {
841
842 uint16_t len = 0;
843 const char* pos = decode_uint16(input, end, &len);
844 output->reserve(len);
845 for (int i = 0; i < len; ++i) {
846 String key;
847 String value;
848 pos = decode_string(pos, end, &key);
849 pos = decode_string(pos, end, &value);
850 output->push_back(std::pair<String, String>(key, value));
851 }
852 return pos;
853 }
854
decode_stringlist(const char * input,const char * end,Vector<String> * output)855 const char* decode_stringlist(const char* input, const char* end, Vector<String>* output) {
856 uint16_t len = 0;
857 const char* pos = decode_uint16(input, end, &len);
858 output->reserve(len);
859 for (int i = 0; i < len; ++i) {
860 String value;
861 pos = decode_string(pos, end, &value);
862 output->push_back(value);
863 }
864 return pos;
865 }
866
decode_values(const char * input,const char * end,Vector<String> * output)867 const char* decode_values(const char* input, const char* end, Vector<String>* output) {
868 uint16_t len = 0;
869 const char* pos = decode_uint16(input, end, &len);
870 output->reserve(len);
871 for (int i = 0; i < len; ++i) {
872 String value;
873 pos = decode_bytes(pos, end, &value);
874 output->push_back(value);
875 }
876 return pos;
877 }
878
decode_values_with_names(const char * input,const char * end,Vector<String> * names,Vector<String> * values)879 const char* decode_values_with_names(const char* input, const char* end, Vector<String>* names,
880 Vector<String>* values) {
881 uint16_t len = 0;
882 const char* pos = decode_uint16(input, end, &len);
883 names->reserve(len);
884 values->reserve(len);
885 for (int i = 0; i < len; ++i) {
886 String name;
887 pos = decode_string(pos, end, &name);
888 names->push_back(name);
889 String value;
890 pos = decode_bytes(pos, end, &value);
891 values->push_back(value);
892 }
893 return pos;
894 }
895
decode_query_params_v1(const char * input,const char * end,bool is_execute,QueryParameters * params)896 const char* decode_query_params_v1(const char* input, const char* end, bool is_execute,
897 QueryParameters* params) {
898 const char* pos = input;
899 if (is_execute) {
900 pos = decode_values(pos, end, ¶ms->values);
901 pos = decode_uint16(pos, end, ¶ms->consistency);
902 } else {
903 pos = decode_uint16(pos, end, ¶ms->consistency);
904 }
905 return pos;
906 }
907
decode_query_params_v2(const char * input,const char * end,QueryParameters * params)908 const char* decode_query_params_v2(const char* input, const char* end, QueryParameters* params) {
909 int8_t flags = 0;
910 const char* pos = input;
911 pos = decode_uint16(pos, end, ¶ms->consistency);
912 pos = decode_int8(pos, end, &flags);
913 params->flags = flags;
914 if (flags & QUERY_FLAG_VALUES) {
915 pos = decode_values(pos, end, ¶ms->values);
916 }
917 if (flags & QUERY_FLAG_PAGE_SIZE) {
918 pos = decode_int32(pos, end, ¶ms->result_page_size);
919 }
920 if (flags & QUERY_FLAG_PAGE_STATE) {
921 pos = decode_bytes(pos, end, ¶ms->paging_state);
922 }
923 if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) {
924 pos = decode_uint16(pos, end, ¶ms->serial_consistency);
925 }
926 return pos;
927 }
928
decode_query_params_v3v4(const char * input,const char * end,QueryParameters * params)929 const char* decode_query_params_v3v4(const char* input, const char* end, QueryParameters* params) {
930 int8_t flags = 0;
931 const char* pos = input;
932 pos = decode_uint16(pos, end, ¶ms->consistency);
933 pos = decode_int8(pos, end, &flags);
934 params->flags = flags;
935 if (flags & QUERY_FLAG_VALUES && flags & QUERY_FLAG_NAMES_FOR_VALUES) {
936 pos = decode_values_with_names(pos, end, ¶ms->names, ¶ms->values);
937 } else if (flags & QUERY_FLAG_VALUES) {
938 pos = decode_values(pos, end, ¶ms->values);
939 }
940 if (flags & QUERY_FLAG_PAGE_SIZE) {
941 pos = decode_int32(pos, end, ¶ms->result_page_size);
942 }
943 if (flags & QUERY_FLAG_PAGE_STATE) {
944 pos = decode_bytes(pos, end, ¶ms->paging_state);
945 }
946 if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) {
947 pos = decode_uint16(pos, end, ¶ms->serial_consistency);
948 }
949 if (flags & QUERY_FLAG_TIMESTAMP) {
950 pos = decode_int64(pos, end, ¶ms->timestamp);
951 }
952 return pos;
953 }
954
decode_query_params_v5(const char * input,const char * end,QueryParameters * params)955 const char* decode_query_params_v5(const char* input, const char* end, QueryParameters* params) {
956 int32_t flags = 0;
957 const char* pos = input;
958 pos = decode_uint16(pos, end, ¶ms->consistency);
959 pos = decode_int32(pos, end, &flags);
960 params->flags = flags;
961 if (flags & QUERY_FLAG_VALUES && flags & QUERY_FLAG_NAMES_FOR_VALUES) {
962 pos = decode_values_with_names(pos, end, ¶ms->names, ¶ms->values);
963 } else if (flags & QUERY_FLAG_VALUES) {
964 pos = decode_values(pos, end, ¶ms->values);
965 }
966 if (flags & QUERY_FLAG_PAGE_SIZE) {
967 pos = decode_int32(pos, end, ¶ms->result_page_size);
968 }
969 if (flags & QUERY_FLAG_PAGE_STATE) {
970 pos = decode_bytes(pos, end, ¶ms->paging_state);
971 }
972 if (flags & QUERY_FLAG_SERIAL_CONSISTENCY) {
973 pos = decode_uint16(pos, end, ¶ms->serial_consistency);
974 }
975 if (flags & QUERY_FLAG_TIMESTAMP) {
976 pos = decode_int64(pos, end, ¶ms->timestamp);
977 }
978 if (flags & QUERY_FLAG_KEYSPACE) {
979 pos = decode_string(pos, end, ¶ms->keyspace);
980 }
981 return pos;
982 }
983
decode_query_params(int version,const char * input,const char * end,bool is_execute,QueryParameters * params)984 const char* decode_query_params(int version, const char* input, const char* end, bool is_execute,
985 QueryParameters* params) {
986 const char* pos = input;
987 if (version == 1) {
988 pos = decode_query_params_v1(pos, end, is_execute, params);
989 } else if (version == 2) {
990 pos = decode_query_params_v2(pos, end, params);
991 } else if (version == 3 || version == 4) {
992 pos = decode_query_params_v3v4(pos, end, params);
993 } else if (version == 5) {
994 pos = decode_query_params_v5(pos, end, params);
995 } else {
996 assert(false && "Unsupported protocol version");
997 }
998 return pos;
999 }
1000
decode_prepare_params(int version,const char * input,const char * end,PrepareParameters * params)1001 const char* decode_prepare_params(int version, const char* input, const char* end,
1002 PrepareParameters* params) {
1003 const char* pos = input;
1004 if (version >= 5) {
1005 pos = decode_int32(pos, end, ¶ms->flags);
1006 if (params->flags & PREPARE_FLAGS_KEYSPACE) {
1007 pos = decode_string(pos, end, ¶ms->keyspace);
1008 }
1009 }
1010 return pos;
1011 }
1012
encode_int8(int8_t value,String * output)1013 int32_t encode_int8(int8_t value, String* output) {
1014 output->push_back(static_cast<char>(value));
1015 return 1;
1016 }
1017
encode_int16(int16_t value,String * output)1018 int32_t encode_int16(int16_t value, String* output) {
1019 output->push_back(static_cast<char>(value >> 8));
1020 output->push_back(static_cast<char>(value >> 0));
1021 return 2;
1022 }
1023
encode_uint16(uint16_t value,String * output)1024 int32_t encode_uint16(uint16_t value, String* output) {
1025 output->push_back(static_cast<char>(value >> 8));
1026 output->push_back(static_cast<char>(value >> 0));
1027 return 2;
1028 }
1029
encode_int32(int32_t value,String * output)1030 int32_t encode_int32(int32_t value, String* output) {
1031 output->push_back(static_cast<char>(value >> 24));
1032 output->push_back(static_cast<char>(value >> 16));
1033 output->push_back(static_cast<char>(value >> 8));
1034 output->push_back(static_cast<char>(value >> 0));
1035 return 4;
1036 }
1037
encode_string(const String & value,String * output)1038 int32_t encode_string(const String& value, String* output) {
1039 int32_t size = encode_uint16(value.size(), output) + value.size();
1040 output->append(value);
1041 return size + value.size();
1042 }
1043
encode_string_list(const Vector<String> & value,String * output)1044 int32_t encode_string_list(const Vector<String>& value, String* output) {
1045 int32_t size = encode_int16(value.size(), output);
1046 for (Vector<String>::const_iterator it = value.begin(), end = value.end(); it != end; ++it) {
1047 size += encode_string(*it, output);
1048 }
1049 return size;
1050 }
1051
encode_bytes(const String & value,String * output)1052 int32_t encode_bytes(const String& value, String* output) {
1053 int32_t size = encode_int32(value.size(), output) + value.size();
1054 output->append(value);
1055 return size + value.size();
1056 }
1057
encode_inet(const Address & value,String * output)1058 int32_t encode_inet(const Address& value, String* output) {
1059 uint8_t buf[16];
1060 uint8_t len = value.to_inet(buf);
1061 encode_int8(len, output);
1062 for (uint8_t i = 0; i < len; ++i) {
1063 output->push_back(static_cast<char>(buf[i]));
1064 }
1065 encode_int32(value.port(), output);
1066 return 1 + len + 4;
1067 }
1068
encode_uuid(CassUuid uuid,String * output)1069 int32_t encode_uuid(CassUuid uuid, String* output) {
1070 uint64_t time_and_version = uuid.time_and_version;
1071 char buf[16];
1072 buf[3] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1073 time_and_version >>= 8;
1074 buf[2] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1075 time_and_version >>= 8;
1076 buf[1] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1077 time_and_version >>= 8;
1078 buf[0] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1079 time_and_version >>= 8;
1080
1081 buf[5] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1082 time_and_version >>= 8;
1083 buf[4] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1084 time_and_version >>= 8;
1085
1086 buf[7] = static_cast<char>(time_and_version & 0x00000000000000FFLL);
1087 time_and_version >>= 8;
1088 buf[6] = static_cast<char>(time_and_version & 0x000000000000000FFLL);
1089
1090 uint64_t clock_seq_and_node = uuid.clock_seq_and_node;
1091 for (size_t i = 0; i < 8; ++i) {
1092 buf[15 - i] = static_cast<char>(clock_seq_and_node & 0x00000000000000FFL);
1093 clock_seq_and_node >>= 8;
1094 }
1095 output->append(buf, 16);
1096 return 16;
1097 }
1098
encode_string_map(const Map<String,Vector<String>> & value,String * output)1099 int32_t encode_string_map(const Map<String, Vector<String> >& value, String* output) {
1100 int32_t size = encode_uint16(value.size(), output);
1101 for (Map<String, Vector<String> >::const_iterator it = value.begin(); it != value.end(); ++it) {
1102 size += encode_string(it->first, output);
1103 size += encode_string_list(it->second, output);
1104 }
1105 return size;
1106 }
1107
encode_header(int8_t version,int8_t flags,int16_t stream,int8_t opcode,int32_t len)1108 static String encode_header(int8_t version, int8_t flags, int16_t stream, int8_t opcode,
1109 int32_t len) {
1110 String header;
1111 encode_int8(0x80 | version, &header);
1112 encode_int8(flags, &header);
1113 if (version >= 3) {
1114 encode_int16(stream, &header);
1115 } else {
1116 encode_int8(stream, &header);
1117 }
1118 encode_int8(opcode, &header);
1119 if (flags & FLAG_TRACING) {
1120 len += 16; // Add enough space for the tracing ID
1121 }
1122 encode_int32(len, &header);
1123 if (flags & FLAG_TRACING) {
1124 UuidGen gen;
1125 CassUuid tracing_id;
1126 gen.generate_random(&tracing_id);
1127 encode_uuid(tracing_id, &header);
1128 }
1129 return header;
1130 }
1131
text()1132 Type Type::text() { return Type(TYPE_VARCHAR); }
1133
inet()1134 Type Type::inet() { return Type(TYPE_INET); }
1135
uuid()1136 Type Type::uuid() { return Type(TYPE_UUID); }
1137
list(const Type & sub_type)1138 Type Type::list(const Type& sub_type) {
1139 Type type(TYPE_LIST);
1140 type.types_.push_back(sub_type);
1141 return type;
1142 }
1143
encode(int protocol_version,String * output) const1144 void Type::encode(int protocol_version, String* output) const {
1145 switch (type_) {
1146 case TYPE_VARCHAR:
1147 case TYPE_INET:
1148 case TYPE_UUID:
1149 encode_int16(type_, output);
1150 break;
1151 case TYPE_LIST:
1152 encode_int16(type_, output);
1153 types_[0].encode(protocol_version, output);
1154 break;
1155 default:
1156 assert(false && "Unsupported type");
1157 break;
1158 };
1159 }
1160
encode(int protocol_version,String * output) const1161 void Column::encode(int protocol_version, String* output) const {
1162 encode_string(name_, output);
1163 type_.encode(protocol_version, output);
1164 }
1165
encode(int protocol_version,String * output) const1166 void Collection::encode(int protocol_version, String* output) const {
1167 encode_int32(values_.size(), output);
1168 for (Vector<Value>::const_iterator it = values_.begin(), end = values_.end(); it != end; ++it) {
1169 it->encode(protocol_version, output);
1170 }
1171 }
1172
Value()1173 Value::Value()
1174 : type_(NUL) {}
1175
Value(const String & value)1176 Value::Value(const String& value)
1177 : type_(VALUE)
1178 , value_(new String(value)) {}
1179
Value(const Collection & collection)1180 Value::Value(const Collection& collection)
1181 : type_(COLLECTION)
1182 , collection_(new Collection(collection)) {}
1183
Value(const Value & other)1184 Value::Value(const Value& other)
1185 : type_(other.type_) {
1186 if (type_ == VALUE) {
1187 value_ = new String(*other.value_);
1188 } else if (type_ == COLLECTION) {
1189 collection_ = new Collection(*other.collection_);
1190 }
1191 }
1192
~Value()1193 Value::~Value() {
1194 if (type_ == VALUE) {
1195 delete value_;
1196 } else if (type_ == COLLECTION) {
1197 delete collection_;
1198 }
1199 }
1200
encode(int protocol_version,String * output) const1201 void Value::encode(int protocol_version, String* output) const {
1202 if (type_ == NUL) {
1203 encode_int32(-1, output);
1204 } else if (type_ == VALUE) {
1205 encode_bytes(*value_, output);
1206 } else if (type_ == COLLECTION) {
1207 String buf;
1208 collection_->encode(protocol_version, &buf);
1209 encode_bytes(buf, output);
1210 }
1211 }
1212
text(const String & text)1213 Row::Builder& Row::Builder::text(const String& text) {
1214 values_.push_back(Value(text));
1215 return *this;
1216 }
1217
inet(const Address & inet)1218 Row::Builder& Row::Builder::inet(const Address& inet) {
1219 String value;
1220 uint8_t buf[16];
1221 uint8_t len = inet.to_inet(buf);
1222 for (uint8_t i = 0; i < len; ++i) {
1223 value.push_back(static_cast<char>(buf[i]));
1224 }
1225 values_.push_back(Value(value));
1226 return *this;
1227 }
1228
uuid(const CassUuid & uuid)1229 Row::Builder& Row::Builder::uuid(const CassUuid& uuid) {
1230 String value;
1231 encode_uuid(uuid, &value);
1232 values_.push_back(Value(value));
1233 return *this;
1234 }
1235
collection(const Collection & collection)1236 Row::Builder& Row::Builder::collection(const Collection& collection) {
1237 values_.push_back(Value(collection));
1238 return *this;
1239 }
1240
encode(int protocol_version,String * output) const1241 void Row::encode(int protocol_version, String* output) const {
1242 for (Vector<Value>::const_iterator it = values_.begin(), end = values_.end(); it != end; ++it) {
1243 it->encode(protocol_version, output);
1244 }
1245 }
1246
encode(int protocol_version) const1247 String ResultSet::encode(int protocol_version) const {
1248 String body;
1249
1250 encode_int32(RESULT_ROWS, &body); // Result type
1251
1252 encode_int32(RESULT_FLAG_GLOBAL_TABLESPEC, &body); // Flags
1253 encode_int32(columns_.size(), &body); // Column count
1254 encode_string(keyspace_name_, &body); // Global spec keyspace name
1255 encode_string(table_name_, &body); // Global spec table name
1256
1257 // Columns
1258 for (Vector<Column>::const_iterator it = columns_.begin(), end = columns_.end(); it != end;
1259 ++it) {
1260 it->encode(protocol_version, &body);
1261 }
1262
1263 encode_int32(rows_.size(), &body); // Row count
1264
1265 // Rows
1266 for (Vector<Row>::const_iterator it = rows_.begin(), end = rows_.end(); it != end; ++it) {
1267 it->encode(protocol_version, &body);
1268 }
1269
1270 return body;
1271 }
1272
reset()1273 Action::Builder& Action::Builder::reset() {
1274 first_.reset();
1275 last_ = NULL;
1276 return *this;
1277 }
1278
execute(Action * action)1279 Action::Builder& Action::Builder::execute(Action* action) {
1280 if (!first_) {
1281 first_.reset(action);
1282 }
1283 if (last_) {
1284 last_->next = action;
1285 }
1286 last_ = action;
1287 return *this;
1288 }
1289
execute_if(Action * action)1290 Action::Builder& Action::Builder::execute_if(Action* action) {
1291 if (last_ && last_->is_predicate()) {
1292 static_cast<Predicate*>(last_)->then = action;
1293 }
1294 return *this;
1295 }
1296
nop()1297 Action::Builder& Action::Builder::nop() { return execute(new Nop()); }
1298
wait(uint64_t timeout)1299 Action::Builder& Action::Builder::wait(uint64_t timeout) { return execute(new Wait(timeout)); }
1300
close()1301 Action::Builder& Action::Builder::close() { return execute(new Close()); }
1302
error(int32_t code,const String & message)1303 Action::Builder& Action::Builder::error(int32_t code, const String& message) {
1304 return execute(new SendError(code, message));
1305 }
1306
invalid_protocol()1307 Action::Builder& Action::Builder::invalid_protocol() {
1308 return error(ERROR_PROTOCOL_ERROR, "Invalid or unsupported protocol version");
1309 }
1310
invalid_opcode()1311 Action::Builder& Action::Builder::invalid_opcode() {
1312 return error(ERROR_PROTOCOL_ERROR, "Invalid opcode (or not implemented)");
1313 }
1314
ready()1315 Action::Builder& Action::Builder::ready() { return execute(new SendReady()); }
1316
authenticate(const String & class_name)1317 Action::Builder& Action::Builder::authenticate(const String& class_name) {
1318 return execute(new SendAuthenticate(class_name));
1319 }
1320
auth_challenge(const String & token)1321 Action::Builder& Action::Builder::auth_challenge(const String& token) {
1322 return execute(new SendAuthChallenge(token));
1323 }
1324
auth_success(const String & token)1325 Action::Builder& Action::Builder::auth_success(const String& token) {
1326 return execute(new SendAuthSuccess(token));
1327 }
1328
supported()1329 Action::Builder& Action::Builder::supported() { return execute(new SendSupported()); }
1330
up_event(const Address & address)1331 Action::Builder& Action::Builder::up_event(const Address& address) {
1332 return execute(new SendUpEvent(address));
1333 }
1334
void_result()1335 Action::Builder& Action::Builder::void_result() { return execute(new VoidResult()); }
1336
empty_rows_result(int32_t row_count)1337 Action::Builder& Action::Builder::empty_rows_result(int32_t row_count) {
1338 return execute(new EmptyRowsResult(row_count));
1339 }
1340
no_result()1341 Action::Builder& Action::Builder::no_result() { return execute(new NoResult()); }
1342
match_query(const Matches & matches)1343 Action::Builder& Action::Builder::match_query(const Matches& matches) {
1344 return execute(new MatchQuery(matches));
1345 }
1346
client_options()1347 Action::Builder& Action::Builder::client_options() { return execute(new ClientOptions()); }
1348
system_local()1349 Action::Builder& Action::Builder::system_local() { return execute(new SystemLocal()); }
1350
system_local_dse()1351 Action::Builder& Action::Builder::system_local_dse() { return execute(new SystemLocalDse()); }
1352
system_peers()1353 Action::Builder& Action::Builder::system_peers() { return execute(new SystemPeers()); }
1354
system_peers_dse()1355 Action::Builder& Action::Builder::system_peers_dse() { return execute(new SystemPeersDse()); }
1356
system_traces()1357 Action::Builder& Action::Builder::system_traces() { return execute(new SystemTraces()); }
1358
use_keyspace(const String & keyspace)1359 Action::Builder& Action::Builder::use_keyspace(const String& keyspace) {
1360 return execute((new UseKeyspace(keyspace)));
1361 }
1362
use_keyspace(const Vector<String> & keyspaces)1363 Action::Builder& Action::Builder::use_keyspace(const Vector<String>& keyspaces) {
1364 return execute((new UseKeyspace(keyspaces)));
1365 }
1366
plaintext_auth(const String & username,const String & password)1367 Action::Builder& Action::Builder::plaintext_auth(const String& username, const String& password) {
1368 return execute((new PlaintextAuth(username, password)));
1369 }
1370
validate_startup()1371 Action::Builder& Action::Builder::validate_startup() { return execute(new ValidateStartup()); }
1372
validate_credentials()1373 Action::Builder& Action::Builder::validate_credentials() {
1374 return execute(new ValidateCredentials());
1375 }
1376
validate_auth_response()1377 Action::Builder& Action::Builder::validate_auth_response() {
1378 return execute(new ValidateAuthResponse());
1379 }
1380
validate_register()1381 Action::Builder& Action::Builder::validate_register() { return execute(new ValidateRegister()); }
1382
validate_query()1383 Action::Builder& Action::Builder::validate_query() { return execute(new ValidateQuery()); }
1384
set_registered_for_events()1385 Action::Builder& Action::Builder::set_registered_for_events() {
1386 return execute(new SetRegisteredForEvents());
1387 }
1388
set_protocol_version()1389 Action::Builder& Action::Builder::set_protocol_version() {
1390 return execute(new SetProtocolVersion());
1391 }
1392
build()1393 Action* Action::Builder::build() { return first_.release(); }
1394
is_address(const Address & address)1395 Action::PredicateBuilder Action::Builder::is_address(const Address& address) {
1396 return PredicateBuilder(execute(new IsAddress(address)));
1397 }
1398
is_address(const String & address,int port)1399 Action::PredicateBuilder Action::Builder::is_address(const String& address, int port) {
1400 return PredicateBuilder(execute(new IsAddress(Address(address, port))));
1401 }
1402
is_query(const String & query)1403 Action::PredicateBuilder Action::Builder::is_query(const String& query) {
1404 return PredicateBuilder(execute(new IsQuery(query)));
1405 }
1406
run(Request * request) const1407 void Action::run(Request* request) const { on_run(request); }
1408
run_next(Request * request) const1409 void Action::run_next(Request* request) const {
1410 if (next) {
1411 next->on_run(request);
1412 }
1413 }
1414
Request(int8_t version,int8_t flags,int16_t stream,int8_t opcode,const String & body,ClientConnection * client)1415 Request::Request(int8_t version, int8_t flags, int16_t stream, int8_t opcode, const String& body,
1416 ClientConnection* client)
1417 : version_(version)
1418 , flags_(flags)
1419 , stream_(stream)
1420 , opcode_(opcode)
1421 , body_(body)
1422 , client_(client)
1423 , timer_action_(NULL) {
1424 (void)flags_; // TODO: Implement custom payload etc.
1425 }
1426
write(int8_t opcode,const String & body)1427 void Request::write(int8_t opcode, const String& body) { write(stream_, opcode, body); }
1428
write(int16_t stream,int8_t opcode,const String & body)1429 void Request::write(int16_t stream, int8_t opcode, const String& body) {
1430 client_->write(encode_header(version_, flags_, stream, opcode, body.size()) + body);
1431 }
1432
error(int32_t code,const String & message)1433 void Request::error(int32_t code, const String& message) {
1434 String body;
1435 encode_int32(code, &body);
1436 encode_string(message, &body);
1437 write(OPCODE_ERROR, body);
1438 }
1439
wait(uint64_t timeout,const Action * action)1440 void Request::wait(uint64_t timeout, const Action* action) {
1441 inc_ref();
1442 timer_action_ = action;
1443 timer_.start(client_->server()->loop(), timeout, bind_callback(&Request::on_timeout, this));
1444 }
1445
close()1446 void Request::close() { client_->close(); }
1447
decode_startup(Options * options)1448 bool Request::decode_startup(Options* options) {
1449 return decode_string_map(start(), end(), options) == end();
1450 }
1451
decode_credentials(Credentials * creds)1452 bool Request::decode_credentials(Credentials* creds) {
1453 return decode_string_map(start(), end(), creds) == end();
1454 }
1455
decode_auth_response(String * token)1456 bool Request::decode_auth_response(String* token) {
1457 return decode_bytes(start(), end(), token) == end();
1458 }
1459
decode_register(EventTypes * types)1460 bool Request::decode_register(EventTypes* types) {
1461 return decode_stringlist(start(), end(), types) == end();
1462 }
1463
decode_query(String * query,QueryParameters * params)1464 bool Request::decode_query(String* query, QueryParameters* params) {
1465 return decode_query_params(version_, decode_long_string(start(), end(), query), end(), false,
1466 params) == end();
1467 }
1468
decode_execute(String * id,QueryParameters * params)1469 bool Request::decode_execute(String* id, QueryParameters* params) {
1470 return decode_query_params(version_, decode_string(start(), end(), id), end(), true, params) ==
1471 end();
1472 }
1473
decode_prepare(String * query,PrepareParameters * params)1474 bool Request::decode_prepare(String* query, PrepareParameters* params) {
1475 return decode_prepare_params(version_, decode_long_string(start(), end(), query), end(),
1476 params) == end();
1477 }
1478
address() const1479 const Address& Request::address() const { return client_->server()->address(); }
1480
host(const Address & address) const1481 const Host& Request::host(const Address& address) const {
1482 return client_->cluster()->host(address);
1483 }
1484
hosts() const1485 Hosts Request::hosts() const { return client_->cluster()->hosts(); }
1486
on_timeout(Timer * timer)1487 void Request::on_timeout(Timer* timer) {
1488 timer_action_->run_next(this);
1489 dec_ref();
1490 }
1491
on_run(Request * request) const1492 void SendError::on_run(Request* request) const { request->error(code, message); }
1493
on_run(Request * request) const1494 void SendReady::on_run(Request* request) const { request->write(OPCODE_READY, String()); }
1495
on_run(Request * request) const1496 void SendAuthenticate::on_run(Request* request) const {
1497 String body;
1498 encode_string(class_name, &body);
1499 request->write(OPCODE_AUTHENTICATE, body);
1500 }
1501
on_run(Request * request) const1502 void SendAuthChallenge::on_run(Request* request) const {
1503 String body;
1504 encode_string(token, &body);
1505 request->write(OPCODE_AUTH_CHALLENGE, body);
1506 }
1507
on_run(Request * request) const1508 void SendAuthSuccess::on_run(Request* request) const {
1509 String body;
1510 encode_string(token, &body);
1511 request->write(OPCODE_AUTH_SUCCESS, body);
1512 }
1513
on_run(Request * request) const1514 void SendSupported::on_run(Request* request) const {
1515 String body;
1516 encode_uint16(0, &body);
1517 request->write(OPCODE_SUPPORTED, body);
1518 }
1519
on_run(Request * request) const1520 void SendUpEvent::on_run(Request* request) const {
1521 request->write(-1, OPCODE_EVENT, StatusChangeEvent::encode(StatusChangeEvent::UP, address));
1522 run_next(request);
1523 }
1524
on_run(Request * request) const1525 void VoidResult::on_run(Request* request) const {
1526 String body;
1527 encode_int32(RESULT_VOID, &body);
1528 request->write(OPCODE_RESULT, body);
1529 }
1530
on_run(Request * request) const1531 void EmptyRowsResult::on_run(Request* request) const {
1532 String query;
1533 QueryParameters params;
1534 if (!request->decode_query(&query, ¶ms)) {
1535 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1536 } else {
1537 String body;
1538 encode_int32(RESULT_ROWS, &body);
1539 encode_int32(0, &body); // Flags
1540 encode_int32(0, &body); // Column count
1541 encode_int32(row_count, &body); // Row count
1542 request->write(OPCODE_RESULT, body);
1543 }
1544 }
1545
on_run(Request * request) const1546 void NoResult::on_run(Request* request) const {}
1547
on_run(Request * request) const1548 void MatchQuery::on_run(Request* request) const {
1549 String query;
1550 QueryParameters params;
1551 if (!request->decode_query(&query, ¶ms)) {
1552 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1553 return;
1554 } else {
1555 for (Matches::const_iterator it = matches.begin(), end = matches.end(); it != end; ++it) {
1556 if (it->first == query) {
1557 request->write(OPCODE_RESULT, it->second.encode(request->version()));
1558 return;
1559 }
1560 }
1561 }
1562 run_next(request);
1563 }
1564
on_run(Request * request) const1565 void ClientOptions::on_run(Request* request) const {
1566 String query;
1567 QueryParameters params;
1568 if (!request->decode_query(&query, ¶ms)) {
1569 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1570 } else if (query == CLIENT_OPTIONS_QUERY) {
1571 ResultSet::Builder builder("client", "options");
1572 Row::Builder row_builder;
1573 for (Options::const_iterator it = request->client()->options().begin(),
1574 end = request->client()->options().end();
1575 it != end; ++it) {
1576 builder.column((*it).first, Type::text());
1577 row_builder.text((*it).second);
1578 }
1579 ResultSet client_options = builder.row(row_builder.build()).build();
1580
1581 request->write(OPCODE_RESULT, client_options.encode(request->version()));
1582 } else {
1583 run_next(request);
1584 }
1585 }
1586
on_run(Request * request) const1587 void SystemLocal::on_run(Request* request) const {
1588 String query;
1589 QueryParameters params;
1590 if (!request->decode_query(&query, ¶ms)) {
1591 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1592 } else if (query.find(SELECT_LOCAL) != String::npos) {
1593 const Host& host(request->host(request->address()));
1594
1595 ResultSet local_rs = ResultSet::Builder("system", "local")
1596 .column("key", Type::text())
1597 .column("data_center", Type::text())
1598 .column("rack", Type::text())
1599 .column("release_version", Type::text())
1600 .column("rpc_address", Type::inet())
1601 .column("partitioner", Type::text())
1602 .column("tokens", Type::list(Type::text()))
1603 .row(Row::Builder()
1604 .text(request->client()->server()->address().to_string())
1605 .text(host.dc)
1606 .text(host.rack)
1607 .text(CASSANDRA_VERSION)
1608 .inet(request->client()->server()->address())
1609 .text(host.partitioner)
1610 .collection(Collection::text(host.tokens))
1611 .build())
1612 .build();
1613
1614 request->write(OPCODE_RESULT, local_rs.encode(request->version()));
1615 } else {
1616 run_next(request);
1617 }
1618 }
1619
on_run(Request * request) const1620 void SystemLocalDse::on_run(Request* request) const {
1621 String query;
1622 QueryParameters params;
1623 if (!request->decode_query(&query, ¶ms)) {
1624 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1625 } else if (query.find(SELECT_LOCAL) != String::npos) {
1626 const Host& host(request->host(request->address()));
1627
1628 ResultSet local_rs = ResultSet::Builder("system", "local")
1629 .column("key", Type::text())
1630 .column("data_center", Type::text())
1631 .column("rack", Type::text())
1632 .column("dse_version", Type::text())
1633 .column("release_version", Type::text())
1634 .column("rpc_address", Type::inet())
1635 .column("partitioner", Type::text())
1636 .column("tokens", Type::list(Type::text()))
1637 .row(Row::Builder()
1638 .text(request->client()->server()->address().to_string())
1639 .text(host.dc)
1640 .text(host.rack)
1641 .text(DSE_VERSION)
1642 .text(DSE_CASSANDRA_VERSION)
1643 .inet(request->client()->server()->address())
1644 .text(host.partitioner)
1645 .collection(Collection::text(host.tokens))
1646 .build())
1647 .build();
1648
1649 request->write(OPCODE_RESULT, local_rs.encode(request->version()));
1650 } else {
1651 run_next(request);
1652 }
1653 }
1654
on_run(Request * request) const1655 void SystemPeers::on_run(Request* request) const {
1656 String query;
1657 QueryParameters params;
1658 if (!request->decode_query(&query, ¶ms)) {
1659 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1660 } else if (query.find(SELECT_PEERS) != String::npos) {
1661 const String where_clause(" WHERE peer = '");
1662 ResultSet::Builder peers_builder = ResultSet::Builder("system", "peers")
1663 .column("peer", Type::inet())
1664 .column("data_center", Type::text())
1665 .column("rack", Type::text())
1666 .column("release_version", Type::text())
1667 .column("rpc_address", Type::inet())
1668 .column("tokens", Type::list(Type::text()));
1669
1670 size_t pos = query.find(where_clause);
1671 if (pos == String::npos) {
1672 Hosts hosts(request->hosts());
1673 for (Hosts::const_iterator it = hosts.begin(), end = hosts.end(); it != end; ++it) {
1674 const Host& host(*it);
1675 if (host.address == request->address()) {
1676 continue;
1677 }
1678 peers_builder.row(Row::Builder()
1679 .inet(host.address)
1680 .text(host.dc)
1681 .text(host.rack)
1682 .text(CASSANDRA_VERSION)
1683 .inet(host.address)
1684 .collection(Collection::text(host.tokens))
1685 .build());
1686 }
1687 ResultSet peers_rs = peers_builder.build();
1688 request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1689 } else {
1690 pos += where_clause.size();
1691 size_t end_pos = query.find("'", pos);
1692 if (end_pos == String::npos) {
1693 request->error(ERROR_INVALID_QUERY, "Invalid WHERE clause");
1694 return;
1695 }
1696
1697 String ip = query.substr(pos, end_pos - pos);
1698 Address address(ip, request->address().port());
1699 if (!address.is_valid_and_resolved()) {
1700 request->error(ERROR_INVALID_QUERY, "Invalid inet address in WHERE clause");
1701 return;
1702 }
1703
1704 const Host& host(request->host(address));
1705 ResultSet peers_rs = peers_builder
1706 .row(Row::Builder()
1707 .inet(host.address)
1708 .text(host.dc)
1709 .text(host.rack)
1710 .text(CASSANDRA_VERSION)
1711 .inet(host.address)
1712 .collection(Collection::text(host.tokens))
1713 .build())
1714 .build();
1715 request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1716 }
1717 } else {
1718 run_next(request);
1719 }
1720 }
1721
on_run(Request * request) const1722 void SystemPeersDse::on_run(Request* request) const {
1723 String query;
1724 QueryParameters params;
1725 if (!request->decode_query(&query, ¶ms)) {
1726 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1727 } else if (query.find(SELECT_PEERS) != String::npos) {
1728 const String where_clause(" WHERE peer = '");
1729 ResultSet::Builder peers_builder = ResultSet::Builder("system", "peers")
1730 .column("peer", Type::inet())
1731 .column("data_center", Type::text())
1732 .column("rack", Type::text())
1733 .column("dse_version", Type::text())
1734 .column("release_version", Type::text())
1735 .column("rpc_address", Type::inet())
1736 .column("tokens", Type::list(Type::text()));
1737
1738 size_t pos = query.find(where_clause);
1739 if (pos == String::npos) {
1740 Hosts hosts(request->hosts());
1741 for (Hosts::const_iterator it = hosts.begin(), end = hosts.end(); it != end; ++it) {
1742 const Host& host(*it);
1743 if (host.address == request->address()) {
1744 continue;
1745 }
1746 peers_builder.row(Row::Builder()
1747 .inet(host.address)
1748 .text(host.dc)
1749 .text(host.rack)
1750 .text(DSE_VERSION)
1751 .text(DSE_CASSANDRA_VERSION)
1752 .inet(host.address)
1753 .collection(Collection::text(host.tokens))
1754 .build());
1755 }
1756 ResultSet peers_rs = peers_builder.build();
1757 request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1758 } else {
1759 pos += where_clause.size();
1760 size_t end_pos = query.find("'", pos);
1761 if (end_pos == String::npos) {
1762 request->error(ERROR_INVALID_QUERY, "Invalid WHERE clause");
1763 return;
1764 }
1765
1766 String ip = query.substr(pos, end_pos - pos);
1767 Address address(ip, request->address().port());
1768 if (!address.is_valid_and_resolved()) {
1769 request->error(ERROR_INVALID_QUERY, "Invalid inet address in WHERE clause");
1770 return;
1771 }
1772
1773 const Host& host(request->host(address));
1774 ResultSet peers_rs = peers_builder
1775 .row(Row::Builder()
1776 .inet(host.address)
1777 .text(host.dc)
1778 .text(host.rack)
1779 .text(CASSANDRA_VERSION)
1780 .inet(host.address)
1781 .collection(Collection::text(host.tokens))
1782 .build())
1783 .build();
1784 request->write(OPCODE_RESULT, peers_rs.encode(request->version()));
1785 }
1786 } else {
1787 run_next(request);
1788 }
1789 }
1790
on_run(Request * request) const1791 void SystemTraces::on_run(Request* request) const {
1792 String query;
1793 QueryParameters params;
1794 if (!request->decode_query(&query, ¶ms)) {
1795 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1796 } else if (query.find(SELECT_TRACES_SESSION) != String::npos) {
1797 if (params.values.empty() || params.values.front().size() < 16) {
1798 request->error(ERROR_INVALID_QUERY, "Query expects a UUID parameter (tracing)");
1799 return;
1800 }
1801 CassUuid tracing_id;
1802 decode_uuid(params.values.front().data(), &tracing_id);
1803 ResultSet session_rs = ResultSet::Builder("system_traces", "session")
1804 .column("session_id", Type::uuid())
1805 .row(Row::Builder().uuid(tracing_id).build())
1806 .build();
1807 request->write(OPCODE_RESULT, session_rs.encode(request->version()));
1808 } else {
1809 run_next(request);
1810 }
1811 }
1812
on_run(Request * request) const1813 void UseKeyspace::on_run(Request* request) const {
1814 String query;
1815 QueryParameters params;
1816 if (request->decode_query(&query, ¶ms)) {
1817 trim(query);
1818 if (query.compare(0, 3, "USE") == 0 || query.compare(0, 3, "use") == 0) {
1819 String keyspace(query.substr(query.find_first_not_of(" \t", 3)));
1820 for (Vector<String>::const_iterator it = keyspaces.begin(), end = keyspaces.end(); it != end;
1821 ++it) {
1822 String temp(*it);
1823 if (keyspace == escape_id(temp)) {
1824 String body;
1825 encode_int32(RESULT_SET_KEYSPACE, &body);
1826 encode_string(*it, &body);
1827 request->client()->set_keyspace(*it);
1828 request->write(OPCODE_RESULT, body);
1829 return;
1830 }
1831 }
1832 request->error(ERROR_INVALID_QUERY, "Keyspace '" + keyspace + "' does not exist");
1833 } else {
1834 run_next(request);
1835 }
1836 } else {
1837 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1838 }
1839 }
1840
on_run(Request * request) const1841 void PlaintextAuth::on_run(Request* request) const {
1842 String token;
1843 if (request->decode_auth_response(&token)) {
1844 String username, password;
1845 String::const_reverse_iterator last = token.rbegin();
1846 enum { PASSWORD, USERNAME } state = PASSWORD;
1847 for (String::const_reverse_iterator it = token.rbegin(), end = token.rend(); it != end; ++it) {
1848 if (*it == '\0') {
1849 if (state == PASSWORD) {
1850 password.assign(it.base(), last.base());
1851 state = USERNAME;
1852 } else if (state == USERNAME) {
1853 username.assign(it.base(), last.base());
1854 break;
1855 }
1856 last = it + 1;
1857 }
1858 }
1859
1860 if (username == this->username && password == this->password) {
1861 String body;
1862 encode_int32(-1, &body); // Null bytes
1863 request->write(OPCODE_AUTH_SUCCESS, body);
1864 } else {
1865 request->error(ERROR_BAD_CREDENTIALS, "Invalid credentials");
1866 }
1867 } else {
1868 request->error(ERROR_PROTOCOL_ERROR, "Invalid auth response message");
1869 }
1870 }
1871
on_run(Request * request) const1872 void ValidateStartup::on_run(Request* request) const {
1873 Options options;
1874 if (!request->decode_startup(&options)) {
1875 request->error(ERROR_PROTOCOL_ERROR, "Invalid startup message");
1876 } else {
1877 request->client()->set_options(options);
1878 run_next(request);
1879 }
1880 }
1881
on_run(Request * request) const1882 void ValidateCredentials::on_run(Request* request) const {
1883 Credentials creds;
1884 if (!request->decode_credentials(&creds)) {
1885 request->error(ERROR_PROTOCOL_ERROR, "Invalid credentials message");
1886 } else {
1887 run_next(request);
1888 }
1889 }
1890
on_run(Request * request) const1891 void ValidateAuthResponse::on_run(Request* request) const {
1892 String token;
1893 if (!request->decode_auth_response(&token)) {
1894 request->error(ERROR_PROTOCOL_ERROR, "Invalid auth response message");
1895 } else {
1896 run_next(request);
1897 }
1898 }
1899
on_run(Request * request) const1900 void ValidateRegister::on_run(Request* request) const {
1901 EventTypes types;
1902 if (!request->decode_register(&types)) {
1903 request->error(ERROR_PROTOCOL_ERROR, "Invalid register message");
1904 } else {
1905 run_next(request);
1906 }
1907 }
1908
on_run(Request * request) const1909 void ValidateQuery::on_run(Request* request) const {
1910 String query;
1911 QueryParameters params;
1912 if (!request->decode_query(&query, ¶ms)) {
1913 request->error(ERROR_PROTOCOL_ERROR, "Invalid query message");
1914 } else {
1915 run_next(request);
1916 }
1917 }
1918
on_run(Request * request) const1919 void SetRegisteredForEvents::on_run(Request* request) const {
1920 request->client()->set_registered_for_events();
1921 run_next(request);
1922 }
1923
on_run(Request * request) const1924 void SetProtocolVersion::on_run(Request* request) const {
1925 request->client()->set_protocol_version(request->version());
1926 run_next(request);
1927 }
1928
is_true(Request * request) const1929 bool IsAddress::is_true(Request* request) const {
1930 return request->client()->server()->address() == address;
1931 }
1932
is_true(Request * request) const1933 bool IsQuery::is_true(Request* request) const {
1934 String query;
1935 QueryParameters params;
1936 return request->decode_query(&query, ¶ms) && query == this->query;
1937 }
1938
RequestHandler(RequestHandler::Builder * builder,int lowest_supported_protocol_version,int highest_supported_protocol_version)1939 RequestHandler::RequestHandler(RequestHandler::Builder* builder,
1940 int lowest_supported_protocol_version,
1941 int highest_supported_protocol_version)
1942 : invalid_protocol_(builder->on_invalid_protocol().build())
1943 , invalid_opcode_(builder->on_invalid_opcode().build())
1944 , lowest_supported_protocol_version_(lowest_supported_protocol_version)
1945 , highest_supported_protocol_version_(highest_supported_protocol_version) {}
1946
build()1947 const RequestHandler* RequestHandler::Builder::build() {
1948 RequestHandler* handler(new RequestHandler(this, lowest_supported_protocol_version_,
1949 highest_supported_protocol_version_));
1950
1951 handler->actions_[OPCODE_STARTUP].reset(actions_[OPCODE_STARTUP].build());
1952 handler->actions_[OPCODE_OPTIONS].reset(actions_[OPCODE_OPTIONS].build());
1953 handler->actions_[OPCODE_CREDENTIALS].reset(actions_[OPCODE_CREDENTIALS].build());
1954 handler->actions_[OPCODE_QUERY].reset(actions_[OPCODE_QUERY].build());
1955 handler->actions_[OPCODE_PREPARE].reset(actions_[OPCODE_PREPARE].build());
1956 handler->actions_[OPCODE_EXECUTE].reset(actions_[OPCODE_EXECUTE].build());
1957 handler->actions_[OPCODE_REGISTER].reset(actions_[OPCODE_REGISTER].build());
1958 handler->actions_[OPCODE_AUTH_RESPONSE].reset(actions_[OPCODE_AUTH_RESPONSE].build());
1959
1960 return handler;
1961 }
1962
decode(ClientConnection * client,const char * data,int32_t len)1963 void ProtocolHandler::decode(ClientConnection* client, const char* data, int32_t len) {
1964 buffer_.append(data, len);
1965 int32_t result = decode_frame(client, buffer_.data(), buffer_.size());
1966 if (result > 0) {
1967 if (static_cast<size_t>(result) == buffer_.size()) {
1968 buffer_.clear();
1969 } else {
1970 // Not efficient, but concise. Copy the consumed part of the buffer
1971 // forward then resize the buffer to what's left over.
1972 std::copy(buffer_.begin() + result, buffer_.end(), buffer_.begin());
1973 buffer_.resize(buffer_.size() - result);
1974 }
1975 }
1976 }
1977
decode_frame(ClientConnection * client,const char * frame,int32_t len)1978 int32_t ProtocolHandler::decode_frame(ClientConnection* client, const char* frame, int32_t len) {
1979 const char* pos = frame;
1980 const char* end = pos + len;
1981 int32_t remaining = len;
1982
1983 while (remaining > 0) {
1984 switch (state_) {
1985 case PROTOCOL_VERSION:
1986 // Version requires a single byte and that's guaranteed by the loop check.
1987 version_ = *pos++;
1988 remaining--;
1989 if (version_ < request_handler_->lowest_supported_protocol_version() ||
1990 version_ > request_handler_->highest_supported_protocol_version()) {
1991 // Use the highest supported protocol version unless it's less than the lowest supported
1992 // then use the original request's protocol version.
1993 int8_t response_version = request_handler_->highest_supported_protocol_version();
1994 if (version_ < request_handler_->lowest_supported_protocol_version()) {
1995 response_version = version_;
1996 }
1997 Request::Ptr request(
1998 new Request(response_version, flags_, stream_, opcode_, String(), client));
1999 request_handler_->invalid_protocol(request.get());
2000 client->close();
2001 return -1;
2002 }
2003 state_ = HEADER;
2004 break;
2005 case HEADER:
2006 if ((version_ == 1 || version_ == 2) && remaining >= 7) {
2007 flags_ = *pos++;
2008 stream_ = *pos++;
2009 opcode_ = *pos++;
2010 pos = decode_int32(pos, end, &length_);
2011 remaining -= 7;
2012 } else if (version_ >= 3 && remaining >= 8) {
2013 flags_ = *pos++;
2014 pos = decode_int16(pos, end, &stream_);
2015 opcode_ = *pos++;
2016 pos = decode_int32(pos, end, &length_);
2017 remaining -= 8;
2018 } else {
2019 return len - remaining;
2020 }
2021
2022 if (length_ == 0) {
2023 decode_body(client, pos, 0);
2024 version_ = 0;
2025 flags_ = 0;
2026 opcode_ = 0;
2027 length_ = 0;
2028 state_ = PROTOCOL_VERSION;
2029 } else {
2030 state_ = BODY;
2031 }
2032 break;
2033 case BODY:
2034 if (remaining >= length_) {
2035 decode_body(client, pos, length_);
2036 pos += length_;
2037 remaining -= length_;
2038 } else {
2039 return len - remaining;
2040 }
2041 version_ = 0;
2042 flags_ = 0;
2043 opcode_ = 0;
2044 length_ = 0;
2045 state_ = PROTOCOL_VERSION;
2046 break;
2047 }
2048 }
2049
2050 return len; // All bytes have been consumed
2051 }
2052
decode_body(ClientConnection * client,const char * body,int32_t len)2053 void ProtocolHandler::decode_body(ClientConnection* client, const char* body, int32_t len) {
2054 Request::Ptr request(new Request(version_, flags_, stream_, opcode_, String(body, len), client));
2055 request_handler_->run(request.get());
2056 }
2057
on_read(const char * data,size_t len)2058 void ClientConnection::on_read(const char* data, size_t len) { handler_.decode(this, data, len); }
2059
Event(const String & event_body)2060 Event::Event(const String& event_body)
2061 : event_body_(event_body) {}
2062
run(internal::ServerConnection * server_connection)2063 void Event::run(internal::ServerConnection* server_connection) {
2064 for (internal::ClientConnections::const_iterator it = server_connection->clients().begin(),
2065 end = server_connection->clients().end();
2066 it != end; ++it) {
2067 ClientConnection* client = static_cast<ClientConnection*>(*it);
2068 if (client->is_registered_for_events() && client->protocol_version() > 0) {
2069 client->write(
2070 encode_header(client->protocol_version(), 0, -1, OPCODE_EVENT, event_body_.size()) +
2071 event_body_);
2072 }
2073 }
2074 }
2075
new_node(const Address & address)2076 Event::Ptr TopologyChangeEvent::new_node(const Address& address) {
2077 return Ptr(new TopologyChangeEvent(NEW_NODE, address));
2078 }
2079
moved_node(const Address & address)2080 Event::Ptr TopologyChangeEvent::moved_node(const Address& address) {
2081 return Ptr(new TopologyChangeEvent(MOVED_NODE, address));
2082 }
2083
removed_node(const Address & address)2084 Event::Ptr TopologyChangeEvent::removed_node(const Address& address) {
2085 return Ptr(new TopologyChangeEvent(REMOVED_NODE, address));
2086 }
2087
encode(TopologyChangeEvent::Type type,const Address & address)2088 String TopologyChangeEvent::encode(TopologyChangeEvent::Type type, const Address& address) {
2089 String body;
2090 encode_string("TOPOLOGY_CHANGE", &body);
2091 switch (type) {
2092 case NEW_NODE:
2093 encode_string("NEW_NODE", &body);
2094 break;
2095 case MOVED_NODE:
2096 encode_string("MOVED_NODE", &body);
2097 break;
2098 case REMOVED_NODE:
2099 encode_string("REMOVED_NODE", &body);
2100 break;
2101 };
2102 encode_inet(address, &body);
2103 return body;
2104 }
2105
up(const Address & address)2106 Event::Ptr StatusChangeEvent::up(const Address& address) {
2107 return Ptr(new StatusChangeEvent(UP, address));
2108 }
2109
down(const Address & address)2110 Event::Ptr StatusChangeEvent::down(const Address& address) {
2111 return Ptr(new StatusChangeEvent(DOWN, address));
2112 }
2113
encode(Type type,const Address & address)2114 String StatusChangeEvent::encode(Type type, const Address& address) {
2115 String body;
2116 encode_string("STATUS_CHANGE", &body);
2117 switch (type) {
2118 case UP:
2119 encode_string("UP", &body);
2120 break;
2121 case DOWN:
2122 encode_string("DOWN", &body);
2123 break;
2124 };
2125 encode_inet(address, &body);
2126 return body;
2127 }
2128
keyspace(SchemaChangeEvent::Type type,const String & keyspace_name)2129 Event::Ptr SchemaChangeEvent::keyspace(SchemaChangeEvent::Type type, const String& keyspace_name) {
2130 return Ptr(new SchemaChangeEvent(KEYSPACE, type, keyspace_name));
2131 }
2132
table(SchemaChangeEvent::Type type,const String & keyspace_name,const String & table_name)2133 Event::Ptr SchemaChangeEvent::table(SchemaChangeEvent::Type type, const String& keyspace_name,
2134 const String& table_name) {
2135 return Ptr(new SchemaChangeEvent(TABLE, type, keyspace_name, table_name));
2136 }
2137
user_type(SchemaChangeEvent::Type type,const String & keyspace_name,const String & user_type_name)2138 Event::Ptr SchemaChangeEvent::user_type(SchemaChangeEvent::Type type, const String& keyspace_name,
2139 const String& user_type_name) {
2140 return Ptr(new SchemaChangeEvent(USER_TYPE, type, keyspace_name, user_type_name));
2141 }
2142
function(SchemaChangeEvent::Type type,const String & keyspace_name,const String & function_name,const Vector<String> & args_types)2143 Event::Ptr SchemaChangeEvent::function(SchemaChangeEvent::Type type, const String& keyspace_name,
2144 const String& function_name,
2145 const Vector<String>& args_types) {
2146 return Ptr(new SchemaChangeEvent(FUNCTION, type, keyspace_name, function_name, args_types));
2147 }
2148
aggregate(SchemaChangeEvent::Type type,const String & keyspace_name,const String & aggregate_name,const Vector<String> & args_types)2149 Event::Ptr SchemaChangeEvent::aggregate(SchemaChangeEvent::Type type, const String& keyspace_name,
2150 const String& aggregate_name,
2151 const Vector<String>& args_types) {
2152 return Ptr(new SchemaChangeEvent(AGGREGATE, type, keyspace_name, aggregate_name, args_types));
2153 }
2154
encode(SchemaChangeEvent::Target target,SchemaChangeEvent::Type type,const String & keyspace_name,const String & target_name,const Vector<String> & arg_types)2155 String SchemaChangeEvent::encode(SchemaChangeEvent::Target target, SchemaChangeEvent::Type type,
2156 const String& keyspace_name, const String& target_name,
2157 const Vector<String>& arg_types) {
2158 String body;
2159 encode_string("SCHEMA_CHANGE", &body);
2160 switch (type) {
2161 case CREATED:
2162 encode_string("CREATED", &body);
2163 break;
2164 case UPDATED:
2165 encode_string("UPDATED", &body);
2166 break;
2167 case DROPPED:
2168 encode_string("DROPPED", &body);
2169 break;
2170 }
2171 switch (target) {
2172 case KEYSPACE:
2173 encode_string("KEYSPACE", &body);
2174 encode_string(keyspace_name, &body);
2175 break;
2176 case TABLE:
2177 encode_string("TABLE", &body);
2178 encode_string(keyspace_name, &body);
2179 encode_string(target_name, &body);
2180 break;
2181 ;
2182 case USER_TYPE:
2183 encode_string("TYPE", &body);
2184 encode_string(keyspace_name, &body);
2185 encode_string(target_name, &body);
2186 break;
2187 case FUNCTION:
2188 encode_string("FUNCTION", &body);
2189 encode_string(keyspace_name, &body);
2190 encode_string(target_name, &body);
2191 encode_string_list(arg_types, &body);
2192 break;
2193 case AGGREGATE:
2194 encode_string("AGGREGATE", &body);
2195 encode_string(keyspace_name, &body);
2196 encode_string(target_name, &body);
2197 encode_string_list(arg_types, &body);
2198 break;
2199 }
2200 return body;
2201 }
2202
init(AddressGenerator & generator,ClientConnectionFactory & factory,size_t num_nodes_dc1,size_t num_nodes_dc2)2203 void Cluster::init(AddressGenerator& generator, ClientConnectionFactory& factory,
2204 size_t num_nodes_dc1, size_t num_nodes_dc2) {
2205 for (size_t i = 0; i < num_nodes_dc1; ++i) {
2206 create_and_add_server(generator, factory, "dc1");
2207 }
2208 for (size_t i = 0; i < num_nodes_dc2; ++i) {
2209 create_and_add_server(generator, factory, "dc2");
2210 }
2211 }
2212
~Cluster()2213 Cluster::~Cluster() { stop_all(); }
2214
use_ssl(const String & cn)2215 String Cluster::use_ssl(const String& cn /*= ""*/) {
2216 String key(Ssl::generate_key());
2217 String cert(Ssl::generate_cert(key, cn));
2218 for (size_t i = 0; i < servers_.size(); ++i) {
2219 Server& server = servers_[i];
2220 if (!server.connection->use_ssl(key, cert)) {
2221 return "";
2222 }
2223 }
2224 return cert;
2225 }
2226
start_all(EventLoopGroup * event_loop_group)2227 int Cluster::start_all(EventLoopGroup* event_loop_group) {
2228 start_all_async(event_loop_group);
2229 for (size_t i = 0; i < servers_.size(); ++i) {
2230 Server& server = servers_[i];
2231 int rc = server.connection->wait_listen();
2232 if (rc != 0) return rc;
2233 }
2234 return 0;
2235 }
2236
start_all_async(EventLoopGroup * event_loop_group)2237 void Cluster::start_all_async(EventLoopGroup* event_loop_group) {
2238 for (size_t i = 0; i < servers_.size(); ++i) {
2239 Server& server = servers_[i];
2240 server.connection->listen(event_loop_group);
2241 }
2242 }
2243
stop_all()2244 void Cluster::stop_all() {
2245 stop_all_async();
2246 for (size_t i = 0; i < servers_.size(); ++i) {
2247 Server& server = servers_[i];
2248 server.connection->wait_close();
2249 }
2250 }
2251
stop_all_async()2252 void Cluster::stop_all_async() {
2253 for (size_t i = 0; i < servers_.size(); ++i) {
2254 Server& server = servers_[i];
2255 server.connection->close();
2256 }
2257 }
2258
start(EventLoopGroup * event_loop_group,size_t node)2259 int Cluster::start(EventLoopGroup* event_loop_group, size_t node) {
2260 if (node < 1 || node > servers_.size()) {
2261 return -1;
2262 }
2263 Server& server = servers_[node - 1];
2264 server.connection->listen(event_loop_group);
2265 return server.connection->wait_listen();
2266 }
2267
start_async(EventLoopGroup * event_loop_group,size_t node)2268 void Cluster::start_async(EventLoopGroup* event_loop_group, size_t node) {
2269 if (node < 1 || node > servers_.size()) {
2270 return;
2271 }
2272 Server& server = servers_[node - 1];
2273 server.connection->listen(event_loop_group);
2274 }
2275
stop(size_t node)2276 void Cluster::stop(size_t node) {
2277 if (node < 1 || node > servers_.size()) {
2278 return;
2279 }
2280 Server& server = servers_[node - 1];
2281 server.connection->close();
2282 server.connection->wait_close();
2283 }
2284
stop_async(size_t node)2285 void Cluster::stop_async(size_t node) {
2286 if (node < 1 || node > servers_.size()) {
2287 return;
2288 }
2289 Server& server = servers_[node - 1];
2290 server.connection->close();
2291 }
2292
add(EventLoopGroup * event_loop_group,size_t node)2293 int Cluster::add(EventLoopGroup* event_loop_group, size_t node) {
2294 if (node < 1 || node > servers_.size()) {
2295 return -1;
2296 }
2297 Server& server = servers_[node - 1];
2298 bool is_removed = server.is_removed.exchange(false);
2299 server.connection->listen(event_loop_group);
2300 int rc = server.connection->wait_listen();
2301
2302 // Send the added node event after starting the socket
2303 if (is_removed) { // Only send topology change event if node was previously removed
2304 event(TopologyChangeEvent::new_node(server.connection->address()));
2305 }
2306
2307 return rc;
2308 }
2309
remove(size_t node)2310 void Cluster::remove(size_t node) {
2311 if (node < 1 || node > servers_.size()) {
2312 return;
2313 }
2314 Server& server = servers_[node - 1];
2315 bool is_removed = server.is_removed.exchange(true);
2316
2317 // Send the remove node event before closing the socket
2318 if (!is_removed) { // Only send the topology change event if node was previously active
2319 event(TopologyChangeEvent::removed_node(server.connection->address()));
2320 }
2321
2322 server.connection->close();
2323 server.connection->wait_close();
2324 }
2325
host(const Address & address) const2326 const Host& Cluster::host(const Address& address) const {
2327 for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) {
2328 if (it->host.address == address) {
2329 return it->host;
2330 }
2331 }
2332
2333 throw Exception(ERROR_PROTOCOL_ERROR, "Unable to find host " + address.to_string());
2334 }
2335
hosts() const2336 Hosts Cluster::hosts() const {
2337 Hosts hosts;
2338 hosts.reserve(servers_.size());
2339 for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) {
2340 if (!it->is_removed.load()) {
2341 hosts.push_back(it->host);
2342 }
2343 }
2344 return hosts;
2345 }
2346
connection_attempts(size_t node) const2347 unsigned Cluster::connection_attempts(size_t node) const {
2348 if (node < 1 || node > servers_.size()) {
2349 return 0;
2350 }
2351 const Server& server = servers_[node - 1];
2352 return server.connection->connection_attempts();
2353 }
2354
create_and_add_server(AddressGenerator & generator,ClientConnectionFactory & factory,const String & dc)2355 int Cluster::create_and_add_server(AddressGenerator& generator, ClientConnectionFactory& factory,
2356 const String& dc) {
2357 Address address(generator.next());
2358 Server server(Host(address, dc, "rack1", token_rng_),
2359 internal::ServerConnection::Ptr(new internal::ServerConnection(address, factory)));
2360
2361 servers_.push_back(server);
2362 return static_cast<int>(servers_.size());
2363 }
2364
event(const Event::Ptr & event)2365 void Cluster::event(const Event::Ptr& event) {
2366 for (Servers::const_iterator it = servers_.begin(), end = servers_.end(); it != end; ++it) {
2367 it->connection->run(internal::ServerConnectionTask::Ptr(event));
2368 }
2369 }
2370
next()2371 Address Ipv4AddressGenerator::next() {
2372 char buf[32];
2373 sprintf(buf, "%d.%d.%d.%d", (ip_ >> 24) & 0xff, (ip_ >> 16) & 0xff, (ip_ >> 8) & 0xff,
2374 ip_ & 0xff);
2375 ip_++;
2376 return Address(buf, port_);
2377 }
2378
Host(const Address & address,const String & dc,const String & rack,MT19937_64 & token_rng,int num_tokens)2379 Host::Host(const Address& address, const String& dc, const String& rack, MT19937_64& token_rng,
2380 int num_tokens)
2381 : address(address)
2382 , dc(dc)
2383 , rack(rack)
2384 , partitioner("org.apache.cassandra.dht.Murmur3Partitioner") {
2385 // Only murmur tokens currently supported
2386 for (int i = 0; i < num_tokens; ++i) {
2387 OStringStream ss;
2388 ss << static_cast<int64_t>(token_rng());
2389 tokens.push_back(ss.str());
2390 }
2391 }
2392
SimpleEventLoopGroup(size_t num_threads,const String & thread_name)2393 SimpleEventLoopGroup::SimpleEventLoopGroup(size_t num_threads,
2394 const String& thread_name /*= "mockssandra"*/)
2395 : RoundRobinEventLoopGroup(num_threads) {
2396 int rc = init(thread_name);
2397 UNUSED_(rc);
2398 assert(rc == 0 && "Unable to initialize simple event loop");
2399 run();
2400 }
2401
~SimpleEventLoopGroup()2402 SimpleEventLoopGroup::~SimpleEventLoopGroup() {
2403 close_handles();
2404 join();
2405 }
2406
SimpleRequestHandlerBuilder()2407 SimpleRequestHandlerBuilder::SimpleRequestHandlerBuilder()
2408 : RequestHandler::Builder() {
2409 on(OPCODE_STARTUP).validate_startup().ready();
2410 on(OPCODE_OPTIONS).supported();
2411 on(OPCODE_CREDENTIALS).validate_credentials().ready();
2412 on(OPCODE_AUTH_RESPONSE).validate_auth_response().auth_success("");
2413 on(OPCODE_REGISTER)
2414 .validate_register()
2415 .set_protocol_version()
2416 .set_registered_for_events()
2417 .ready();
2418 on(OPCODE_QUERY).system_local().system_peers().empty_rows_result(1);
2419 }
2420
AuthRequestHandlerBuilder(const String & username,const String & password)2421 AuthRequestHandlerBuilder::AuthRequestHandlerBuilder(const String& username, const String& password)
2422 : SimpleRequestHandlerBuilder() {
2423 on(mockssandra::OPCODE_STARTUP).validate_startup().authenticate("com.datastax.SomeAuthenticator");
2424 on(mockssandra::OPCODE_AUTH_RESPONSE).validate_auth_response().plaintext_auth(username, password);
2425 }
2426
2427 } // namespace mockssandra
2428