1 /*
2  * Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License, version 2.0,
6  * as published by the Free Software Foundation.
7  *
8  * This program is also distributed with certain software (including
9  * but not limited to OpenSSL) that is licensed under separate terms,
10  * as designated in a particular file or component or in included license
11  * documentation.  The authors of MySQL hereby grant you an additional
12  * permission to link the program and your derivative works with the
13  * separately licensed software that they have included with MySQL.
14  *
15  * This program is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  * GNU General Public License, version 2.0, for more details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with this program; if not, write to the Free Software
22  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
23  */
24 
25 #include <cstdint>
26 #include <fstream>
27 #include <stdexcept>
28 
29 #include "my_dbug.h"      // NOLINT(build/include_subdir)
30 #include "my_loglevel.h"  // NOLINT(build/include_subdir)
31 #include "my_sys.h"       // NOLINT(build/include_subdir)
32 #include "violite.h"      // NOLINT(build/include_subdir)
33 
34 #include "plugin/x/tests/driver/driver_command_line_options.h"
35 #include "plugin/x/tests/driver/processor/stream_processor.h"
36 
ignore_traces_from_libraries(enum loglevel ll,uint32_t ecode,va_list args)37 static void ignore_traces_from_libraries(enum loglevel ll, uint32_t ecode,
38                                          va_list args) {}
39 
parse_mysql_connstring(const std::string & connstring,std::string * protocol,std::string * user,std::string * password,std::string * host,int * port,std::string * sock,std::string * db,int * pwd_found=nullptr)40 bool parse_mysql_connstring(const std::string &connstring,
41                             std::string *protocol, std::string *user,
42                             std::string *password, std::string *host, int *port,
43                             std::string *sock, std::string *db,
44                             int *pwd_found = nullptr) {
45   // format is [protocol://][user[:pass]]@host[:port][/db] or
46   // user[:pass]@::socket[/db], like what cmdline utilities use
47   if (pwd_found) *pwd_found = 0;
48   std::string remaining = connstring;
49   std::string::size_type p;
50   p = remaining.find("://");
51   if (p != std::string::npos) {
52     *protocol = connstring.substr(0, p);
53     remaining = remaining.substr(p + 3);
54   }
55   std::string s = remaining;
56   p = remaining.find('/');
57   if (p != std::string::npos) {
58     *db = remaining.substr(p + 1);
59     s = remaining.substr(0, p);
60   }
61   p = s.rfind('@');
62   std::string user_part;
63   std::string server_part = (p == std::string::npos) ? s : s.substr(p + 1);
64   if (p == std::string::npos) {
65 // by default, connect using the current OS username
66 #ifdef _WIN32
67     char tmp_buffer[1024];
68     char *tmp = tmp_buffer;
69     DWORD tmp_size = sizeof(tmp_buffer);
70     if (!GetUserNameA(tmp_buffer, &tmp_size)) {
71       tmp = NULL;
72     }
73 #else
74     const char *tmp = getenv("USER");
75 #endif
76     user_part = tmp ? tmp : "";
77   } else {
78     user_part = s.substr(0, p);
79   }
80   if ((p = user_part.find(':')) != std::string::npos) {
81     *user = user_part.substr(0, p);
82     *password = user_part.substr(p + 1);
83     if (pwd_found) *pwd_found = 1;
84   } else {
85     *user = user_part;
86   }
87   p = server_part.find(':');
88   if (p != std::string::npos) {
89     *host = server_part.substr(0, p);
90     server_part = server_part.substr(p + 1);
91     p = server_part.find(':');
92     if (p != std::string::npos)
93       *sock = server_part.substr(p + 1);
94     else if (!sscanf(server_part.substr(0, p).c_str(), "%i", port))
95       return false;
96   } else {
97     *host = server_part;
98   }
99   return true;
100 }
101 
parse_mysql_connstring(const std::string & uri,Connection_options * options)102 bool parse_mysql_connstring(const std::string &uri,
103                             Connection_options *options) {
104   int pwdfound;
105   std::string proto;
106   return parse_mysql_connstring(uri, &proto, &options->user, &options->password,
107                                 &options->host, &options->port,
108                                 &options->socket, &options->schema, &pwdfound);
109 }
110 
client_connect_and_process(const Driver_command_line_options & options,std::istream & input)111 int client_connect_and_process(const Driver_command_line_options &options,
112                                std::istream &input) {
113   Variable_container variables(options.m_variables);
114   Console console(options.m_console_options);
115   Connection_manager cm{options.m_connection_options, &variables, console};
116   Execution_context context(options.m_context_options, &cm, &variables,
117                             console);
118 
119   try {
120     context.m_script_stack.push({0, "main"});
121 
122     cm.connect_default(options.m_cap_expired_password,
123                        options.m_client_interactive, options.m_run_without_auth,
124                        options.m_connect_attrs);
125 
126     std::vector<Block_processor_ptr> eaters = create_block_processors(&context);
127     int result_code =
128         process_client_input(input, &eaters, &context.m_script_stack, console);
129 
130     if (!options.m_run_without_auth) cm.close_active(true);
131 
132     return result_code;
133   } catch (const xcl::XError &e) {
134     if (options.is_expected_error_set() &&
135         options.m_expected_error_code == e.error()) {
136       console.print("Application terminated with expected error: ", e.what(),
137                     " (code ", e.error(), ")\n");
138       return 0;
139     }
140     console.print_error_red(context.m_script_stack, e, '\n');
141 
142     return 1;
143   }
144 }
145 
get_input(Driver_command_line_options * opt,std::ifstream & file,std::stringstream & string)146 std::istream &get_input(Driver_command_line_options *opt, std::ifstream &file,
147                         std::stringstream &string) {
148   if (opt->m_has_file) {
149     if (!opt->m_sql.empty()) {
150       std::cerr << "ERROR: specified file and SQL to execute, please enter "
151                    "only one of those\n";
152       opt->exit_code = 1;
153     }
154 
155     file.open(opt->m_run_file.c_str());
156     file.rdbuf()->pubsetbuf(nullptr, 0);
157 
158     if (!file.is_open()) {
159       std::cerr << "ERROR: Could not open file " << opt->m_run_file << "\n";
160       opt->exit_code = 1;
161     }
162 
163     return file;
164   }
165 
166   if (!opt->m_sql.empty()) {
167     std::streampos position = string.tellp();
168 
169     string << "-->sql\n";
170     string << opt->m_sql << "\n";
171     string << "-->endsql\n";
172     string.seekp(position, std::ios::beg);
173 
174     return string;
175   }
176 
177   return std::cin;
178 }
179 
unable_daemonize()180 void unable_daemonize() {
181   std::cerr << "ERROR: Unable to put process in background\n";
182   exit(2);
183 }
184 
daemonize()185 static void daemonize() {
186 #ifdef WIN32
187   unable_daemonize();
188 #else
189   if (getppid() == 1)  // already a daemon
190     exit(0);
191   pid_t pid = fork();
192   if (pid < 0) unable_daemonize();
193   if (pid > 0) exit(0);
194   if (setsid() < 0) unable_daemonize();
195 #endif
196 }
197 
main(int argc,char ** argv)198 int main(int argc, char **argv) {
199   MY_INIT(argv[0]);
200   DBUG_TRACE;
201 
202   local_message_hook = ignore_traces_from_libraries;
203 
204   Driver_command_line_options options(argc, argv);
205 
206   if (options.exit_code != 0) return options.exit_code;
207 
208   if (options.m_daemon) daemonize();
209 
210   std::cout << std::unitbuf;
211   std::ifstream fs;
212   std::stringstream ss;
213   std::istream &input = get_input(&options, fs, ss);
214   if (options.m_uri.length()) {
215     parse_mysql_connstring(options.m_uri, &options.m_connection_options);
216   }
217 #ifdef WIN32
218   if (!have_tcpip) {
219     std::cerr << "OS doesn't have tcpip\n";
220     return 1;
221   }
222 #endif
223 
224   ssl_start();
225 
226   bool return_code = false;
227   try {
228     return_code = client_connect_and_process(options, input);
229     const bool is_ok = 0 == return_code;
230 
231     if (is_ok)
232       std::cerr << "ok\n";
233     else
234       std::cerr << "not ok\n";
235   } catch (std::exception &e) {
236     std::cerr << "ERROR: " << e.what() << "\n";
237     return_code = true;
238   }
239 
240   vio_end();
241   my_end(0);
242   return return_code;
243 }
244