1 /*
2   Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
3 
4   This program is free software; you can redistribute it and/or modify
5   it under the terms of the GNU General Public License, version 2.0,
6   as published by the Free Software Foundation.
7 
8   This program is also distributed with certain software (including
9   but not limited to OpenSSL) that is licensed under separate terms,
10   as designated in a particular file or component or in included license
11   documentation.  The authors of MySQL hereby grant you an additional
12   permission to link the program and your derivative works with the
13   separately licensed software that they have included with MySQL.
14 
15   This program is distributed in the hope that it will be useful,
16   but WITHOUT ANY WARRANTY; without even the implied warranty of
17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18   GNU General Public License for more details.
19 
20   You should have received a copy of the GNU General Public License
21   along with this program; if not, write to the Free Software
22   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
23 */
24 
25 #include "router_component_testutils.h"
26 
27 #ifdef RAPIDJSON_NO_SIZETYPEDEFINE
28 // if we build within the server, it will set RAPIDJSON_NO_SIZETYPEDEFINE
29 // globally and require to include my_rapidjson_size_t.h
30 #include "my_rapidjson_size_t.h"
31 #endif
32 #include <gmock/gmock.h>
33 #include <rapidjson/document.h>
34 #include <rapidjson/stringbuffer.h>
35 #include <fstream>
36 #include <thread>
37 
38 #include "router_test_helpers.h"
39 
40 #include "mock_server_rest_client.h"
41 
42 namespace {
43 // default allocator for rapidJson (MemoryPoolAllocator) is broken for
44 // SparcSolaris
45 using JsonAllocator = rapidjson::CrtAllocator;
46 using JsonValue = rapidjson::GenericValue<rapidjson::UTF8<>, JsonAllocator>;
47 using JsonDocument =
48     rapidjson::GenericDocument<rapidjson::UTF8<>, JsonAllocator>;
49 using JsonStringBuffer =
50     rapidjson::GenericStringBuffer<rapidjson::UTF8<>, rapidjson::CrtAllocator>;
51 }  // namespace
52 
53 using namespace std::chrono_literals;
54 
create_state_file_content(const std::string & replication_goup_id,const std::vector<uint16_t> & metadata_servers_ports,const unsigned view_id)55 std::string create_state_file_content(
56     const std::string &replication_goup_id,
57     const std::vector<uint16_t> &metadata_servers_ports,
58     const unsigned view_id /*= 0*/) {
59   std::string metadata_servers;
60   for (std::size_t i = 0; i < metadata_servers_ports.size(); i++) {
61     metadata_servers +=
62         "\"mysql://127.0.0.1:" + std::to_string(metadata_servers_ports[i]) +
63         "\"";
64     if (i < metadata_servers_ports.size() - 1) metadata_servers += ",";
65   }
66   std::string view_id_str;
67   if (view_id > 0) view_id_str = R"(, "view-id":)" + std::to_string(view_id);
68   // clang-format off
69   const std::string result =
70     "{"
71        R"("version": "1.0.0",)"
72        R"("metadata-cache": {)"
73          R"("group-replication-id": ")" + replication_goup_id + R"(",)"
74          R"("cluster-metadata-servers": [)" + metadata_servers + "]"
75          + view_id_str +
76         "}"
77       "}";
78   // clang-format on
79 
80   return result;
81 }
82 
83 #define CHECK_TRUE(expr) \
84   if (!(expr)) return false
85 
check_state_file_helper(const std::string & state_file_content,const std::string & expected_group_replication_id,const std::vector<uint16_t> expected_cluster_nodes,const unsigned expected_view_id,const std::string node_address)86 bool check_state_file_helper(const std::string &state_file_content,
87                              const std::string &expected_group_replication_id,
88                              const std::vector<uint16_t> expected_cluster_nodes,
89                              const unsigned expected_view_id /*= 0*/,
90                              const std::string node_address /*= "127.0.0.1"*/) {
91   JsonDocument json_doc;
92   json_doc.Parse(state_file_content.c_str());
93   const std::string kExpectedVersion = "1.0.0";
94 
95   CHECK_TRUE(json_doc.HasMember("version"));
96   CHECK_TRUE(json_doc["version"].IsString());
97   CHECK_TRUE(kExpectedVersion == json_doc["version"].GetString());
98 
99   CHECK_TRUE(json_doc.HasMember("metadata-cache"));
100   CHECK_TRUE(json_doc["metadata-cache"].IsObject());
101 
102   auto metadata_cache_section = json_doc["metadata-cache"].GetObject();
103 
104   CHECK_TRUE(metadata_cache_section.HasMember("group-replication-id"));
105   CHECK_TRUE(metadata_cache_section["group-replication-id"].IsString());
106   CHECK_TRUE(expected_group_replication_id ==
107              metadata_cache_section["group-replication-id"].GetString());
108 
109   if (expected_view_id > 0) {
110     CHECK_TRUE(metadata_cache_section.HasMember("view-id"));
111     CHECK_TRUE(metadata_cache_section["view-id"].IsInt());
112     CHECK_TRUE(
113         expected_view_id ==
114         static_cast<unsigned>(metadata_cache_section["view-id"].GetInt()));
115   }
116 
117   CHECK_TRUE(metadata_cache_section.HasMember("cluster-metadata-servers"));
118   CHECK_TRUE(metadata_cache_section["cluster-metadata-servers"].IsArray());
119   auto cluster_nodes =
120       metadata_cache_section["cluster-metadata-servers"].GetArray();
121   CHECK_TRUE(expected_cluster_nodes.size() == cluster_nodes.Size());
122   for (unsigned i = 0; i < cluster_nodes.Size(); ++i) {
123     CHECK_TRUE(cluster_nodes[i].IsString());
124     const std::string expected_cluster_node =
125         "mysql://" + node_address + ":" +
126         std::to_string(expected_cluster_nodes[i]);
127     CHECK_TRUE(expected_cluster_node == cluster_nodes[i].GetString());
128   }
129 
130   return true;
131 }
132 
check_state_file(const std::string & state_file,const std::string & expected_group_replication_id,const std::vector<uint16_t> expected_cluster_nodes,const unsigned expected_view_id,const std::string node_address)133 void check_state_file(const std::string &state_file,
134                       const std::string &expected_group_replication_id,
135                       const std::vector<uint16_t> expected_cluster_nodes,
136                       const unsigned expected_view_id /*= 0*/,
137                       const std::string node_address /*= "127.0.0.1"*/) {
138   bool result = false;
139   size_t steps = 0;
140   std::string state_file_content;
141   do {
142     state_file_content = get_file_output(state_file);
143     result = check_state_file_helper(
144         state_file_content, expected_group_replication_id,
145         expected_cluster_nodes, expected_view_id, node_address);
146     if (!result) {
147       std::this_thread::sleep_for(50ms);
148     }
149   } while ((!result) && (steps++ < 20));
150 
151   if (!result) {
152     std::string expected_cluster_nodes_str;
153     for (size_t i = 0; i < expected_cluster_nodes.size(); ++i) {
154       expected_cluster_nodes_str +=
155           std::to_string(expected_cluster_nodes[i]) + " ";
156     }
157 
158     FAIL() << "Unexpected state file content." << std::endl
159            << "expected_group_replication_id: " << expected_group_replication_id
160            << std::endl
161            << "expected_cluster_nodes: " << expected_cluster_nodes_str
162            << std::endl
163            << "expected_view_id: " << expected_view_id << std::endl
164            << "node_address: " << node_address << std::endl
165            << "state_file_content: " << state_file_content;
166   }
167 
168   // check that we have write access to the file
169   // just append it with an empty line, that will not break it
170   EXPECT_NO_THROW({
171     std::ofstream ofs(state_file, std::ios::app);
172     ofs << "\n";
173   });
174 }
175 
get_int_field_value(const std::string & json_string,const std::string & field_name)176 int get_int_field_value(const std::string &json_string,
177                         const std::string &field_name) {
178   rapidjson::Document json_doc;
179   json_doc.Parse(json_string.c_str());
180   if (!json_doc.HasMember(field_name.c_str())) {
181     // that can mean this has not been set yet
182     return 0;
183   }
184 
185   if (!json_doc[field_name.c_str()].IsInt()) {
186     // that can mean this has not been set yet
187     return 0;
188   }
189 
190   return json_doc[field_name.c_str()].GetInt();
191 }
192 
get_transaction_count(const std::string & json_string)193 int get_transaction_count(const std::string &json_string) {
194   return get_int_field_value(json_string, "transaction_count");
195 }
196 
wait_for_transaction_count(const uint16_t http_port,const int expected_queries_count,std::chrono::milliseconds timeout)197 bool wait_for_transaction_count(const uint16_t http_port,
198                                 const int expected_queries_count,
199                                 std::chrono::milliseconds timeout) {
200   const std::chrono::milliseconds kStep = 20ms;
201   do {
202     std::string server_globals =
203         MockServerRestClient(http_port).get_globals_as_json_string();
204     if (get_transaction_count(server_globals) >= expected_queries_count)
205       return true;
206     std::this_thread::sleep_for(kStep);
207     timeout -= kStep;
208   } while (timeout > 0ms);
209 
210   return false;
211 }
212 
wait_for_transaction_count_increase(const uint16_t http_port,const int increment_by,std::chrono::milliseconds timeout)213 bool wait_for_transaction_count_increase(const uint16_t http_port,
214                                          const int increment_by,
215                                          std::chrono::milliseconds timeout) {
216   std::string server_globals =
217       MockServerRestClient(http_port).get_globals_as_json_string();
218   int expected_queries_count =
219       get_transaction_count(server_globals) + increment_by;
220 
221   return wait_for_transaction_count(http_port, expected_queries_count, timeout);
222 }
223 
wait_connection_dropped(mysqlrouter::MySQLSession & session,std::chrono::milliseconds timeout)224 bool wait_connection_dropped(mysqlrouter::MySQLSession &session,
225                              std::chrono::milliseconds timeout) {
226   const auto kStep = 50ms;
227   do {
228     try {
229       session.query_one("select @@@port");
230     } catch (const mysqlrouter::MySQLSession::Error &) {
231       return true;
232     }
233 
234     std::this_thread::sleep_for(kStep);
235     timeout -= kStep;
236   } while (timeout >= 0ms);
237 
238   return false;
239 }
240