1 /*
2    Copyright (c) 2008, 2021, Oracle and/or its affiliates.
3     All rights reserved. Use is subject to license terms.
4 
5    This program is free software; you can redistribute it and/or modify
6    it under the terms of the GNU General Public License, version 2.0,
7    as published by the Free Software Foundation.
8 
9    This program is also distributed with certain software (including
10    but not limited to OpenSSL) that is licensed under separate terms,
11    as designated in a particular file or component or in included license
12    documentation.  The authors of MySQL hereby grant you an additional
13    permission to link the program and your derivative works with the
14    separately licensed software that they have included with MySQL.
15 
16    This program is distributed in the hope that it will be useful,
17    but WITHOUT ANY WARRANTY; without even the implied warranty of
18    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19    GNU General Public License, version 2.0, for more details.
20 
21    You should have received a copy of the GNU General Public License
22    along with this program; if not, write to the Free Software
23    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
24 */
25 
26 #include <SqlClient.hpp>
27 #include <NDBT_Output.hpp>
28 #include <NdbSleep.h>
29 
SqlClient(const char * _user,const char * _password,const char * _group_suffix)30 SqlClient::SqlClient(const char* _user,
31                        const char* _password,
32                        const char* _group_suffix):
33   connected(false),
34   mysql(NULL),
35   free_mysql(false)
36 {
37 
38   const char* env= getenv("MYSQL_HOME");
39   if (env && strlen(env))
40   {
41     default_file.assfmt("%s/my.cnf", env);
42   }
43 
44   if (_group_suffix != NULL){
45     default_group.assfmt("client%s", _group_suffix);
46   }
47   else {
48     default_group.assign("client.1.atrt");
49   }
50 
51   g_info << "default_file: " << default_file.c_str() << endl;
52   g_info << "default_group: " << default_group.c_str() << endl;
53 
54   user.assign(_user);
55   password.assign(_password);
56 }
57 
58 
SqlClient(MYSQL * mysql)59 SqlClient::SqlClient(MYSQL* mysql):
60   connected(true),
61   mysql(mysql),
62   free_mysql(false)
63 {
64 }
65 
66 
~SqlClient()67 SqlClient::~SqlClient(){
68   disconnect();
69 }
70 
71 
72 bool
isConnected()73 SqlClient::isConnected(){
74   if (connected == true)
75   {
76     require(mysql);
77     return true;
78   }
79   return connect() == 0;
80 }
81 
82 
83 int
connect()84 SqlClient::connect(){
85   disconnect();
86 
87 //  mysql_debug("d:t:O,/tmp/client.trace");
88 
89   if ((mysql= mysql_init(NULL)) == NULL){
90     g_err << "mysql_init failed" << endl;
91     return -1;
92   }
93 
94   /* Load connection parameters file and group */
95   if (mysql_options(mysql, MYSQL_READ_DEFAULT_FILE, default_file.c_str()) ||
96       mysql_options(mysql, MYSQL_READ_DEFAULT_GROUP, default_group.c_str()))
97   {
98     g_err << "mysql_options failed" << endl;
99     disconnect();
100     return 1;
101   }
102 
103   /*
104     Connect, read settings from my.cnf
105     NOTE! user and password can be stored there as well
106    */
107   if (mysql_real_connect(mysql, NULL, user.c_str(),
108                          password.c_str(), "atrt", 0, NULL, 0) == NULL)
109   {
110     g_err  << "Connection to atrt server failed: "<< mysql_error(mysql) << endl;
111     disconnect();
112     return -1;
113   }
114 
115   g_err << "Connected to MySQL " << mysql_get_server_info(mysql)<< endl;
116 
117   connected = true;
118   return 0;
119 }
120 
121 
122 bool
waitConnected(int timeout)123 SqlClient::waitConnected(int timeout) {
124   timeout*= 10;
125   while(!isConnected()){
126     if (timeout-- == 0)
127       return false;
128     NdbSleep_MilliSleep(100);
129   }
130   return true;
131 }
132 
133 
134 void
disconnect()135 SqlClient::disconnect(){
136   if (mysql != NULL){
137     if (free_mysql)
138       mysql_close(mysql);
139     mysql= NULL;
140   }
141   connected = false;
142 }
143 
144 
is_int_type(enum_field_types type)145 static bool is_int_type(enum_field_types type){
146   switch(type){
147   case MYSQL_TYPE_TINY:
148   case MYSQL_TYPE_SHORT:
149   case MYSQL_TYPE_LONGLONG:
150   case MYSQL_TYPE_INT24:
151   case MYSQL_TYPE_LONG:
152   case MYSQL_TYPE_ENUM:
153     return true;
154   default:
155     return false;
156   }
157   return false;
158 }
159 
160 
161 bool
runQuery(const char * sql,const Properties & args,SqlResultSet & rows)162 SqlClient::runQuery(const char* sql,
163                     const Properties& args,
164                     SqlResultSet& rows){
165 
166   rows.clear();
167   if (!isConnected())
168     return false;
169 
170   g_debug << "runQuery: " << endl
171           << " sql: '" << sql << "'" << endl;
172 
173 
174   MYSQL_STMT *stmt= mysql_stmt_init(mysql);
175   if (mysql_stmt_prepare(stmt, sql, strlen(sql)))
176   {
177     g_err << "Failed to prepare: " << mysql_error(mysql) << endl;
178     return false;
179   }
180 
181   uint params= mysql_stmt_param_count(stmt);
182   MYSQL_BIND bind_param[params];
183   bzero(bind_param, sizeof(bind_param));
184 
185   for(uint i= 0; i < mysql_stmt_param_count(stmt); i++)
186   {
187     BaseString name;
188     name.assfmt("%d", i);
189     // Parameters are named 0, 1, 2...
190     if (!args.contains(name.c_str()))
191     {
192       g_err << "param " << i << " missing" << endl;
193       require(false);
194     }
195     PropertiesType t;
196     Uint32 val_i;
197     const char* val_s;
198     args.getTypeOf(name.c_str(), &t);
199     switch(t) {
200     case PropertiesType_Uint32:
201       args.get(name.c_str(), &val_i);
202       bind_param[i].buffer_type= MYSQL_TYPE_LONG;
203       bind_param[i].buffer= (char*)&val_i;
204       g_debug << " param" << name.c_str() << ": " << val_i << endl;
205       break;
206     case PropertiesType_char:
207       args.get(name.c_str(), &val_s);
208       bind_param[i].buffer_type= MYSQL_TYPE_STRING;
209       bind_param[i].buffer= (char*)val_s;
210       bind_param[i].buffer_length= strlen(val_s);
211       g_debug << " param" << name.c_str() << ": " << val_s << endl;
212       break;
213     default:
214       require(false);
215       break;
216     }
217   }
218   if (mysql_stmt_bind_param(stmt, bind_param))
219   {
220     g_err << "Failed to bind param: " << mysql_error(mysql) << endl;
221     mysql_stmt_close(stmt);
222     return false;
223   }
224 
225   if (mysql_stmt_execute(stmt))
226   {
227     g_err << "Failed to execute: " << mysql_error(mysql) << endl;
228     mysql_stmt_close(stmt);
229     return false;
230   }
231 
232   /*
233     Update max_length, making it possible to know how big
234     buffers to allocate
235   */
236   my_bool one= 1;
237   mysql_stmt_attr_set(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, (void*) &one);
238 
239   if (mysql_stmt_store_result(stmt))
240   {
241     g_err << "Failed to store result: " << mysql_error(mysql) << endl;
242     mysql_stmt_close(stmt);
243     return false;
244   }
245 
246   uint row= 0;
247   MYSQL_RES* res= mysql_stmt_result_metadata(stmt);
248   if (res != NULL)
249   {
250     MYSQL_FIELD *fields= mysql_fetch_fields(res);
251     uint num_fields= mysql_num_fields(res);
252     MYSQL_BIND bind_result[num_fields];
253     bzero(bind_result, sizeof(bind_result));
254 
255     for (uint i= 0; i < num_fields; i++)
256     {
257       if (is_int_type(fields[i].type)){
258         bind_result[i].buffer_type= MYSQL_TYPE_LONG;
259         bind_result[i].buffer= malloc(sizeof(int));
260       }
261       else
262       {
263         uint max_length= fields[i].max_length + 1;
264         bind_result[i].buffer_type= MYSQL_TYPE_STRING;
265         bind_result[i].buffer= malloc(max_length);
266         bind_result[i].buffer_length= max_length;
267       }
268     }
269 
270     if (mysql_stmt_bind_result(stmt, bind_result)){
271       g_err << "Failed to bind result: " << mysql_error(mysql) << endl;
272       mysql_stmt_close(stmt);
273       return false;
274     }
275 
276     while (mysql_stmt_fetch(stmt) != MYSQL_NO_DATA)
277     {
278       Properties curr(true);
279       for (uint i= 0; i < num_fields; i++){
280         if (is_int_type(fields[i].type))
281           curr.put(fields[i].name, *(int*)bind_result[i].buffer);
282         else
283           curr.put(fields[i].name, (char*)bind_result[i].buffer);
284       }
285       rows.put("row", row++, &curr);
286     }
287 
288     mysql_free_result(res);
289 
290     for (uint i= 0; i < num_fields; i++)
291       free(bind_result[i].buffer);
292 
293   }
294 
295   // Save stats in result set
296   rows.put("rows", row);
297   rows.put("affected_rows", mysql_affected_rows(mysql));
298   rows.put("mysql_errno", mysql_errno(mysql));
299   rows.put("mysql_error", mysql_error(mysql));
300   rows.put("mysql_sqlstate", mysql_sqlstate(mysql));
301   rows.put("insert_id", mysql_insert_id(mysql));
302 
303   mysql_stmt_close(stmt);
304   return true;
305 }
306 
307 
308 bool
doQuery(const char * query)309 SqlClient::doQuery(const char* query){
310   const Properties args;
311   SqlResultSet result;
312   return doQuery(query, args, result);
313 }
314 
315 
316 bool
doQuery(const char * query,SqlResultSet & result)317 SqlClient::doQuery(const char* query, SqlResultSet& result){
318   Properties args;
319   return doQuery(query, args, result);
320 }
321 
322 
323 bool
doQuery(const char * query,const Properties & args,SqlResultSet & result)324 SqlClient::doQuery(const char* query, const Properties& args,
325                    SqlResultSet& result){
326   if (!runQuery(query, args, result))
327     return false;
328   result.get_row(0); // Load first row
329   return true;
330 }
331 
332 
333 bool
doQuery(BaseString & str)334 SqlClient::doQuery(BaseString& str){
335   return doQuery(str.c_str());
336 }
337 
338 
339 bool
doQuery(BaseString & str,SqlResultSet & result)340 SqlClient::doQuery(BaseString& str, SqlResultSet& result){
341   return doQuery(str.c_str(), result);
342 }
343 
344 
345 bool
doQuery(BaseString & str,const Properties & args,SqlResultSet & result)346 SqlClient::doQuery(BaseString& str, const Properties& args,
347                    SqlResultSet& result){
348   return doQuery(str.c_str(), args, result);
349 }
350 
351 
352 
353 
354 bool
get_row(int row_num)355 SqlResultSet::get_row(int row_num){
356   if(!get("row", row_num, &m_curr_row)){
357     return false;
358   }
359   return true;
360 }
361 
362 bool
next(void)363 SqlResultSet::next(void){
364   return get_row(++m_curr_row_num);
365 }
366 
367 // Reset iterator
reset(void)368 void SqlResultSet::reset(void){
369   m_curr_row_num= -1;
370   m_curr_row= 0;
371 }
372 
373 // Remove row from resultset
remove()374 void SqlResultSet::remove(){
375   BaseString row_name;
376   row_name.assfmt("row_%d", m_curr_row_num);
377   Properties::remove(row_name.c_str());
378 }
379 
380 
SqlResultSet()381 SqlResultSet::SqlResultSet(): m_curr_row(0), m_curr_row_num(-1){
382 }
383 
~SqlResultSet()384 SqlResultSet::~SqlResultSet(){
385 }
386 
column(const char * col_name)387 const char* SqlResultSet::column(const char* col_name){
388   const char* value;
389   if (!m_curr_row){
390     g_err << "ERROR: SqlResultSet::column("<< col_name << ")" << endl
391           << "There is no row loaded, call next() before "
392           << "acessing the column values" << endl;
393     require(m_curr_row);
394   }
395   if (!m_curr_row->get(col_name, &value))
396     return NULL;
397   return value;
398 }
399 
columnAsInt(const char * col_name)400 uint SqlResultSet::columnAsInt(const char* col_name){
401   uint value;
402   if (!m_curr_row){
403     g_err << "ERROR: SqlResultSet::columnAsInt("<< col_name << ")" << endl
404           << "There is no row loaded, call next() before "
405           << "acessing the column values" << endl;
406     require(m_curr_row);
407   }
408   if (!m_curr_row->get(col_name, &value))
409     return (uint)-1;
410   return value;
411 }
412 
insertId()413 uint SqlResultSet::insertId(){
414   return get_int("insert_id");
415 }
416 
affectedRows()417 uint SqlResultSet::affectedRows(){
418   return get_int("affected_rows");
419 }
420 
numRows(void)421 uint SqlResultSet::numRows(void){
422   return get_int("rows");
423 }
424 
mysqlErrno(void)425 uint SqlResultSet::mysqlErrno(void){
426   return get_int("mysql_errno");
427 }
428 
429 
mysqlError(void)430 const char* SqlResultSet::mysqlError(void){
431   return get_string("mysql_error");
432 }
433 
mysqlSqlstate(void)434 const char* SqlResultSet::mysqlSqlstate(void){
435   return get_string("mysql_sqlstate");
436 }
437 
get_int(const char * name)438 uint SqlResultSet::get_int(const char* name){
439   uint value;
440   require(get(name, &value));
441   return value;
442 }
443 
get_string(const char * name)444 const char* SqlResultSet::get_string(const char* name){
445   const char* value;
446   require(get(name, &value));
447   return value;
448 }
449