1 /*
2   Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
3 
4   This program is free software; you can redistribute it and/or modify
5   it under the terms of the GNU General Public License, version 2.0,
6   as published by the Free Software Foundation.
7 
8   This program is also distributed with certain software (including
9   but not limited to OpenSSL) that is licensed under separate terms,
10   as designated in a particular file or component or in included license
11   documentation.  The authors of MySQL hereby grant you an additional
12   permission to link the program and your derivative works with the
13   separately licensed software that they have included with MySQL.
14 
15   This program is distributed in the hope that it will be useful,
16   but WITHOUT ANY WARRANTY; without even the implied warranty of
17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18   GNU General Public License for more details.
19 
20   You should have received a copy of the GNU General Public License
21   along with this program; if not, write to the Free Software
22   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
23 */
24 
25 #ifdef RAPIDJSON_NO_SIZETYPEDEFINE
26 // if we build within the server, it will set RAPIDJSON_NO_SIZETYPEDEFINE
27 // globally and require to include my_rapidjson_size_t.h
28 #include "my_rapidjson_size_t.h"
29 #endif
30 
31 #include "x_mock_session.h"
32 
33 #include <google/protobuf/util/json_util.h>
34 #include <rapidjson/document.h>
35 #include <thread>
36 
37 #include "mysql/harness/logging/logging.h"
38 IMPORT_LOG_FUNCTIONS()
39 
40 #include "config.h"
41 #include "mysql_protocol_utils.h"
42 #include "mysqlx_error.h"
43 #include "mysqlxclient/xprotocol.h"
44 
45 namespace server_mock {
46 
47 struct MySQLServerMockSessionX::Impl {
Implserver_mock::MySQLServerMockSessionX::Impl48   Impl(socket_t client_socket, const XProtocolDecoder &protocol_decoder,
49        const std::vector<AsyncNotice> &async_notices)
50       : client_socket_(client_socket),
51         protocol_decoder_(protocol_decoder),
52         aync_notices_(async_notices) {}
53 
recv_headerserver_mock::MySQLServerMockSessionX::Impl54   bool recv_header(uint8_t *out_msg_id, std::size_t *out_buffer_size) {
55     union {
56       uint8_t header_buffer[5];
57       uint32_t payload_size;
58     };
59 
60     read_packet(client_socket_, &header_buffer[0], 5);
61 
62 #ifdef WORDS_BIGENDIAN
63     std::swap(header_buffer[0], header_buffer[3]);
64     std::swap(header_buffer[1], header_buffer[2]);
65 #endif
66 
67     *out_buffer_size = payload_size - 1;
68     *out_msg_id = header_buffer[4];
69 
70     return true;
71   }
72 
recv_single_messageserver_mock::MySQLServerMockSessionX::Impl73   std::unique_ptr<xcl::XProtocol::Message> recv_single_message(
74       xcl::XProtocol::Client_message_type_id *out_msg_id) {
75     std::size_t payload_size = 0;
76     uint8_t header_msg_id;
77 
78     const bool res = recv_header(&header_msg_id, &payload_size);
79     if (!res) return nullptr;
80 
81     std::unique_ptr<std::uint8_t[]> allocated_payload_buffer;
82     std::uint8_t *payload = nullptr;
83     if (payload_size > 0) {
84       allocated_payload_buffer.reset(new uint8_t[payload_size]);
85       payload = allocated_payload_buffer.get();
86       read_packet(client_socket_, payload, payload_size);
87     }
88 
89     *out_msg_id =
90         static_cast<xcl::XProtocol::Client_message_type_id>(header_msg_id);
91 
92     return protocol_decoder_.decode_message(header_msg_id, payload,
93                                             payload_size);
94   }
95 
client_socket_has_dataserver_mock::MySQLServerMockSessionX::Impl96   bool client_socket_has_data(const std::chrono::milliseconds timeout) {
97     return socket_has_data(client_socket_, static_cast<int>(timeout.count()));
98   }
99 
sendserver_mock::MySQLServerMockSessionX::Impl100   void send(const xcl::XProtocol::Server_message_type_id msg_id,
101             const xcl::XProtocol::Message &msg) {
102     std::string msg_buffer;
103     const std::uint8_t header_size = 5;
104 
105 #if (defined(GOOGLE_PROTOBUF_VERSION) && GOOGLE_PROTOBUF_VERSION > 3000000)
106     const std::size_t msg_size = msg.ByteSizeLong();
107 #else
108     const std::size_t msg_size = msg.ByteSize();
109 #endif
110 
111     msg_buffer.resize(msg_size + header_size);
112 
113     if (!msg.SerializeToArray(&msg_buffer[0] + header_size, msg_size)) {
114       throw std::runtime_error("Failed to serialize the message");
115     }
116 
117     const auto msg_size_to_buffer = static_cast<std::uint32_t>(msg_size + 1);
118 
119     memcpy(&msg_buffer[0], &msg_size_to_buffer, sizeof(std::uint32_t));
120 #ifdef WORDS_BIGENDIAN
121     std::swap(msg_buffer[0], msg_buffer[3]);
122     std::swap(msg_buffer[1], msg_buffer[2]);
123 #endif
124     msg_buffer[4] = msg_id;
125 
126     send_packet(client_socket_,
127                 reinterpret_cast<const std::uint8_t *>(msg_buffer.data()),
128                 msg_buffer.size());
129   }
130 
send_due_async_noticesserver_mock::MySQLServerMockSessionX::Impl131   void send_due_async_notices(
132       const std::chrono::time_point<std::chrono::system_clock> &start_time) {
133     const auto current_time = std::chrono::system_clock::now();
134     auto ms_passed = std::chrono::duration_cast<std::chrono::milliseconds>(
135                          current_time - start_time)
136                          .count();
137     for (auto it = aync_notices_.begin(); it != aync_notices_.end();) {
138       if (it->send_offset_ms.count() <= ms_passed) {
139         send_async_notice(*it);
140         it = aync_notices_.erase(it);
141       } else {
142         ++it;
143       }
144     }
145   }
146 
147  private:
148   stdx::expected<std::unique_ptr<xcl::XProtocol::Message>, std::string>
gr_state_changed_from_jsonserver_mock::MySQLServerMockSessionX::Impl149   gr_state_changed_from_json(const std::string &json_string) {
150     rapidjson::Document json_doc;
151     auto result{
152         std::make_unique<Mysqlx::Notice::GroupReplicationStateChanged>()};
153     json_doc.Parse(json_string.c_str());
154     if (json_doc.HasMember("type")) {
155       if (json_doc["type"].IsUint()) {
156         result->set_type(json_doc["type"].GetUint());
157       } else {
158         return stdx::make_unexpected(
159             "Invalid json type for field 'type', expected 'uint' got " +
160             std::to_string(json_doc["type"].GetType()));
161       }
162     }
163 
164     if (json_doc.HasMember("view_id")) {
165       if (json_doc["view_id"].IsString()) {
166         result->set_view_id(json_doc["view_id"].GetString());
167       } else {
168         return stdx::make_unexpected(
169             "Invalid json type for field 'view_id', expected 'string' got " +
170             std::to_string(json_doc["view_id"].GetType()));
171       }
172     }
173 
174     return std::unique_ptr<xcl::XProtocol::Message>(std::move(result));
175   }
176 
177   stdx::expected<std::unique_ptr<xcl::XProtocol::Message>, std::string>
get_notice_messageserver_mock::MySQLServerMockSessionX::Impl178   get_notice_message(const unsigned id, const std::string &payload) {
179     switch (id) {
180       case Mysqlx::Notice::Frame_Type_GROUP_REPLICATION_STATE_CHANGED: {
181         return gr_state_changed_from_json(payload);
182       }
183       // those we currently not use, if needed add a function encoding json
184       // string to the selected message type
185       case Mysqlx::Notice::Frame_Type_WARNING:
186       case Mysqlx::Notice::Frame_Type_SESSION_VARIABLE_CHANGED:
187       case Mysqlx::Notice::Frame_Type_SESSION_STATE_CHANGED:
188       default:
189         return stdx::make_unexpected("Unsupported notice id: " +
190                                      std::to_string(id));
191     }
192   }
193 
send_async_noticeserver_mock::MySQLServerMockSessionX::Impl194   void send_async_notice(const AsyncNotice &async_notice) {
195     Mysqlx::Notice::Frame notice_frame;
196     notice_frame.set_type(async_notice.type);
197     notice_frame.set_scope(async_notice.is_local
198                                ? Mysqlx::Notice::Frame_Scope_LOCAL
199                                : Mysqlx::Notice::Frame_Scope_GLOBAL);
200 
201     auto notice_msg =
202         get_notice_message(async_notice.type, async_notice.payload);
203 
204     if (!notice_msg)
205       throw std::runtime_error("Failed encoding notice message: " +
206                                notice_msg.error());
207 
208     notice_frame.set_payload(notice_msg.value()->SerializeAsString());
209 
210     send(xcl::XProtocol::Server_message_type_id::ServerMessages_Type_NOTICE,
211          notice_frame);
212   }
213 
214   socket_t client_socket_;
215   const XProtocolDecoder &protocol_decoder_;
216   std::vector<AsyncNotice> aync_notices_;
217 };
218 
MySQLServerMockSessionX(const socket_t client_sock,std::unique_ptr<StatementReaderBase> statement_processor,const bool debug_mode)219 MySQLServerMockSessionX::MySQLServerMockSessionX(
220     const socket_t client_sock,
221     std::unique_ptr<StatementReaderBase> statement_processor,
222     const bool debug_mode)
223     : MySQLServerMockSession(client_sock, std::move(statement_processor),
224                              debug_mode),
225       impl_(new MySQLServerMockSessionX::Impl(
226           client_sock, protocol_decoder_,
227           this->json_reader_->get_async_notices())) {}
228 
229 MySQLServerMockSessionX::~MySQLServerMockSessionX() = default;
230 
process_handshake()231 bool MySQLServerMockSessionX::process_handshake() {
232   xcl::XProtocol::Client_message_type_id out_msg_id;
233   bool done = false;
234   while (!done) {
235     auto msg = impl_->recv_single_message(&out_msg_id);
236     switch (out_msg_id) {
237       case Mysqlx::ClientMessages::CON_CAPABILITIES_SET: {
238         Mysqlx::Connection::CapabilitiesSet *capab_msg =
239             dynamic_cast<Mysqlx::Connection::CapabilitiesSet *>(msg.get());
240         harness_assert(capab_msg != nullptr);
241         bool tls_request = false;
242         const auto capabilities = capab_msg->capabilities();
243         for (int i = 0; i < capabilities.capabilities_size(); ++i) {
244           const auto capability = capabilities.capabilities(i);
245           if (capability.name() == "tls") tls_request = true;
246         }
247 
248         // we do not support TLS so if the client requested it
249         // we need to reject it
250         if (tls_request) {
251           send_error(ER_X_CAPABILITIES_PREPARE_FAILED,
252                      "Capability prepare failed for tls");
253         } else {
254           Mysqlx::Ok ok_msg;
255           impl_->send(Mysqlx::ServerMessages::OK, ok_msg);
256         }
257         break;
258       }
259       case Mysqlx::ClientMessages::CON_CAPABILITIES_GET: {
260         Mysqlx::Connection::Capabilities msg_capab;
261         impl_->send(Mysqlx::ServerMessages::CONN_CAPABILITIES, msg_capab);
262         break;
263       }
264       case Mysqlx::ClientMessages::SESS_AUTHENTICATE_START: {
265         Mysqlx::Session::AuthenticateContinue msg_auth_cont;
266         msg_auth_cont.set_auth_data("abcd");
267         impl_->send(Mysqlx::ServerMessages::SESS_AUTHENTICATE_CONTINUE,
268                     msg_auth_cont);
269         break;
270       }
271       case Mysqlx::ClientMessages::SESS_AUTHENTICATE_CONTINUE: {
272         Mysqlx::Session::AuthenticateOk msg_auth_ok;
273         impl_->send(Mysqlx::ServerMessages::SESS_AUTHENTICATE_OK, msg_auth_ok);
274         done = true;
275         break;
276       }
277       case Mysqlx::ClientMessages::CON_CLOSE: {
278         return false;
279       }
280       default:
281         done = true;
282         break;
283     }
284   }
285 
286   return true;
287 }
288 
process_statements()289 bool MySQLServerMockSessionX::process_statements() {
290   const auto kTimerResolution = std::chrono::milliseconds(10);
291   const auto start_time = std::chrono::system_clock::now();
292   while (!killed_) {
293     impl_->send_due_async_notices(start_time);
294     if (!impl_->client_socket_has_data(kTimerResolution)) {
295       continue;
296     }
297 
298     xcl::XProtocol::Client_message_type_id out_msg_id;
299     auto msg = impl_->recv_single_message(&out_msg_id);
300     switch (out_msg_id) {
301       case Mysqlx::ClientMessages::SQL_STMT_EXECUTE: {
302         Mysqlx::Sql::StmtExecute *msg_stmt_execute =
303             dynamic_cast<Mysqlx::Sql::StmtExecute *>(msg.get());
304         harness_assert(msg_stmt_execute != nullptr);
305         const auto statement_received = msg_stmt_execute->stmt();
306         try {
307           handle_statement(json_reader_->handle_statement(statement_received));
308         } catch (const std::exception &e) {
309           // handling statement failed. Return the error to the client
310           std::this_thread::sleep_for(json_reader_->get_default_exec_time());
311           send_error(1064,
312                      std::string("executing statement failed: ") + e.what());
313 
314           // assume the connection is broken
315           return true;
316         }
317 
318       } break;
319 
320       case Mysqlx::ClientMessages::CON_CLOSE: {
321         log_info("received QUIT command from the client");
322         return true;
323       }
324 
325       default:
326         log_error("received unsupported message from the x-client: %d",
327                   static_cast<int>(out_msg_id));
328 
329         std::this_thread::sleep_for(json_reader_->get_default_exec_time());
330         send_error(1064, "Unsupported command: " + std::to_string(out_msg_id));
331     }
332   }
333 
334   return true;
335 }
336 
send_error(uint16_t error_code,const std::string & error_msg,const std::string & sql_state)337 void MySQLServerMockSessionX::send_error(uint16_t error_code,
338                                          const std::string &error_msg,
339                                          const std::string &sql_state) {
340   Mysqlx::Error err_msg;
341   protocol_encoder_.encode_error(err_msg, error_code, error_msg, sql_state);
342 
343   impl_->send(Mysqlx::ServerMessages::ERROR, err_msg);
344 }
345 
send_ok(const uint64_t,const uint64_t,const uint16_t,const uint16_t)346 void MySQLServerMockSessionX::send_ok(
347     const uint64_t /*affected_rows*/,  // TODO: notice with this data?
348     const uint64_t /*last_insert_id*/, const uint16_t /*server_status*/,
349     const uint16_t /*warning_count*/) {
350   Mysqlx::Sql::StmtExecuteOk ok_msg;
351   impl_->send(Mysqlx::ServerMessages::SQL_STMT_EXECUTE_OK, ok_msg);
352 }
353 
send_resultset(const ResultsetResponse & response,const std::chrono::microseconds delay_ms)354 void MySQLServerMockSessionX::send_resultset(
355     const ResultsetResponse &response,
356     const std::chrono::microseconds delay_ms) {
357   std::this_thread::sleep_for(delay_ms);
358 
359   for (const auto &column : response.columns) {
360     Mysqlx::Resultset::ColumnMetaData metadata_msg;
361     protocol_encoder_.encode_metadata(metadata_msg, column);
362     impl_->send(Mysqlx::ServerMessages::RESULTSET_COLUMN_META_DATA,
363                 metadata_msg);
364   }
365 
366   for (size_t i = 0; i < response.rows.size(); ++i) {
367     auto row = response.rows[i];
368     if (response.columns.size() != row.size()) {
369       throw std::runtime_error(
370           std::string("columns_info.size() != row_values.size() ") +
371           std::to_string(response.columns.size()) + std::string("!=") +
372           std::to_string(row.size()));
373     }
374     Mysqlx::Resultset::Row row_msg;
375     for (const auto &field : row) {
376       const bool is_null = !field.first;
377       protocol_encoder_.encode_row_field(
378           row_msg, protocol_encoder_.column_type_to_x(response.columns[i].type),
379           field.second, is_null);
380     }
381     impl_->send(Mysqlx::ServerMessages::RESULTSET_ROW, row_msg);
382   }
383 
384   Mysqlx::Resultset::FetchDone fetch_done_msg;
385   impl_->send(Mysqlx::ServerMessages::RESULTSET_FETCH_DONE, fetch_done_msg);
386   send_ok();
387 }
388 
389 }  // namespace server_mock
390