1 /*
2   Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
3 
4   This program is free software; you can redistribute it and/or modify
5   it under the terms of the GNU General Public License, version 2.0,
6   as published by the Free Software Foundation.
7 
8   This program is also distributed with certain software (including
9   but not limited to OpenSSL) that is licensed under separate terms,
10   as designated in a particular file or component or in included license
11   documentation.  The authors of MySQL hereby grant you an additional
12   permission to link the program and your derivative works with the
13   separately licensed software that they have included with MySQL.
14 
15   This program is distributed in the hope that it will be useful,
16   but WITHOUT ANY WARRANTY; without even the implied warranty of
17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18   GNU General Public License for more details.
19 
20   You should have received a copy of the GNU General Public License
21   along with this program; if not, write to the Free Software
22   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
23 */
24 
25 #include "classic_mock_session.h"
26 
27 #include <thread>
28 #include "mysql_protocol_utils.h"
29 
30 #include "mysql/harness/logging/logging.h"
31 #include "mysqld_error.h"
32 IMPORT_LOG_FUNCTIONS()
33 
34 namespace server_mock {
35 
process_handshake()36 bool MySQLServerMockSessionClassic::process_handshake() {
37   using namespace mysql_protocol;
38 
39   bool is_first_packet = true;
40 
41   while (!killed_) {
42     std::vector<uint8_t> payload;
43     if (!is_first_packet) {
44       protocol_decoder_.read_message(client_socket_);
45       seq_no_ = protocol_decoder_.packet_seq() + 1;
46       payload = protocol_decoder_.get_payload();
47     }
48     is_first_packet = false;
49     if (true == handle_handshake(json_reader_->handle_handshake(payload))) {
50       // handshake is done
51       return true;
52     }
53   }
54 
55   return false;
56 }
57 
process_statements()58 bool MySQLServerMockSessionClassic::process_statements() {
59   using mysql_protocol::Command;
60 
61   while (!killed_) {
62     protocol_decoder_.read_message(client_socket_);
63     seq_no_ = protocol_decoder_.packet_seq() + 1;
64     auto cmd = protocol_decoder_.get_command_type();
65     switch (cmd) {
66       case Command::QUERY: {
67         std::string statement_received = protocol_decoder_.get_statement();
68 
69         try {
70           handle_statement(json_reader_->handle_statement(statement_received));
71         } catch (const std::exception &e) {
72           // handling statement failed. Return the error to the client
73           std::this_thread::sleep_for(json_reader_->get_default_exec_time());
74           log_error("executing statement failed: %s", e.what());
75           send_error(ER_PARSE_ERROR,
76                      std::string("executing statement failed: ") + e.what());
77 
78           // assume the connection is broken
79           return true;
80         }
81       } break;
82       case Command::QUIT:
83         log_info("received QUIT command from the client");
84         return true;
85       default:
86         std::cerr << "received unsupported command from the client: "
87                   << static_cast<int>(cmd) << "\n";
88         std::this_thread::sleep_for(json_reader_->get_default_exec_time());
89         send_error(ER_PARSE_ERROR,
90                    "Unsupported command: " + std::to_string(cmd));
91     }
92   }
93 
94   return true;
95 }
96 
handle_handshake(const HandshakeResponse & response)97 bool MySQLServerMockSessionClassic::handle_handshake(
98     const HandshakeResponse &response) {
99   using ResponseType = HandshakeResponse::ResponseType;
100 
101   std::this_thread::sleep_for(response.exec_time);
102 
103   switch (response.response_type) {
104     case ResponseType::GREETING: {
105       Greeting *greeting_resp =
106           dynamic_cast<Greeting *>(response.response.get());
107       harness_assert(greeting_resp);
108 
109       send_packet(
110           client_socket_,
111           protocol_encoder_.encode_greetings_message(
112               seq_no_++, greeting_resp->server_version(),
113               greeting_resp->connection_id(), greeting_resp->auth_data(),
114               greeting_resp->capabilities(), greeting_resp->auth_method(),
115               greeting_resp->character_set(), greeting_resp->status_flags()));
116     } break;
117     case ResponseType::AUTH_SWITCH: {
118       AuthSwitch *auth_switch_resp =
119           dynamic_cast<AuthSwitch *>(response.response.get());
120       harness_assert(auth_switch_resp);
121 
122       send_packet(client_socket_, protocol_encoder_.encode_auth_switch_message(
123                                       seq_no_++, auth_switch_resp->method(),
124                                       auth_switch_resp->data()));
125     } break;
126     case ResponseType::AUTH_FAST: {
127       // sha256-fast-auth is
128       // - 0x03
129       // - ok
130       send_packet(client_socket_,
131                   protocol_encoder_.encode_auth_fast_message(seq_no_++));
132 
133       send_ok(0, 0, 0, 0);
134 
135       return true;
136     }
137     case ResponseType::OK: {
138       OkResponse *ok_resp = dynamic_cast<OkResponse *>(response.response.get());
139       harness_assert(ok_resp);
140 
141       send_ok(0, ok_resp->last_insert_id, 0, ok_resp->warning_count);
142 
143       return true;
144     }
145     case ResponseType::ERROR: {
146       ErrorResponse *err_resp =
147           dynamic_cast<ErrorResponse *>(response.response.get());
148       harness_assert(err_resp);
149       send_error(err_resp->code, err_resp->msg);
150 
151       return true;
152     }
153     default:
154       throw std::runtime_error(
155           "Unsupported command in handle_handshake(): " +
156           std::to_string(static_cast<int>(response.response_type)));
157   }
158 
159   return false;
160 }
161 
send_error(const uint16_t error_code,const std::string & error_msg,const std::string & sql_state)162 void MySQLServerMockSessionClassic::send_error(const uint16_t error_code,
163                                                const std::string &error_msg,
164                                                const std::string &sql_state) {
165   auto buf = protocol_encoder_.encode_error_message(seq_no_++, error_code,
166                                                     sql_state, error_msg);
167   send_packet(client_socket_, buf);
168 }
169 
send_ok(const uint64_t affected_rows,const uint64_t last_insert_id,const uint16_t server_status,const uint16_t warning_count)170 void MySQLServerMockSessionClassic::send_ok(const uint64_t affected_rows,
171                                             const uint64_t last_insert_id,
172                                             const uint16_t server_status,
173                                             const uint16_t warning_count) {
174   auto buf = protocol_encoder_.encode_ok_message(
175       seq_no_++, affected_rows, last_insert_id, server_status, warning_count);
176   send_packet(client_socket_, buf);
177 }
178 
send_resultset(const ResultsetResponse & response,const std::chrono::microseconds delay_ms)179 void MySQLServerMockSessionClassic::send_resultset(
180     const ResultsetResponse &response,
181     const std::chrono::microseconds delay_ms) {
182   auto buf = protocol_encoder_.encode_columns_number_message(
183       seq_no_++, response.columns.size());
184   std::this_thread::sleep_for(delay_ms);
185   send_packet(client_socket_, buf);
186   for (const auto &column : response.columns) {
187     auto col_buf =
188         protocol_encoder_.encode_column_meta_message(seq_no_++, column);
189     send_packet(client_socket_, col_buf);
190   }
191   buf = protocol_encoder_.encode_eof_message(seq_no_++);
192   send_packet(client_socket_, buf);
193 
194   for (size_t i = 0; i < response.rows.size(); ++i) {
195     auto res_buf = protocol_encoder_.encode_row_message(
196         seq_no_++, response.columns, response.rows[i]);
197     send_packet(client_socket_, res_buf);
198   }
199   buf = protocol_encoder_.encode_eof_message(seq_no_++);
200   send_packet(client_socket_, buf);
201 }
202 
MySQLServerMockSessionClassic(const socket_t client_sock,std::unique_ptr<StatementReaderBase> statement_processor,const bool debug_mode)203 MySQLServerMockSessionClassic::MySQLServerMockSessionClassic(
204     const socket_t client_sock,
205     std::unique_ptr<StatementReaderBase> statement_processor,
206     const bool debug_mode)
207     : MySQLServerMockSession(client_sock, std::move(statement_processor),
208                              debug_mode),
209       protocol_decoder_{&read_packet} {}
210 
211 }  // namespace server_mock
212