1 /*
2   Copyright (c) 2014-2017 DataStax
3 
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7 
8   http://www.apache.org/licenses/LICENSE-2.0
9 
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15 */
16 
17 #include "socket.hpp"
18 
19 #include "logger.hpp"
20 
21 #define SSL_READ_SIZE 8192
22 #define SSL_WRITE_SIZE 8192
23 #define SSL_ENCRYPTED_BUFS_COUNT 16
24 
25 #define MAX_BUFFER_REUSE_NO 8
26 #define BUFFER_REUSE_SIZE 64 * 1024
27 
28 using namespace datastax::internal;
29 using namespace datastax::internal::core;
30 
31 typedef Vector<uv_buf_t> UvBufVec;
32 
33 /**
34  * A basic socket write handler.
35  */
36 class SocketWrite : public SocketWriteBase {
37 public:
SocketWrite(Socket * socket)38   SocketWrite(Socket* socket)
39       : SocketWriteBase(socket) {}
40 
41   size_t flush();
42 };
43 
flush()44 size_t SocketWrite::flush() {
45   size_t total = 0;
46   if (!is_flushed_ && !buffers_.empty()) {
47     UvBufVec bufs;
48 
49     bufs.reserve(buffers_.size());
50 
51     for (BufferVec::const_iterator it = buffers_.begin(), end = buffers_.end(); it != end; ++it) {
52       total += it->size();
53       bufs.push_back(uv_buf_init(const_cast<char*>(it->data()), it->size()));
54     }
55 
56     is_flushed_ = true;
57     uv_stream_t* sock_stream = reinterpret_cast<uv_stream_t*>(tcp());
58     uv_write(&req_, sock_stream, bufs.data(), bufs.size(), SocketWrite::on_write);
59   }
60   return total;
61 }
62 
~SocketHandler()63 SocketHandler::~SocketHandler() {
64   while (!buffer_reuse_list_.empty()) {
65     uv_buf_t buf = buffer_reuse_list_.top();
66     Memory::free(buf.base);
67     buffer_reuse_list_.pop();
68   }
69 }
70 
new_pending_write(Socket * socket)71 SocketWriteBase* SocketHandler::new_pending_write(Socket* socket) {
72   return new SocketWrite(socket);
73 }
74 
alloc_buffer(size_t suggested_size,uv_buf_t * buf)75 void SocketHandler::alloc_buffer(size_t suggested_size, uv_buf_t* buf) {
76   if (suggested_size <= BUFFER_REUSE_SIZE) {
77     if (!buffer_reuse_list_.empty()) {
78       *buf = buffer_reuse_list_.top();
79       buffer_reuse_list_.pop();
80     } else {
81       *buf = uv_buf_init(reinterpret_cast<char*>(Memory::malloc(BUFFER_REUSE_SIZE)),
82                          BUFFER_REUSE_SIZE);
83     }
84   } else {
85     *buf = uv_buf_init(reinterpret_cast<char*>(Memory::malloc(suggested_size)), suggested_size);
86   }
87 }
88 
free_buffer(const uv_buf_t * buf)89 void SocketHandler::free_buffer(const uv_buf_t* buf) {
90   if (buf->len == BUFFER_REUSE_SIZE && buffer_reuse_list_.size() < MAX_BUFFER_REUSE_NO) {
91     buffer_reuse_list_.push(*buf);
92     return;
93   }
94   Memory::free(buf->base);
95 }
96 
97 /**
98  * A SSL socket write handler.
99  */
100 class SslSocketWrite : public SocketWriteBase {
101 public:
SslSocketWrite(Socket * socket,SslSession * ssl_session)102   SslSocketWrite(Socket* socket, SslSession* ssl_session)
103       : SocketWriteBase(socket)
104       , ssl_session_(ssl_session)
105       , encrypted_size_(0) {}
106 
107   virtual size_t flush();
108 
109 private:
110   void encrypt();
111 
112   static void on_write(uv_write_t* req, int status);
113 
114 private:
115   SslSession* ssl_session_;
116   size_t encrypted_size_;
117 };
118 
flush()119 size_t SslSocketWrite::flush() {
120   size_t total = 0;
121   if (!is_flushed_ && !buffers_.empty()) {
122     rb::RingBuffer::Position prev_pos = ssl_session_->outgoing().write_position();
123 
124     encrypt();
125 
126     SmallVector<uv_buf_t, SSL_ENCRYPTED_BUFS_COUNT> bufs;
127     total = encrypted_size_ = ssl_session_->outgoing().peek_multiple(prev_pos, &bufs);
128 
129     LOG_TRACE("Sending %u encrypted bytes", static_cast<unsigned int>(encrypted_size_));
130 
131     uv_stream_t* sock_stream = reinterpret_cast<uv_stream_t*>(tcp());
132     uv_write(&req_, sock_stream, bufs.data(), bufs.size(), SslSocketWrite::on_write);
133 
134     is_flushed_ = true;
135   }
136   return total;
137 }
138 
encrypt()139 void SslSocketWrite::encrypt() {
140   char buf[SSL_WRITE_SIZE];
141 
142   size_t copied = 0;
143   size_t offset = 0;
144   size_t total = 0;
145 
146   BufferVec::const_iterator it = buffers_.begin(), end = buffers_.end();
147 
148   LOG_TRACE("Copying %u bufs", static_cast<unsigned int>(buffers_.size()));
149 
150   bool is_done = (it == end);
151 
152   while (!is_done) {
153     assert(it->size() > 0);
154     size_t size = it->size();
155 
156     size_t to_copy = size - offset;
157     size_t available = SSL_WRITE_SIZE - copied;
158     if (available < to_copy) {
159       to_copy = available;
160     }
161 
162     memcpy(buf + copied, it->data() + offset, to_copy);
163 
164     copied += to_copy;
165     offset += to_copy;
166     total += to_copy;
167 
168     if (offset == size) {
169       ++it;
170       offset = 0;
171     }
172 
173     is_done = (it == end);
174 
175     if (is_done || copied == SSL_WRITE_SIZE) {
176       int rc = ssl_session_->encrypt(buf, copied);
177       if (rc <= 0 && ssl_session_->has_error()) {
178         LOG_ERROR("Unable to encrypt data: %s", ssl_session_->error_message().c_str());
179         socket_->defunct();
180         return;
181       }
182       copied = 0;
183     }
184   }
185 
186   LOG_TRACE("Copied %u bytes for encryption", static_cast<unsigned int>(total));
187 }
188 
on_write(uv_write_t * req,int status)189 void SslSocketWrite::on_write(uv_write_t* req, int status) {
190   if (status == 0) {
191     SslSocketWrite* socket_write = static_cast<SslSocketWrite*>(req->data);
192     socket_write->ssl_session_->outgoing().read(NULL, socket_write->encrypted_size_);
193   }
194   SocketWriteBase::on_write(req, status);
195 }
196 
new_pending_write(Socket * socket)197 SocketWriteBase* SslSocketHandler::new_pending_write(Socket* socket) {
198   return new SslSocketWrite(socket, ssl_session_.get());
199 }
200 
alloc_buffer(size_t suggested_size,uv_buf_t * buf)201 void SslSocketHandler::alloc_buffer(size_t suggested_size, uv_buf_t* buf) {
202   buf->base = ssl_session_->incoming().peek_writable(&suggested_size);
203   buf->len = suggested_size;
204 }
205 
on_read(Socket * socket,ssize_t nread,const uv_buf_t * buf)206 void SslSocketHandler::on_read(Socket* socket, ssize_t nread, const uv_buf_t* buf) {
207   if (nread < 0) return;
208 
209   ssl_session_->incoming().commit(nread);
210   char decrypted[SSL_READ_SIZE];
211   int rc = 0;
212   while ((rc = ssl_session_->decrypt(decrypted, sizeof(decrypted))) > 0) {
213     on_ssl_read(socket, decrypted, rc);
214   }
215   if (rc <= 0 && ssl_session_->has_error()) {
216     if (ssl_session_->error_code() == CASS_ERROR_SSL_CLOSED) {
217       LOG_DEBUG("SSL session closed");
218       socket->close();
219     } else {
220       LOG_ERROR("Unable to decrypt data: %s", ssl_session_->error_message().c_str());
221       socket->defunct();
222     }
223   }
224 }
225 
tcp()226 uv_tcp_t* SocketWriteBase::tcp() { return &socket_->tcp_; }
227 
on_close()228 void SocketWriteBase::on_close() {
229   for (RequestVec::iterator i = requests_.begin(), end = requests_.end(); i != end; ++i) {
230     (*i)->on_close();
231   }
232 }
233 
write(SocketRequest * request)234 int32_t SocketWriteBase::write(SocketRequest* request) {
235   size_t last_buffer_size = buffers_.size();
236   int32_t request_size = request->encode(&buffers_);
237   if (request_size <= 0) {
238     buffers_.resize(last_buffer_size); // Rollback
239     return request_size;
240   }
241 
242   requests_.push_back(request);
243 
244   return request_size;
245 }
246 
on_write(uv_write_t * req,int status)247 void SocketWriteBase::on_write(uv_write_t* req, int status) {
248   SocketWriteBase* pending_write = static_cast<SocketWriteBase*>(req->data);
249   pending_write->handle_write(req, status);
250 }
251 
handle_write(uv_write_t * req,int status)252 void SocketWriteBase::handle_write(uv_write_t* req, int status) {
253   Socket* socket = socket_;
254 
255   if (status != 0) {
256     if (!socket->is_closing()) {
257       LOG_ERROR("Socket write error '%s'", uv_strerror(status));
258       socket->defunct();
259     }
260   }
261 
262   if (socket->handler_) {
263     for (RequestVec::iterator i = requests_.begin(), end = requests_.end(); i != end; ++i) {
264       socket->handler_->on_write(socket, status, *i);
265     }
266   }
267 
268   socket->pending_writes_.remove(this);
269 
270   if (socket->free_writes_.size() < socket->max_reusable_write_objects_) {
271     clear();
272     socket->free_writes_.push_back(this);
273   } else {
274     delete this;
275   }
276 
277   socket->flush();
278 }
279 
Socket(const Address & address,size_t max_reusable_write_objects)280 Socket::Socket(const Address& address, size_t max_reusable_write_objects)
281     : is_defunct_(false)
282     , max_reusable_write_objects_(max_reusable_write_objects)
283     , address_(address) {
284   tcp_.data = this;
285 }
286 
~Socket()287 Socket::~Socket() { cleanup_free_writes(); }
288 
set_handler(SocketHandlerBase * handler)289 void Socket::set_handler(SocketHandlerBase* handler) {
290   handler_.reset(handler);
291   cleanup_free_writes();
292   free_writes_.clear();
293   if (handler_) {
294     uv_read_start(reinterpret_cast<uv_stream_t*>(&tcp_), Socket::alloc_buffer, Socket::on_read);
295   } else {
296     uv_read_stop(reinterpret_cast<uv_stream_t*>(&tcp_));
297   }
298 }
299 
write(SocketRequest * request)300 int32_t Socket::write(SocketRequest* request) {
301   if (!handler_) {
302     return SocketRequest::SOCKET_REQUEST_ERROR_NO_HANDLER;
303   }
304 
305   if (is_closing()) {
306     return SocketRequest::SOCKET_REQUEST_ERROR_CLOSED;
307   }
308 
309   if (pending_writes_.is_empty() || pending_writes_.back()->is_flushed()) {
310     if (!free_writes_.empty()) {
311       pending_writes_.add_to_back(free_writes_.back());
312       free_writes_.pop_back();
313     } else {
314       pending_writes_.add_to_back(handler_->new_pending_write(this));
315     }
316   }
317 
318   return pending_writes_.back()->write(request);
319 }
320 
write_and_flush(SocketRequest * request)321 int32_t Socket::write_and_flush(SocketRequest* request) {
322   int32_t result = write(request);
323   if (result > 0) {
324     flush();
325   }
326   return result;
327 }
328 
flush()329 size_t Socket::flush() {
330   if (pending_writes_.is_empty()) return 0;
331 
332   return pending_writes_.back()->flush();
333 }
334 
is_closing() const335 bool Socket::is_closing() const {
336   return uv_is_closing(reinterpret_cast<const uv_handle_t*>(&tcp_)) != 0;
337 }
338 
close()339 void Socket::close() {
340   uv_handle_t* handle = reinterpret_cast<uv_handle_t*>(&tcp_);
341   if (!uv_is_closing(handle)) {
342     uv_close(handle, on_close);
343   }
344 }
345 
defunct()346 void Socket::defunct() {
347   close();
348   is_defunct_ = true;
349 }
350 
alloc_buffer(uv_handle_t * handle,size_t suggested_size,uv_buf_t * buf)351 void Socket::alloc_buffer(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) {
352   Socket* socket = static_cast<Socket*>(handle->data);
353   socket->handler_->alloc_buffer(suggested_size, buf);
354 }
355 
on_read(uv_stream_t * client,ssize_t nread,const uv_buf_t * buf)356 void Socket::on_read(uv_stream_t* client, ssize_t nread, const uv_buf_t* buf) {
357   Socket* socket = static_cast<Socket*>(client->data);
358   socket->handle_read(nread, buf);
359 }
360 
handle_read(ssize_t nread,const uv_buf_t * buf)361 void Socket::handle_read(ssize_t nread, const uv_buf_t* buf) {
362   if (nread < 0) {
363     if (nread != UV_EOF) {
364       LOG_ERROR("Socket read error '%s'", uv_strerror(nread));
365     }
366     defunct();
367   }
368   handler_->on_read(this, nread, buf);
369 }
370 
on_close(uv_handle_t * handle)371 void Socket::on_close(uv_handle_t* handle) {
372   Socket* socket = static_cast<Socket*>(handle->data);
373   socket->handle_close();
374 }
375 
handle_close()376 void Socket::handle_close() {
377   LOG_DEBUG("Socket(%p) to host %s closed", static_cast<void*>(this), address_.to_string().c_str());
378 
379   while (!pending_writes_.is_empty()) {
380     SocketWriteBase* pending_write = pending_writes_.pop_front();
381     pending_write->on_close();
382     delete pending_write;
383   }
384 
385   if (handler_) {
386     handler_->on_close();
387   }
388   dec_ref();
389 }
390 
cleanup_free_writes()391 void Socket::cleanup_free_writes() {
392   for (SocketWriteVec::iterator i = free_writes_.begin(), end = free_writes_.end(); i != end; ++i) {
393     delete *i;
394   }
395 }
396