1 /*
2 Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
3
4 This program is free software; you can redistribute it and/or modify
5 it under the terms of the GNU General Public License, version 2.0,
6 as published by the Free Software Foundation.
7
8 This program is also distributed with certain software (including
9 but not limited to OpenSSL) that is licensed under separate terms,
10 as designated in a particular file or component or in included license
11 documentation. The authors of MySQL hereby grant you an additional
12 permission to link the program and your derivative works with the
13 separately licensed software that they have included with MySQL.
14
15 This program is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with this program; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25 #include "router_component_testutils.h"
26
27 #ifdef RAPIDJSON_NO_SIZETYPEDEFINE
28 // if we build within the server, it will set RAPIDJSON_NO_SIZETYPEDEFINE
29 // globally and require to include my_rapidjson_size_t.h
30 #include "my_rapidjson_size_t.h"
31 #endif
32 #include <gmock/gmock.h>
33 #include <rapidjson/document.h>
34 #include <rapidjson/stringbuffer.h>
35 #include <fstream>
36 #include <thread>
37
38 #include "router_test_helpers.h"
39
40 #include "mock_server_rest_client.h"
41
42 namespace {
43 // default allocator for rapidJson (MemoryPoolAllocator) is broken for
44 // SparcSolaris
45 using JsonAllocator = rapidjson::CrtAllocator;
46 using JsonValue = rapidjson::GenericValue<rapidjson::UTF8<>, JsonAllocator>;
47 using JsonDocument =
48 rapidjson::GenericDocument<rapidjson::UTF8<>, JsonAllocator>;
49 using JsonStringBuffer =
50 rapidjson::GenericStringBuffer<rapidjson::UTF8<>, rapidjson::CrtAllocator>;
51 } // namespace
52
53 using namespace std::chrono_literals;
54
create_state_file_content(const std::string & replication_goup_id,const std::vector<uint16_t> & metadata_servers_ports,const unsigned view_id)55 std::string create_state_file_content(
56 const std::string &replication_goup_id,
57 const std::vector<uint16_t> &metadata_servers_ports,
58 const unsigned view_id /*= 0*/) {
59 std::string metadata_servers;
60 for (std::size_t i = 0; i < metadata_servers_ports.size(); i++) {
61 metadata_servers +=
62 "\"mysql://127.0.0.1:" + std::to_string(metadata_servers_ports[i]) +
63 "\"";
64 if (i < metadata_servers_ports.size() - 1) metadata_servers += ",";
65 }
66 std::string view_id_str;
67 if (view_id > 0) view_id_str = R"(, "view-id":)" + std::to_string(view_id);
68 // clang-format off
69 const std::string result =
70 "{"
71 R"("version": "1.0.0",)"
72 R"("metadata-cache": {)"
73 R"("group-replication-id": ")" + replication_goup_id + R"(",)"
74 R"("cluster-metadata-servers": [)" + metadata_servers + "]"
75 + view_id_str +
76 "}"
77 "}";
78 // clang-format on
79
80 return result;
81 }
82
83 #define CHECK_TRUE(expr) \
84 if (!(expr)) return false
85
check_state_file_helper(const std::string & state_file_content,const std::string & expected_group_replication_id,const std::vector<uint16_t> expected_cluster_nodes,const unsigned expected_view_id,const std::string node_address)86 bool check_state_file_helper(const std::string &state_file_content,
87 const std::string &expected_group_replication_id,
88 const std::vector<uint16_t> expected_cluster_nodes,
89 const unsigned expected_view_id /*= 0*/,
90 const std::string node_address /*= "127.0.0.1"*/) {
91 JsonDocument json_doc;
92 json_doc.Parse(state_file_content.c_str());
93 const std::string kExpectedVersion = "1.0.0";
94
95 CHECK_TRUE(json_doc.HasMember("version"));
96 CHECK_TRUE(json_doc["version"].IsString());
97 CHECK_TRUE(kExpectedVersion == json_doc["version"].GetString());
98
99 CHECK_TRUE(json_doc.HasMember("metadata-cache"));
100 CHECK_TRUE(json_doc["metadata-cache"].IsObject());
101
102 auto metadata_cache_section = json_doc["metadata-cache"].GetObject();
103
104 CHECK_TRUE(metadata_cache_section.HasMember("group-replication-id"));
105 CHECK_TRUE(metadata_cache_section["group-replication-id"].IsString());
106 CHECK_TRUE(expected_group_replication_id ==
107 metadata_cache_section["group-replication-id"].GetString());
108
109 if (expected_view_id > 0) {
110 CHECK_TRUE(metadata_cache_section.HasMember("view-id"));
111 CHECK_TRUE(metadata_cache_section["view-id"].IsInt());
112 CHECK_TRUE(
113 expected_view_id ==
114 static_cast<unsigned>(metadata_cache_section["view-id"].GetInt()));
115 }
116
117 CHECK_TRUE(metadata_cache_section.HasMember("cluster-metadata-servers"));
118 CHECK_TRUE(metadata_cache_section["cluster-metadata-servers"].IsArray());
119 auto cluster_nodes =
120 metadata_cache_section["cluster-metadata-servers"].GetArray();
121 CHECK_TRUE(expected_cluster_nodes.size() == cluster_nodes.Size());
122 for (unsigned i = 0; i < cluster_nodes.Size(); ++i) {
123 CHECK_TRUE(cluster_nodes[i].IsString());
124 const std::string expected_cluster_node =
125 "mysql://" + node_address + ":" +
126 std::to_string(expected_cluster_nodes[i]);
127 CHECK_TRUE(expected_cluster_node == cluster_nodes[i].GetString());
128 }
129
130 return true;
131 }
132
check_state_file(const std::string & state_file,const std::string & expected_group_replication_id,const std::vector<uint16_t> expected_cluster_nodes,const unsigned expected_view_id,const std::string node_address)133 void check_state_file(const std::string &state_file,
134 const std::string &expected_group_replication_id,
135 const std::vector<uint16_t> expected_cluster_nodes,
136 const unsigned expected_view_id /*= 0*/,
137 const std::string node_address /*= "127.0.0.1"*/) {
138 bool result = false;
139 size_t steps = 0;
140 std::string state_file_content;
141 do {
142 state_file_content = get_file_output(state_file);
143 result = check_state_file_helper(
144 state_file_content, expected_group_replication_id,
145 expected_cluster_nodes, expected_view_id, node_address);
146 if (!result) {
147 std::this_thread::sleep_for(50ms);
148 }
149 } while ((!result) && (steps++ < 20));
150
151 if (!result) {
152 std::string expected_cluster_nodes_str;
153 for (size_t i = 0; i < expected_cluster_nodes.size(); ++i) {
154 expected_cluster_nodes_str +=
155 std::to_string(expected_cluster_nodes[i]) + " ";
156 }
157
158 FAIL() << "Unexpected state file content." << std::endl
159 << "expected_group_replication_id: " << expected_group_replication_id
160 << std::endl
161 << "expected_cluster_nodes: " << expected_cluster_nodes_str
162 << std::endl
163 << "expected_view_id: " << expected_view_id << std::endl
164 << "node_address: " << node_address << std::endl
165 << "state_file_content: " << state_file_content;
166 }
167
168 // check that we have write access to the file
169 // just append it with an empty line, that will not break it
170 EXPECT_NO_THROW({
171 std::ofstream ofs(state_file, std::ios::app);
172 ofs << "\n";
173 });
174 }
175
get_int_field_value(const std::string & json_string,const std::string & field_name)176 int get_int_field_value(const std::string &json_string,
177 const std::string &field_name) {
178 rapidjson::Document json_doc;
179 json_doc.Parse(json_string.c_str());
180 if (!json_doc.HasMember(field_name.c_str())) {
181 // that can mean this has not been set yet
182 return 0;
183 }
184
185 if (!json_doc[field_name.c_str()].IsInt()) {
186 // that can mean this has not been set yet
187 return 0;
188 }
189
190 return json_doc[field_name.c_str()].GetInt();
191 }
192
get_transaction_count(const std::string & json_string)193 int get_transaction_count(const std::string &json_string) {
194 return get_int_field_value(json_string, "transaction_count");
195 }
196
wait_for_transaction_count(const uint16_t http_port,const int expected_queries_count,std::chrono::milliseconds timeout)197 bool wait_for_transaction_count(const uint16_t http_port,
198 const int expected_queries_count,
199 std::chrono::milliseconds timeout) {
200 const std::chrono::milliseconds kStep = 20ms;
201 do {
202 std::string server_globals =
203 MockServerRestClient(http_port).get_globals_as_json_string();
204 if (get_transaction_count(server_globals) >= expected_queries_count)
205 return true;
206 std::this_thread::sleep_for(kStep);
207 timeout -= kStep;
208 } while (timeout > 0ms);
209
210 return false;
211 }
212
wait_for_transaction_count_increase(const uint16_t http_port,const int increment_by,std::chrono::milliseconds timeout)213 bool wait_for_transaction_count_increase(const uint16_t http_port,
214 const int increment_by,
215 std::chrono::milliseconds timeout) {
216 std::string server_globals =
217 MockServerRestClient(http_port).get_globals_as_json_string();
218 int expected_queries_count =
219 get_transaction_count(server_globals) + increment_by;
220
221 return wait_for_transaction_count(http_port, expected_queries_count, timeout);
222 }
223
wait_connection_dropped(mysqlrouter::MySQLSession & session,std::chrono::milliseconds timeout)224 bool wait_connection_dropped(mysqlrouter::MySQLSession &session,
225 std::chrono::milliseconds timeout) {
226 const auto kStep = 50ms;
227 do {
228 try {
229 session.query_one("select @@@port");
230 } catch (const mysqlrouter::MySQLSession::Error &) {
231 return true;
232 }
233
234 std::this_thread::sleep_for(kStep);
235 timeout -= kStep;
236 } while (timeout >= 0ms);
237
238 return false;
239 }
240