1 /*
2   Copyright (c) 2015, 2020, Oracle and/or its affiliates. All rights reserved.
3 
4   This program is free software; you can redistribute it and/or modify
5   it under the terms of the GNU General Public License, version 2.0,
6   as published by the Free Software Foundation.
7 
8   This program is also distributed with certain software (including
9   but not limited to OpenSSL) that is licensed under separate terms,
10   as designated in a particular file or component or in included license
11   documentation.  The authors of MySQL hereby grant you an additional
12   permission to link the program and your derivative works with the
13   separately licensed software that they have included with MySQL.
14 
15   This program is distributed in the hope that it will be useful,
16   but WITHOUT ANY WARRANTY; without even the implied warranty of
17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18   GNU General Public License for more details.
19 
20   You should have received a copy of the GNU General Public License
21   along with this program; if not, write to the Free Software
22   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
23 */
24 
25 #include "router_test_helpers.h"
26 
27 #include <cassert>
28 #include <cerrno>
29 #include <chrono>
30 #include <cstdlib>
31 #include <cstring>
32 #include <iostream>
33 #include <regex>
34 #include <stdexcept>
35 #include <thread>
36 
37 #ifndef _WIN32
38 #include <sys/socket.h>
39 #include <unistd.h>
40 #else
41 #include <direct.h>
42 #include <windows.h>
43 #include <winsock2.h>
44 #include <ws2tcpip.h>
45 #define getcwd _getcwd
46 #endif
47 
48 #include "keyring/keyring_manager.h"
49 #include "my_inttypes.h"  // ssize_t
50 #include "mysql/harness/filesystem.h"
51 #include "mysqlrouter/mysql_session.h"
52 #include "mysqlrouter/utils.h"
53 
54 using mysql_harness::Path;
55 using namespace std::chrono_literals;
56 
get_cmake_source_dir()57 Path get_cmake_source_dir() {
58   Path result;
59 
60   // PB2 specific source location
61   char *env_pb2workdir = std::getenv("PB2WORKDIR");
62   char *env_sourcename = std::getenv("SOURCENAME");
63   char *env_tmpdir = std::getenv("TMPDIR");
64   if ((env_pb2workdir && env_sourcename && env_tmpdir) &&
65       (strlen(env_pb2workdir) && strlen(env_tmpdir) &&
66        strlen(env_sourcename))) {
67     result = Path(env_tmpdir);
68     result.append(Path(env_sourcename));
69     if (result.exists()) {
70       return result;
71     }
72   }
73 
74   char *env_value = std::getenv("CMAKE_SOURCE_DIR");
75 
76   if (env_value == nullptr) {
77     // try a few places
78     result = Path(get_cwd()).join("..");
79     result = Path(result).real_path();
80   } else {
81     result = Path(env_value).real_path();
82   }
83 
84   if (!result.join("src")
85            .join("router")
86            .join("src")
87            .join("router_app.cc")
88            .is_regular()) {
89     throw std::runtime_error(
90         "Source directory not available. Use CMAKE_SOURCE_DIR environment "
91         "variable; was " +
92         result.str());
93   }
94 
95   return result;
96 }
97 
get_envvar_path(const std::string & envvar,Path alternative=Path ())98 Path get_envvar_path(const std::string &envvar, Path alternative = Path()) {
99   char *env_value = std::getenv(envvar.c_str());
100   Path result;
101   if (env_value == nullptr) {
102     result = alternative;
103   } else {
104     result = Path(env_value).real_path();
105   }
106   return result;
107 }
108 
get_cwd()109 const std::string get_cwd() {
110   char buffer[FILENAME_MAX];
111   if (!getcwd(buffer, FILENAME_MAX)) {
112     throw std::runtime_error("getcwd failed: " + std::string(strerror(errno)));
113   }
114   return std::string(buffer);
115 }
116 
change_cwd(std::string & dir)117 const std::string change_cwd(std::string &dir) {
118   auto cwd = get_cwd();
119 #ifndef _WIN32
120   if (chdir(dir.c_str()) == -1) {
121 #else
122   if (!SetCurrentDirectory(dir.c_str())) {
123 #endif
124     throw std::runtime_error("chdir failed: " + mysqlrouter::get_last_error());
125   }
126   return cwd;
127 }
128 
129 size_t read_bytes_with_timeout(int sockfd, void *buffer, size_t n_bytes,
130                                uint64_t timeout_in_ms) {
131   // returns epoch time (aka unix time, etc), expressed in milliseconds
132   auto get_epoch_in_ms = []() -> uint64_t {
133     using namespace std::chrono;
134     time_point<system_clock> now = system_clock::now();
135     return static_cast<uint64_t>(
136         duration_cast<milliseconds>(now.time_since_epoch()).count());
137   };
138 
139   // calculate deadline time
140   uint64_t now_in_ms = get_epoch_in_ms();
141   uint64_t deadline_epoch_in_ms = now_in_ms + timeout_in_ms;
142 
143   // read until 1 of 3 things happen: enough bytes were read, we time out or
144   // read() fails
145   size_t bytes_read = 0;
146   while (true) {
147 #ifndef _WIN32
148     ssize_t res = read(sockfd, static_cast<char *>(buffer) + bytes_read,
149                        n_bytes - bytes_read);
150 #else
151     WSASetLastError(0);
152     ssize_t res = recv(sockfd, static_cast<char *>(buffer) + bytes_read,
153                        n_bytes - bytes_read, 0);
154 #endif
155 
156     if (res == 0) {  // reached EOF?
157       return bytes_read;
158     }
159 
160     if (get_epoch_in_ms() > deadline_epoch_in_ms) {
161       throw std::runtime_error("read() timed out");
162     }
163 
164     if (res == -1) {
165 #ifndef _WIN32
166       if (errno != EAGAIN) {
167         throw std::runtime_error(std::string("read() failed: ") +
168                                  strerror(errno));
169       }
170 #else
171       int err_code = WSAGetLastError();
172       if (err_code != 0) {
173         throw std::runtime_error("recv() failed with error: " +
174                                  get_last_error(err_code));
175       }
176 
177 #endif
178     } else {
179       bytes_read += static_cast<size_t>(res);
180       if (bytes_read >= n_bytes) {
181         assert(bytes_read == n_bytes);
182         return bytes_read;
183       }
184     }
185 
186     std::this_thread::sleep_for(std::chrono::milliseconds(10));
187   }
188 }
189 
190 #ifdef _WIN32
191 std::string get_last_error(int err_code) {
192   char message[512];
193   FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS |
194                     FORMAT_MESSAGE_ALLOCATE_BUFFER,
195                 nullptr, err_code, LANG_NEUTRAL, message, sizeof(message),
196                 nullptr);
197   return std::string(message);
198 }
199 #endif
200 
201 void init_windows_sockets() {
202 #ifdef _WIN32
203   WSADATA wsaData;
204   int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
205   if (iResult != 0) {
206     std::cerr << "WSAStartup() failed\n";
207     exit(1);
208   }
209 #endif
210 }
211 
212 bool pattern_found(const std::string &s, const std::string &pattern) {
213   bool result = false;
214   try {
215     std::smatch m;
216     std::regex r(pattern);
217     result = std::regex_search(s, m, r);
218   } catch (const std::regex_error &e) {
219     std::cerr << ">" << e.what();
220   }
221 
222   return result;
223 }
224 
225 namespace {
226 #ifndef _WIN32
227 int close_socket(int sock) {
228   ::shutdown(sock, SHUT_RDWR);
229   return close(sock);
230 }
231 #else
232 int close_socket(SOCKET sock) {
233   ::shutdown(sock, SD_BOTH);
234   return closesocket(sock);
235 }
236 #endif
237 }  // namespace
238 
239 bool wait_for_port_ready(uint16_t port, std::chrono::milliseconds timeout,
240                          const std::string &hostname) {
241   struct addrinfo hints, *ainfo;
242   memset(&hints, 0, sizeof hints);
243   hints.ai_family = AF_UNSPEC;
244   hints.ai_socktype = SOCK_STREAM;
245   hints.ai_flags = AI_PASSIVE;
246 
247   // Valgrind needs way more time
248   if (getenv("WITH_VALGRIND")) {
249     timeout *= 10;
250   }
251 
252   int status = getaddrinfo(hostname.c_str(), std::to_string(port).c_str(),
253                            &hints, &ainfo);
254   if (status != 0) {
255     throw std::runtime_error(
256         std::string("wait_for_port_ready(): getaddrinfo() failed: ") +
257         gai_strerror(status));
258   }
259   std::shared_ptr<void> exit_freeaddrinfo(nullptr,
260                                           [&](void *) { freeaddrinfo(ainfo); });
261 
262   const auto MSEC_STEP = 10ms;
263   const auto started = std::chrono::steady_clock::now();
264   do {
265     auto sock_id =
266         socket(ainfo->ai_family, ainfo->ai_socktype, ainfo->ai_protocol);
267     if (sock_id < 0) {
268       throw std::runtime_error("wait_for_port_ready(): socket() failed: " +
269                                std::to_string(mysqlrouter::get_socket_errno()));
270     }
271     std::shared_ptr<void> exit_close_socket(
272         nullptr, [&](void *) { close_socket(sock_id); });
273 
274 #ifdef _WIN32
275     // On Windows if the port is not ready yet when we try the connect() first
276     // time it will block for 500ms (depends on the OS wide configuration) and
277     // retry again internally. Here we sleep for 100ms but will save this 500ms
278     // for most of the cases which is still a good deal
279     std::this_thread::sleep_for(100ms);
280 #endif
281     status = connect(sock_id, ainfo->ai_addr, ainfo->ai_addrlen);
282     if (status < 0) {
283       // if the address is not available, it is a client side problem.
284 #ifdef _WIN32
285       if (WSAGetLastError() == WSAEADDRNOTAVAIL) {
286         throw std::system_error(mysqlrouter::get_socket_errno(),
287                                 std::system_category());
288       }
289 #else
290       if (errno == EADDRNOTAVAIL) {
291         throw std::system_error(mysqlrouter::get_socket_errno(),
292                                 std::generic_category());
293       }
294 #endif
295       const auto step = std::min(timeout, MSEC_STEP);
296       std::this_thread::sleep_for(std::chrono::milliseconds(step));
297       timeout -= step;
298     }
299   } while (status < 0 && timeout > std::chrono::steady_clock::now() - started);
300 
301   return status >= 0;
302 }
303 
304 void init_keyring(std::map<std::string, std::string> &default_section,
305                   const std::string &keyring_dir,
306                   const std::string &user /*= "mysql_router1_user"*/,
307                   const std::string &password /*= "root"*/) {
308   // init keyring
309   const std::string masterkey_file = Path(keyring_dir).join("master.key").str();
310   const std::string keyring_file = Path(keyring_dir).join("keyring").str();
311   mysql_harness::init_keyring(keyring_file, masterkey_file, true);
312   mysql_harness::Keyring *keyring = mysql_harness::get_keyring();
313   keyring->store(user, "password", password);
314   mysql_harness::flush_keyring();
315   mysql_harness::reset_keyring();
316 
317   // add relevant config settings to [DEFAULT] section
318   default_section["keyring_path"] = keyring_file;
319   default_section["master_key_path"] = masterkey_file;
320 }
321 
322 namespace {
323 
324 bool real_find_in_file(
325     const std::string &file_path,
326     const std::function<bool(const std::string &)> &predicate,
327     std::ifstream &in_file, std::streampos &cur_pos) {
328   if (!in_file.is_open()) {
329     in_file.clear();
330     Path file(file_path);
331     in_file.open(file.c_str(), std::ifstream::in);
332     if (!in_file) {
333       throw std::runtime_error("Error opening file " + file.str());
334     }
335     cur_pos = in_file.tellg();  // initialize properly
336   } else {
337     // set current position to the end of what was already read
338     in_file.clear();
339     in_file.seekg(cur_pos);
340   }
341 
342   std::string line;
343   while (std::getline(in_file, line)) {
344     cur_pos = in_file.tellg();
345     if (predicate(line)) {
346       return true;
347     }
348   }
349 
350   return false;
351 }
352 
353 }  // namespace
354 
355 bool find_in_file(const std::string &file_path,
356                   const std::function<bool(const std::string &)> &predicate,
357                   std::chrono::milliseconds sleep_time) {
358   const auto STEP = std::chrono::milliseconds(100);
359   std::ifstream in_file;
360   std::streampos cur_pos;
361   do {
362     try {
363       // This is proxy function to account for the fact that I/O can sometimes
364       // be slow.
365       if (real_find_in_file(file_path, predicate, in_file, cur_pos))
366         return true;
367     } catch (const std::runtime_error &) {
368       // report I/O error only on the last attempt
369       if (sleep_time == std::chrono::milliseconds(0)) {
370         std::cerr << "  find_in_file() failed, giving up." << std::endl;
371         throw;
372       }
373     }
374 
375     const auto sleep_for = std::min(STEP, sleep_time);
376     std::this_thread::sleep_for(sleep_for);
377     sleep_time -= sleep_for;
378 
379   } while (sleep_time > std::chrono::milliseconds(0));
380 
381   return false;
382 }
383 
384 std::string get_file_output(const std::string &file_name,
385                             const std::string &file_path,
386                             bool throw_on_error /*=false*/) {
387   return get_file_output(file_path + "/" + file_name, throw_on_error);
388 }
389 
390 std::string get_file_output(const std::string &file_name,
391                             bool throw_on_error /*=false*/) {
392   Path file(file_name);
393   std::ifstream in_file;
394   in_file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
395   try {
396     in_file.open(file.c_str(), std::ifstream::in);
397   } catch (const std::exception &e) {
398     const std::string msg =
399         "Could not open file '" + file.str() + "' for reading: ";
400     if (throw_on_error)
401       throw std::runtime_error(msg + e.what());
402     else
403       return "<THIS ERROR COMES FROM TEST FRAMEWORK'S get_file_output(), IT IS "
404              "NOT PART OF PROCESS OUTPUT: " +
405              msg + e.what() + ">";
406   }
407   assert(in_file);
408 
409   std::string result;
410   try {
411     result.assign((std::istreambuf_iterator<char>(in_file)),
412                   std::istreambuf_iterator<char>());
413   } catch (const std::exception &e) {
414     const std::string msg = "Reading file '" + file.str() + "' failed: ";
415     if (throw_on_error)
416       throw std::runtime_error(msg + e.what());
417     else
418       return "<THIS ERROR COMES FROM TEST FRAMEWORK'S get_file_output(), IT IS "
419              "NOT PART OF PROCESS OUTPUT: " +
420              msg + e.what() + ">";
421   }
422 
423   return result;
424 }
425 
426 void connect_client_and_query_port(unsigned router_port, std::string &out_port,
427                                    bool should_fail) {
428   using mysqlrouter::MySQLSession;
429   MySQLSession client;
430 
431   if (should_fail) {
432     try {
433       client.connect("127.0.0.1", router_port, "username", "password", "", "");
434     } catch (const std::exception &exc) {
435       if (std::string(exc.what()).find("Error connecting to MySQL server") !=
436           std::string::npos) {
437         out_port = "";
438         return;
439       } else
440         throw;
441     }
442     throw std::runtime_error(
443         "connect_client_and_query_port: did not fail as expected");
444 
445   } else {
446     client.connect("127.0.0.1", router_port, "username", "password", "", "");
447   }
448 
449   std::unique_ptr<MySQLSession::ResultRow> result{
450       client.query_one("select @@port")};
451   if (nullptr == result.get()) {
452     throw std::runtime_error(
453         "connect_client_and_query_port: error querying the port");
454   }
455   if (1u != result->size()) {
456     throw std::runtime_error(
457         "connect_client_and_query_port: wrong number of columns returned " +
458         std::to_string(result->size()));
459   }
460   out_port = std::string((*result)[0]);
461 }
462