1 /*
2 Copyright (c) 2014-2017 DataStax
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #include "socket.hpp"
18
19 #include "logger.hpp"
20
21 #define SSL_READ_SIZE 8192
22 #define SSL_WRITE_SIZE 8192
23 #define SSL_ENCRYPTED_BUFS_COUNT 16
24
25 #define MAX_BUFFER_REUSE_NO 8
26 #define BUFFER_REUSE_SIZE 64 * 1024
27
28 using namespace datastax::internal;
29 using namespace datastax::internal::core;
30
31 typedef Vector<uv_buf_t> UvBufVec;
32
33 /**
34 * A basic socket write handler.
35 */
36 class SocketWrite : public SocketWriteBase {
37 public:
SocketWrite(Socket * socket)38 SocketWrite(Socket* socket)
39 : SocketWriteBase(socket) {}
40
41 size_t flush();
42 };
43
flush()44 size_t SocketWrite::flush() {
45 size_t total = 0;
46 if (!is_flushed_ && !buffers_.empty()) {
47 UvBufVec bufs;
48
49 bufs.reserve(buffers_.size());
50
51 for (BufferVec::const_iterator it = buffers_.begin(), end = buffers_.end(); it != end; ++it) {
52 total += it->size();
53 bufs.push_back(uv_buf_init(const_cast<char*>(it->data()), it->size()));
54 }
55
56 is_flushed_ = true;
57 uv_stream_t* sock_stream = reinterpret_cast<uv_stream_t*>(tcp());
58 uv_write(&req_, sock_stream, bufs.data(), bufs.size(), SocketWrite::on_write);
59 }
60 return total;
61 }
62
~SocketHandler()63 SocketHandler::~SocketHandler() {
64 while (!buffer_reuse_list_.empty()) {
65 uv_buf_t buf = buffer_reuse_list_.top();
66 Memory::free(buf.base);
67 buffer_reuse_list_.pop();
68 }
69 }
70
new_pending_write(Socket * socket)71 SocketWriteBase* SocketHandler::new_pending_write(Socket* socket) {
72 return new SocketWrite(socket);
73 }
74
alloc_buffer(size_t suggested_size,uv_buf_t * buf)75 void SocketHandler::alloc_buffer(size_t suggested_size, uv_buf_t* buf) {
76 if (suggested_size <= BUFFER_REUSE_SIZE) {
77 if (!buffer_reuse_list_.empty()) {
78 *buf = buffer_reuse_list_.top();
79 buffer_reuse_list_.pop();
80 } else {
81 *buf = uv_buf_init(reinterpret_cast<char*>(Memory::malloc(BUFFER_REUSE_SIZE)),
82 BUFFER_REUSE_SIZE);
83 }
84 } else {
85 *buf = uv_buf_init(reinterpret_cast<char*>(Memory::malloc(suggested_size)), suggested_size);
86 }
87 }
88
free_buffer(const uv_buf_t * buf)89 void SocketHandler::free_buffer(const uv_buf_t* buf) {
90 if (buf->len == BUFFER_REUSE_SIZE && buffer_reuse_list_.size() < MAX_BUFFER_REUSE_NO) {
91 buffer_reuse_list_.push(*buf);
92 return;
93 }
94 Memory::free(buf->base);
95 }
96
97 /**
98 * A SSL socket write handler.
99 */
100 class SslSocketWrite : public SocketWriteBase {
101 public:
SslSocketWrite(Socket * socket,SslSession * ssl_session)102 SslSocketWrite(Socket* socket, SslSession* ssl_session)
103 : SocketWriteBase(socket)
104 , ssl_session_(ssl_session)
105 , encrypted_size_(0) {}
106
107 virtual size_t flush();
108
109 private:
110 void encrypt();
111
112 static void on_write(uv_write_t* req, int status);
113
114 private:
115 SslSession* ssl_session_;
116 size_t encrypted_size_;
117 };
118
flush()119 size_t SslSocketWrite::flush() {
120 size_t total = 0;
121 if (!is_flushed_ && !buffers_.empty()) {
122 rb::RingBuffer::Position prev_pos = ssl_session_->outgoing().write_position();
123
124 encrypt();
125
126 SmallVector<uv_buf_t, SSL_ENCRYPTED_BUFS_COUNT> bufs;
127 total = encrypted_size_ = ssl_session_->outgoing().peek_multiple(prev_pos, &bufs);
128
129 LOG_TRACE("Sending %u encrypted bytes", static_cast<unsigned int>(encrypted_size_));
130
131 uv_stream_t* sock_stream = reinterpret_cast<uv_stream_t*>(tcp());
132 uv_write(&req_, sock_stream, bufs.data(), bufs.size(), SslSocketWrite::on_write);
133
134 is_flushed_ = true;
135 }
136 return total;
137 }
138
encrypt()139 void SslSocketWrite::encrypt() {
140 char buf[SSL_WRITE_SIZE];
141
142 size_t copied = 0;
143 size_t offset = 0;
144 size_t total = 0;
145
146 BufferVec::const_iterator it = buffers_.begin(), end = buffers_.end();
147
148 LOG_TRACE("Copying %u bufs", static_cast<unsigned int>(buffers_.size()));
149
150 bool is_done = (it == end);
151
152 while (!is_done) {
153 assert(it->size() > 0);
154 size_t size = it->size();
155
156 size_t to_copy = size - offset;
157 size_t available = SSL_WRITE_SIZE - copied;
158 if (available < to_copy) {
159 to_copy = available;
160 }
161
162 memcpy(buf + copied, it->data() + offset, to_copy);
163
164 copied += to_copy;
165 offset += to_copy;
166 total += to_copy;
167
168 if (offset == size) {
169 ++it;
170 offset = 0;
171 }
172
173 is_done = (it == end);
174
175 if (is_done || copied == SSL_WRITE_SIZE) {
176 int rc = ssl_session_->encrypt(buf, copied);
177 if (rc <= 0 && ssl_session_->has_error()) {
178 LOG_ERROR("Unable to encrypt data: %s", ssl_session_->error_message().c_str());
179 socket_->defunct();
180 return;
181 }
182 copied = 0;
183 }
184 }
185
186 LOG_TRACE("Copied %u bytes for encryption", static_cast<unsigned int>(total));
187 }
188
on_write(uv_write_t * req,int status)189 void SslSocketWrite::on_write(uv_write_t* req, int status) {
190 if (status == 0) {
191 SslSocketWrite* socket_write = static_cast<SslSocketWrite*>(req->data);
192 socket_write->ssl_session_->outgoing().read(NULL, socket_write->encrypted_size_);
193 }
194 SocketWriteBase::on_write(req, status);
195 }
196
new_pending_write(Socket * socket)197 SocketWriteBase* SslSocketHandler::new_pending_write(Socket* socket) {
198 return new SslSocketWrite(socket, ssl_session_.get());
199 }
200
alloc_buffer(size_t suggested_size,uv_buf_t * buf)201 void SslSocketHandler::alloc_buffer(size_t suggested_size, uv_buf_t* buf) {
202 buf->base = ssl_session_->incoming().peek_writable(&suggested_size);
203 buf->len = suggested_size;
204 }
205
on_read(Socket * socket,ssize_t nread,const uv_buf_t * buf)206 void SslSocketHandler::on_read(Socket* socket, ssize_t nread, const uv_buf_t* buf) {
207 if (nread < 0) return;
208
209 ssl_session_->incoming().commit(nread);
210 char decrypted[SSL_READ_SIZE];
211 int rc = 0;
212 while ((rc = ssl_session_->decrypt(decrypted, sizeof(decrypted))) > 0) {
213 on_ssl_read(socket, decrypted, rc);
214 }
215 if (rc <= 0 && ssl_session_->has_error()) {
216 if (ssl_session_->error_code() == CASS_ERROR_SSL_CLOSED) {
217 LOG_DEBUG("SSL session closed");
218 socket->close();
219 } else {
220 LOG_ERROR("Unable to decrypt data: %s", ssl_session_->error_message().c_str());
221 socket->defunct();
222 }
223 }
224 }
225
tcp()226 uv_tcp_t* SocketWriteBase::tcp() { return &socket_->tcp_; }
227
on_close()228 void SocketWriteBase::on_close() {
229 for (RequestVec::iterator i = requests_.begin(), end = requests_.end(); i != end; ++i) {
230 (*i)->on_close();
231 }
232 }
233
write(SocketRequest * request)234 int32_t SocketWriteBase::write(SocketRequest* request) {
235 size_t last_buffer_size = buffers_.size();
236 int32_t request_size = request->encode(&buffers_);
237 if (request_size <= 0) {
238 buffers_.resize(last_buffer_size); // Rollback
239 return request_size;
240 }
241
242 requests_.push_back(request);
243
244 return request_size;
245 }
246
on_write(uv_write_t * req,int status)247 void SocketWriteBase::on_write(uv_write_t* req, int status) {
248 SocketWriteBase* pending_write = static_cast<SocketWriteBase*>(req->data);
249 pending_write->handle_write(req, status);
250 }
251
handle_write(uv_write_t * req,int status)252 void SocketWriteBase::handle_write(uv_write_t* req, int status) {
253 Socket* socket = socket_;
254
255 if (status != 0) {
256 if (!socket->is_closing()) {
257 LOG_ERROR("Socket write error '%s'", uv_strerror(status));
258 socket->defunct();
259 }
260 }
261
262 if (socket->handler_) {
263 for (RequestVec::iterator i = requests_.begin(), end = requests_.end(); i != end; ++i) {
264 socket->handler_->on_write(socket, status, *i);
265 }
266 }
267
268 socket->pending_writes_.remove(this);
269
270 if (socket->free_writes_.size() < socket->max_reusable_write_objects_) {
271 clear();
272 socket->free_writes_.push_back(this);
273 } else {
274 delete this;
275 }
276
277 socket->flush();
278 }
279
Socket(const Address & address,size_t max_reusable_write_objects)280 Socket::Socket(const Address& address, size_t max_reusable_write_objects)
281 : is_defunct_(false)
282 , max_reusable_write_objects_(max_reusable_write_objects)
283 , address_(address) {
284 tcp_.data = this;
285 }
286
~Socket()287 Socket::~Socket() { cleanup_free_writes(); }
288
set_handler(SocketHandlerBase * handler)289 void Socket::set_handler(SocketHandlerBase* handler) {
290 handler_.reset(handler);
291 cleanup_free_writes();
292 free_writes_.clear();
293 if (handler_) {
294 uv_read_start(reinterpret_cast<uv_stream_t*>(&tcp_), Socket::alloc_buffer, Socket::on_read);
295 } else {
296 uv_read_stop(reinterpret_cast<uv_stream_t*>(&tcp_));
297 }
298 }
299
write(SocketRequest * request)300 int32_t Socket::write(SocketRequest* request) {
301 if (!handler_) {
302 return SocketRequest::SOCKET_REQUEST_ERROR_NO_HANDLER;
303 }
304
305 if (is_closing()) {
306 return SocketRequest::SOCKET_REQUEST_ERROR_CLOSED;
307 }
308
309 if (pending_writes_.is_empty() || pending_writes_.back()->is_flushed()) {
310 if (!free_writes_.empty()) {
311 pending_writes_.add_to_back(free_writes_.back());
312 free_writes_.pop_back();
313 } else {
314 pending_writes_.add_to_back(handler_->new_pending_write(this));
315 }
316 }
317
318 return pending_writes_.back()->write(request);
319 }
320
write_and_flush(SocketRequest * request)321 int32_t Socket::write_and_flush(SocketRequest* request) {
322 int32_t result = write(request);
323 if (result > 0) {
324 flush();
325 }
326 return result;
327 }
328
flush()329 size_t Socket::flush() {
330 if (pending_writes_.is_empty()) return 0;
331
332 return pending_writes_.back()->flush();
333 }
334
is_closing() const335 bool Socket::is_closing() const {
336 return uv_is_closing(reinterpret_cast<const uv_handle_t*>(&tcp_)) != 0;
337 }
338
close()339 void Socket::close() {
340 uv_handle_t* handle = reinterpret_cast<uv_handle_t*>(&tcp_);
341 if (!uv_is_closing(handle)) {
342 uv_close(handle, on_close);
343 }
344 }
345
defunct()346 void Socket::defunct() {
347 close();
348 is_defunct_ = true;
349 }
350
alloc_buffer(uv_handle_t * handle,size_t suggested_size,uv_buf_t * buf)351 void Socket::alloc_buffer(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) {
352 Socket* socket = static_cast<Socket*>(handle->data);
353 socket->handler_->alloc_buffer(suggested_size, buf);
354 }
355
on_read(uv_stream_t * client,ssize_t nread,const uv_buf_t * buf)356 void Socket::on_read(uv_stream_t* client, ssize_t nread, const uv_buf_t* buf) {
357 Socket* socket = static_cast<Socket*>(client->data);
358 socket->handle_read(nread, buf);
359 }
360
handle_read(ssize_t nread,const uv_buf_t * buf)361 void Socket::handle_read(ssize_t nread, const uv_buf_t* buf) {
362 if (nread < 0) {
363 if (nread != UV_EOF) {
364 LOG_ERROR("Socket read error '%s'", uv_strerror(nread));
365 }
366 defunct();
367 }
368 handler_->on_read(this, nread, buf);
369 }
370
on_close(uv_handle_t * handle)371 void Socket::on_close(uv_handle_t* handle) {
372 Socket* socket = static_cast<Socket*>(handle->data);
373 socket->handle_close();
374 }
375
handle_close()376 void Socket::handle_close() {
377 LOG_DEBUG("Socket(%p) to host %s closed", static_cast<void*>(this), address_.to_string().c_str());
378
379 while (!pending_writes_.is_empty()) {
380 SocketWriteBase* pending_write = pending_writes_.pop_front();
381 pending_write->on_close();
382 delete pending_write;
383 }
384
385 if (handler_) {
386 handler_->on_close();
387 }
388 dec_ref();
389 }
390
cleanup_free_writes()391 void Socket::cleanup_free_writes() {
392 for (SocketWriteVec::iterator i = free_writes_.begin(), end = free_writes_.end(); i != end; ++i) {
393 delete *i;
394 }
395 }
396