1 /* Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
2 
3    This program is free software; you can redistribute it and/or modify
4    it under the terms of the GNU General Public License, version 2.0,
5    as published by the Free Software Foundation.
6 
7    This program is also distributed with certain software (including
8    but not limited to OpenSSL) that is licensed under separate terms,
9    as designated in a particular file or component or in included license
10    documentation.  The authors of MySQL hereby grant you an additional
11    permission to link the program and your derivative works with the
12    separately licensed software that they have included with MySQL.
13 
14    This program is distributed in the hope that it will be useful,
15    but WITHOUT ANY WARRANTY; without even the implied warranty of
16    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17    GNU General Public License, version 2.0, for more details.
18 
19    You should have received a copy of the GNU General Public License
20    along with this program; if not, write to the Free Software
21    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA */
22 
23 #define LOG_COMPONENT_TAG "test_sql_reset_connection"
24 
25 #include <fcntl.h>
26 #include <mysql/plugin.h>
27 #include <stdlib.h>
28 #include <sys/types.h>
29 #include <memory>
30 
31 #include <mysql/components/my_service.h>
32 #include <mysql/components/services/log_builtins.h>
33 #include <mysql/components/services/udf_registration.h>
34 #include <mysql/service_srv_session_info.h>
35 #include <mysqld_error.h>
36 
37 #include "m_string.h"
38 #include "my_dbug.h"
39 #include "my_inttypes.h"
40 #include "my_io.h"
41 #include "my_sys.h"  // my_write, my_malloc
42 #include "mysql_com.h"
43 #include "template_utils.h"
44 
45 #define STRING_BUFFER 256
46 
47 static File outfile;
48 
WRITE_STR(const char * format)49 static void WRITE_STR(const char *format) {
50   char buffer[STRING_BUFFER];
51   snprintf(buffer, sizeof(buffer), "%s", format);
52   my_write(outfile, (uchar *)buffer, strlen(buffer), MYF(0));
53 }
54 
55 template <typename T>
WRITE_VAL(const char * format,T value)56 void WRITE_VAL(const char *format, T value) {
57   char buffer[STRING_BUFFER];
58   snprintf(buffer, sizeof(buffer), format, value);
59   my_write(outfile, (uchar *)buffer, strlen(buffer), MYF(0));
60 }
61 
62 template <typename T1, typename T2>
WRITE_VAL2(const char * format,T1 value1,T2 value2)63 void WRITE_VAL2(const char *format, T1 value1, T2 value2) {
64   char buffer[STRING_BUFFER];
65   snprintf(buffer, sizeof(buffer), format, value1, value2);
66   my_write(outfile, (uchar *)buffer, strlen(buffer), MYF(0));
67 }
68 
69 static const char *sep =
70     "=======================================================================\n";
71 
72 #define WRITE_SEP() \
73   my_write(outfile, pointer_cast<const uchar *>(sep), strlen(sep), MYF(0))
74 
75 static SERVICE_TYPE(registry) *reg_srv = nullptr;
76 SERVICE_TYPE(log_builtins) *log_bi = nullptr;
77 SERVICE_TYPE(log_builtins_string) *log_bs = nullptr;
78 
79 struct st_send_field_n {
80   char db_name[256];
81   char table_name[256];
82   char org_table_name[256];
83   char col_name[256];
84   char org_col_name[256];
85   unsigned long length;
86   unsigned int charsetnr;
87   unsigned int flags;
88   unsigned int decimals;
89   enum_field_types type;
90 };
91 
92 struct st_decimal_n {
93   int intg, frac, len;
94   bool sign;
95   decimal_digit_t buf[256];
96 };
97 
98 struct st_plugin_ctx {
99   const CHARSET_INFO *resultcs;
100   uint meta_server_status;
101   uint meta_warn_count;
102   uint current_col;
103   uint num_cols;
104   uint num_rows;
105   st_send_field_n sql_field[8];
106   char sql_str_value[8][8][256];
107   size_t sql_str_len[8][8];
108 
109   uint server_status;
110   uint warn_count;
111   uint affected_rows;
112   uint last_insert_id;
113   char message[1024];
114 
115   uint sql_errno;
116   char err_msg[1024];
117   char sqlstate[6];
st_plugin_ctxst_plugin_ctx118   st_plugin_ctx() { reset(); }
119 
resetst_plugin_ctx120   void reset() {
121     resultcs = nullptr;
122     server_status = 0;
123     current_col = 0;
124     warn_count = 0;
125     num_cols = 0;
126     num_rows = 0;
127     memset(&sql_field, 0, 8 * sizeof(st_send_field_n));
128     memset(&sql_str_value, 0, 8 * 8 * 256 * sizeof(char));
129     memset(&sql_str_len, 0, 8 * 8 * sizeof(size_t));
130 
131     server_status = 0;
132     warn_count = 0;
133     affected_rows = 0;
134     last_insert_id = 0;
135     memset(&message, 0, sizeof(message));
136 
137     sql_errno = 0;
138     memset(&err_msg, 0, sizeof(err_msg));
139     memset(&sqlstate, 0, sizeof(sqlstate));
140   }
141 };
142 
sql_start_result_metadata(void * ctx,uint num_cols,uint,const CHARSET_INFO * resultcs)143 static int sql_start_result_metadata(void *ctx, uint num_cols, uint,
144                                      const CHARSET_INFO *resultcs) {
145   auto pctx = (struct st_plugin_ctx *)ctx;
146   DBUG_TRACE;
147   DBUG_PRINT("info", ("resultcs->number: %d", resultcs->number));
148   DBUG_PRINT("info", ("resultcs->csname: %s", resultcs->csname));
149   DBUG_PRINT("info", ("resultcs->name: %s", resultcs->name));
150   pctx->num_cols = num_cols;
151   pctx->resultcs = resultcs;
152   pctx->current_col = 0;
153   return false;
154 }
155 
sql_field_metadata(void * ctx,struct st_send_field * field,const CHARSET_INFO *)156 static int sql_field_metadata(void *ctx, struct st_send_field *field,
157                               const CHARSET_INFO *) {
158   auto pctx = (struct st_plugin_ctx *)ctx;
159   st_send_field_n *cfield = &pctx->sql_field[pctx->current_col];
160   DBUG_TRACE;
161   DBUG_PRINT("info", ("field->db_name: %s", field->db_name));
162   DBUG_PRINT("info", ("field->table_name: %s", field->table_name));
163   DBUG_PRINT("info", ("field->org_table_name: %s", field->org_table_name));
164   DBUG_PRINT("info", ("field->col_name: %s", field->col_name));
165   DBUG_PRINT("info", ("field->org_col_name: %s", field->org_col_name));
166   DBUG_PRINT("info", ("field->length: %d", (int)field->length));
167   DBUG_PRINT("info", ("field->charsetnr: %d", (int)field->charsetnr));
168   DBUG_PRINT("info", ("field->flags: %d", (int)field->flags));
169   DBUG_PRINT("info", ("field->decimals: %d", (int)field->decimals));
170   DBUG_PRINT("info", ("field->type: %d", (int)field->type));
171 
172   strcpy(cfield->db_name, field->db_name);
173   strcpy(cfield->table_name, field->table_name);
174   strcpy(cfield->org_table_name, field->org_table_name);
175   strcpy(cfield->col_name, field->col_name);
176   strcpy(cfield->org_col_name, field->org_col_name);
177   cfield->length = field->length;
178   cfield->charsetnr = field->charsetnr;
179   cfield->flags = field->flags;
180   cfield->decimals = field->decimals;
181   cfield->type = field->type;
182 
183   pctx->current_col++;
184   return false;
185 }
186 
sql_end_result_metadata(void * ctx,uint server_status,uint warn_count)187 static int sql_end_result_metadata(void *ctx, uint server_status,
188                                    uint warn_count) {
189   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
190   DBUG_TRACE;
191   pctx->meta_server_status = server_status;
192   pctx->meta_warn_count = warn_count;
193   pctx->num_rows = 0;
194   return false;
195 }
196 
sql_start_row(void * ctx)197 static int sql_start_row(void *ctx) {
198   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
199   DBUG_TRACE;
200   pctx->current_col = 0;
201   return false;
202 }
203 
sql_end_row(void * ctx)204 static int sql_end_row(void *ctx) {
205   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
206   DBUG_TRACE;
207   pctx->num_rows++;
208   return false;
209 }
210 
sql_abort_row(void * ctx)211 static void sql_abort_row(void *ctx) {
212   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
213   DBUG_TRACE;
214   pctx->current_col = 0;
215 }
216 
sql_get_client_capabilities(void *)217 static ulong sql_get_client_capabilities(void *) {
218   DBUG_TRACE;
219   return 0;
220 }
221 
sql_get_null(void * ctx)222 static int sql_get_null(void *ctx) {
223   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
224   DBUG_TRACE;
225   uint row = pctx->num_rows;
226   uint col = pctx->current_col;
227   pctx->current_col++;
228 
229   memcpy(pctx->sql_str_value[row][col], "[NULL]", sizeof("[NULL]"));
230   pctx->sql_str_len[row][col] = sizeof("[NULL]") - 1;
231 
232   return false;
233 }
234 
sql_get_integer(void * ctx,longlong value)235 static int sql_get_integer(void *ctx, longlong value) {
236   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
237   DBUG_TRACE;
238   uint row = pctx->num_rows;
239   uint col = pctx->current_col;
240   pctx->current_col++;
241 
242   size_t len = snprintf(pctx->sql_str_value[row][col],
243                         sizeof(pctx->sql_str_value[row][col]), "%lld", value);
244   pctx->sql_str_len[row][col] = len;
245 
246   return false;
247 }
248 
sql_get_longlong(void * ctx,longlong value,uint is_unsigned)249 static int sql_get_longlong(void *ctx, longlong value, uint is_unsigned) {
250   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
251   DBUG_TRACE;
252   uint row = pctx->num_rows;
253   uint col = pctx->current_col;
254   pctx->current_col++;
255 
256   size_t len = snprintf(pctx->sql_str_value[row][col],
257                         sizeof(pctx->sql_str_value[row][col]),
258                         is_unsigned ? "%llu" : "%lld", value);
259 
260   pctx->sql_str_len[row][col] = len;
261 
262   return false;
263 }
264 
sql_get_decimal(void * ctx,const decimal_t * value)265 static int sql_get_decimal(void *ctx, const decimal_t *value) {
266   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
267   DBUG_TRACE;
268   uint row = pctx->num_rows;
269   uint col = pctx->current_col;
270   pctx->current_col++;
271 
272   size_t len = snprintf(pctx->sql_str_value[row][col],
273                         sizeof(pctx->sql_str_value[row][col]),
274                         "%s%d.%d(%d)[%s]", value->sign ? "+" : "-", value->intg,
275                         value->frac, value->len, (char *)value->buf);
276   pctx->sql_str_len[row][col] = len;
277 
278   return false;
279 }
280 
sql_get_double(void * ctx,double value,uint32)281 static int sql_get_double(void *ctx, double value, uint32 /*decimals*/) {
282   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
283   DBUG_TRACE;
284   uint row = pctx->num_rows;
285   uint col = pctx->current_col;
286   pctx->current_col++;
287 
288   size_t len = snprintf(pctx->sql_str_value[row][col],
289                         sizeof(pctx->sql_str_value[row][col]), "%3.7g", value);
290 
291   pctx->sql_str_len[row][col] = len;
292 
293   return false;
294 }
295 
sql_get_date(void * ctx,const MYSQL_TIME * value)296 static int sql_get_date(void *ctx, const MYSQL_TIME *value) {
297   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
298   DBUG_TRACE;
299   uint row = pctx->num_rows;
300   uint col = pctx->current_col;
301   pctx->current_col++;
302 
303   size_t len =
304       snprintf(pctx->sql_str_value[row][col],
305                sizeof(pctx->sql_str_value[row][col]), "%s%4d-%02d-%02d",
306                value->neg ? "-" : "", value->year, value->month, value->day);
307   pctx->sql_str_len[row][col] = len;
308 
309   return false;
310 }
311 
sql_get_time(void * ctx,const MYSQL_TIME * value,uint)312 static int sql_get_time(void *ctx, const MYSQL_TIME *value, uint /*decimals*/) {
313   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
314   DBUG_TRACE;
315   uint row = pctx->num_rows;
316   uint col = pctx->current_col;
317   pctx->current_col++;
318 
319   size_t len = snprintf(
320       pctx->sql_str_value[row][col], sizeof(pctx->sql_str_value[row][col]),
321       "%s%02d:%02d:%02d", value->neg ? "-" : "",
322       value->day ? (value->day * 24 + value->hour) : value->hour, value->minute,
323       value->second);
324 
325   pctx->sql_str_len[row][col] = len;
326 
327   return false;
328 }
329 
sql_get_datetime(void * ctx,const MYSQL_TIME * value,uint)330 static int sql_get_datetime(void *ctx, const MYSQL_TIME *value,
331                             uint /*decimals*/) {
332   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
333   DBUG_TRACE;
334   uint row = pctx->num_rows;
335   uint col = pctx->current_col;
336   pctx->current_col++;
337 
338   size_t len = snprintf(
339       pctx->sql_str_value[row][col], sizeof(pctx->sql_str_value[row][col]),
340       "%s%4d-%02d-%02d %02d:%02d:%02d", value->neg ? "-" : "", value->year,
341       value->month, value->day, value->hour, value->minute, value->second);
342 
343   pctx->sql_str_len[row][col] = len;
344 
345   return false;
346 }
347 
sql_get_string(void * ctx,const char * const value,size_t length,const CHARSET_INFO * const)348 static int sql_get_string(void *ctx, const char *const value, size_t length,
349                           const CHARSET_INFO *const) {
350   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
351   DBUG_TRACE;
352   uint row = pctx->num_rows;
353   uint col = pctx->current_col;
354   pctx->current_col++;
355 
356   strncpy(pctx->sql_str_value[row][col], value, length);
357   pctx->sql_str_len[row][col] = length;
358 
359   return false;
360 }
361 
sql_handle_ok(void * ctx,uint server_status,uint statement_warn_count,ulonglong affected_rows,ulonglong last_insert_id,const char * const message)362 static void sql_handle_ok(void *ctx, uint server_status,
363                           uint statement_warn_count, ulonglong affected_rows,
364                           ulonglong last_insert_id, const char *const message) {
365   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
366   DBUG_TRACE;
367   /* This could be an EOF */
368   if (!pctx->num_cols) pctx->num_rows = 0;
369   pctx->server_status = server_status;
370   pctx->warn_count = statement_warn_count;
371   pctx->affected_rows = affected_rows;
372   pctx->last_insert_id = last_insert_id;
373   if (message) strncpy(pctx->message, message, sizeof(pctx->message) - 1);
374   pctx->message[sizeof(pctx->message) - 1] = '\0';
375 }
376 
sql_handle_error(void * ctx,uint sql_errno,const char * const err_msg,const char * const sqlstate)377 static void sql_handle_error(void *ctx, uint sql_errno,
378                              const char *const err_msg,
379                              const char *const sqlstate) {
380   struct st_plugin_ctx *pctx = (struct st_plugin_ctx *)ctx;
381   DBUG_TRACE;
382   pctx->sql_errno = sql_errno;
383   if (pctx->sql_errno) {
384     strcpy(pctx->err_msg, err_msg);
385     strcpy(pctx->sqlstate, sqlstate);
386   }
387   pctx->num_rows = 0;
388 }
389 
sql_shutdown(void *,int)390 static void sql_shutdown(void *, int) { DBUG_TRACE; }
391 
392 const struct st_command_service_cbs sql_cbs = {
393     sql_start_result_metadata,
394     sql_field_metadata,
395     sql_end_result_metadata,
396     sql_start_row,
397     sql_end_row,
398     sql_abort_row,
399     sql_get_client_capabilities,
400     sql_get_null,
401     sql_get_integer,
402     sql_get_longlong,
403     sql_get_decimal,
404     sql_get_double,
405     sql_get_date,
406     sql_get_time,
407     sql_get_datetime,
408     sql_get_string,
409     sql_handle_ok,
410     sql_handle_error,
411     sql_shutdown,
412 };
413 
fieldtype2str(enum enum_field_types type)414 static const char *fieldtype2str(enum enum_field_types type) {
415   switch (type) {
416     case MYSQL_TYPE_BIT:
417       return "BIT";
418     case MYSQL_TYPE_BLOB:
419       return "BLOB";
420     case MYSQL_TYPE_DATE:
421       return "DATE";
422     case MYSQL_TYPE_DATETIME:
423       return "DATETIME";
424     case MYSQL_TYPE_NEWDECIMAL:
425       return "NEWDECIMAL";
426     case MYSQL_TYPE_DECIMAL:
427       return "DECIMAL";
428     case MYSQL_TYPE_DOUBLE:
429       return "DOUBLE";
430     case MYSQL_TYPE_ENUM:
431       return "ENUM";
432     case MYSQL_TYPE_FLOAT:
433       return "FLOAT";
434     case MYSQL_TYPE_GEOMETRY:
435       return "GEOMETRY";
436     case MYSQL_TYPE_INT24:
437       return "INT24";
438     case MYSQL_TYPE_LONG:
439       return "LONG";
440     case MYSQL_TYPE_LONGLONG:
441       return "LONGLONG";
442     case MYSQL_TYPE_LONG_BLOB:
443       return "LONG_BLOB";
444     case MYSQL_TYPE_MEDIUM_BLOB:
445       return "MEDIUM_BLOB";
446     case MYSQL_TYPE_NEWDATE:
447       return "NEWDATE";
448     case MYSQL_TYPE_NULL:
449       return "NULL";
450     case MYSQL_TYPE_SET:
451       return "SET";
452     case MYSQL_TYPE_SHORT:
453       return "SHORT";
454     case MYSQL_TYPE_STRING:
455       return "STRING";
456     case MYSQL_TYPE_TIME:
457       return "TIME";
458     case MYSQL_TYPE_TIMESTAMP:
459       return "TIMESTAMP";
460     case MYSQL_TYPE_TINY:
461       return "TINY";
462     case MYSQL_TYPE_TINY_BLOB:
463       return "TINY_BLOB";
464     case MYSQL_TYPE_VARCHAR:
465       return "VARCHAR";
466     case MYSQL_TYPE_VAR_STRING:
467       return "VAR_STRING";
468     case MYSQL_TYPE_YEAR:
469       return "YEAR";
470     default:
471       return "?-unknown-?";
472   }
473 }
474 
get_data_str(struct st_plugin_ctx * pctx)475 static void get_data_str(struct st_plugin_ctx *pctx) {
476   WRITE_STR(
477       "-----------------------------------------------------------------\n");
478   for (uint col = 0; col < pctx->num_cols; col++) {
479     WRITE_VAL("%s ", pctx->sql_field[col].col_name);
480     WRITE_VAL2("%s(%u)\t", fieldtype2str(pctx->sql_field[col].type),
481                pctx->sql_field[col].type);
482   }
483   WRITE_STR("\n");
484 
485   for (uint row = 0; row < pctx->num_rows; row++) {
486     for (uint col = 0; col < pctx->num_cols; col++) {
487       WRITE_VAL2("%s%s", pctx->sql_str_value[row][col],
488                  col < pctx->num_cols - 1 ? "\t\t\t" : "\n");
489     }
490   }
491 }
492 
query_execute(MYSQL_SESSION session,st_plugin_ctx * pctx,const std::string & query)493 static void query_execute(MYSQL_SESSION session, st_plugin_ctx *pctx,
494                           const std::string &query) {
495   WRITE_VAL("%s\n", query.c_str());
496   pctx->reset();
497 
498   COM_DATA cmd;
499   cmd.com_query.query = query.c_str();
500   cmd.com_query.length = query.size();
501   if (command_service_run_command(session, COM_QUERY, &cmd,
502                                   &my_charset_utf8_general_ci, &sql_cbs,
503                                   CS_TEXT_REPRESENTATION, pctx)) {
504     LogPluginErr(ERROR_LEVEL, ER_LOG_PRINTF_MSG, "fail query execution - %d:%s",
505                  pctx->sql_errno, pctx->err_msg);
506     return;
507   }
508   if (pctx->num_cols) get_data_str(pctx);
509 }
510 
511 struct Thread_data {
512   void *p;
513   void (*proc)(void *p);
514 };
515 
test_session_thread(void * ctxt)516 static void *test_session_thread(void *ctxt) {
517   auto thread_data = (Thread_data *)ctxt;
518 
519   if (srv_session_init_thread(thread_data->p))
520     LogPluginErr(ERROR_LEVEL, ER_LOG_PRINTF_MSG,
521                  "srv_session_init_thread failed.");
522 
523   thread_data->proc(thread_data->p);
524 
525   srv_session_deinit_thread();
526 
527   return nullptr;
528 }
529 
test_execute_in_thread(void * p,void (* proc)(void * p))530 void test_execute_in_thread(void *p, void (*proc)(void *p)) {
531   Thread_data thread_data{p, proc};
532 
533   my_thread_handle thread_handle;
534   my_thread_attr_t attr;
535   my_thread_attr_init(&attr);
536   (void)my_thread_attr_setdetachstate(&attr, MY_THREAD_CREATE_JOINABLE);
537 
538   if (my_thread_create(&thread_handle, &attr,
539                        (void *(*)(void *))test_session_thread,
540                        &thread_data) != 0) {
541     WRITE_STR("Could not create test services thread!\n");
542     exit(1);
543   }
544   void *ret;
545   my_thread_join(&thread_handle, &ret);
546 }
547 
ensure_api_ok(const char * function,int result)548 static void ensure_api_ok(const char *function, int result) {
549   if (result != 0) {
550     WRITE_VAL2("ERROR calling %s: returned %i\n", function, result);
551   }
552 }
553 
ensure_api_not_null(const char * function,void * result)554 static void ensure_api_not_null(const char *function, void *result) {
555   if (!result) {
556     WRITE_VAL("ERROR calling %s: returned NULL\n", function);
557   }
558 }
559 
560 #define ENSURE_API_OK(call) ensure_api_ok(__FUNCTION__, (call));
561 #define ENSURE_API_NOT_NULL(call) ensure_api_not_null(__FUNCTION__, (call));
562 
reset_connection(MYSQL_SESSION st_session,st_plugin_ctx * pctx)563 static void reset_connection(MYSQL_SESSION st_session, st_plugin_ctx *pctx) {
564   COM_DATA cmd;
565   ENSURE_API_OK(command_service_run_command(
566       st_session, COM_RESET_CONNECTION, &cmd, &my_charset_utf8_general_ci,
567       &sql_cbs, CS_TEXT_REPRESENTATION, pctx));
568 }
569 
session_error_cb(void *,unsigned int sql_errno,const char * err_msg)570 static void session_error_cb(void *, unsigned int sql_errno,
571                              const char *err_msg) {
572   WRITE_STR("default error handler called\n");
573   WRITE_VAL("sql_errno = %i\n", sql_errno);
574   WRITE_VAL("errmsg = %s\n", err_msg);
575 }
576 
test_com_reset_connection(void * p)577 static void test_com_reset_connection(void *p) {
578   DBUG_TRACE;
579 
580   WRITE_STR("COM_RESET_CONNECTION\n");
581 
582   MYSQL_SESSION st_session;
583   ENSURE_API_NOT_NULL(st_session = srv_session_open(session_error_cb, p));
584 
585   my_thread_id session_id = srv_session_info_get_session_id(st_session);
586 
587   std::unique_ptr<st_plugin_ctx> ctx(new st_plugin_ctx());
588   query_execute(st_session, ctx.get(), "set @secret = 123");
589   query_execute(st_session, ctx.get(), "select @secret");
590   reset_connection(st_session, ctx.get());
591   query_execute(st_session, ctx.get(), "select @secret");
592 
593   WRITE_VAL("Has session ID changed: %i\n",
594             srv_session_info_get_session_id(st_session) != session_id);
595 
596   ENSURE_API_OK(srv_session_close(st_session));
597 }
598 
test_com_reset_connection_from_another_session(void * p)599 static void test_com_reset_connection_from_another_session(void *p) {
600   DBUG_TRACE;
601 
602   WRITE_STR("COM_RESET_CONNECTION from another session\n");
603 
604   MYSQL_SESSION st_session;
605   ENSURE_API_NOT_NULL(st_session = srv_session_open(NULL, p));
606 
607   my_thread_id session_id = srv_session_info_get_session_id(st_session);
608 
609   std::unique_ptr<st_plugin_ctx> ctx(new st_plugin_ctx());
610   query_execute(st_session, ctx.get(), "set @another_secret = 456");
611   query_execute(st_session, ctx.get(), "select @another_secret");
612   WRITE_STR(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n");
613   query_execute(st_session, ctx.get(), "do reset_connection()");
614   WRITE_STR("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n");
615   query_execute(st_session, ctx.get(), "select @another_secret");
616 
617   WRITE_VAL("Has session ID changed: %i\n",
618             srv_session_info_get_session_id(st_session) != session_id);
619 
620   ENSURE_API_OK(srv_session_close(st_session));
621 }
622 
test_sql(void * p)623 static void test_sql(void *p) {
624   DBUG_TRACE;
625 
626   WRITE_SEP();
627   test_execute_in_thread(p, test_com_reset_connection);
628   WRITE_SEP();
629   test_execute_in_thread(p, test_com_reset_connection_from_another_session);
630   WRITE_SEP();
631 }
632 
create_log_file(const char * log_name)633 static void create_log_file(const char *log_name) {
634   char filename[FN_REFLEN];
635 
636   fn_format(filename, log_name, "", ".log",
637             MY_REPLACE_EXT | MY_UNPACK_FILENAME);
638   unlink(filename);
639   outfile = my_open(filename, O_CREAT | O_RDWR, MYF(0));
640 }
641 
642 static const char *log_filename = "test_sql_reset_connection";
643 
644 namespace {
645 void *plg = nullptr;
646 
647 using Udf_registrator = my_service<SERVICE_TYPE(udf_registration)>;
648 
reset_connection_init(UDF_INIT *,UDF_ARGS * args,char *)649 bool reset_connection_init(UDF_INIT *, UDF_ARGS *args, char *) {
650   return args->arg_count != 0;
651 }
652 
reset_connection_exe(UDF_INIT *,UDF_ARGS *,unsigned char *,unsigned char *)653 long long reset_connection_exe(UDF_INIT *, UDF_ARGS *, unsigned char *,
654                                unsigned char *) {
655   DBUG_TRACE;
656   test_execute_in_thread(plg, test_com_reset_connection);
657   return 0;
658 }
659 
register_udf_reset_connection()660 void register_udf_reset_connection() {
661   DBUG_TRACE;
662   auto reg = mysql_plugin_registry_acquire();
663   {
664     Udf_registrator udf_reg{"udf_registration", reg};
665     if (udf_reg.is_valid()) {
666       udf_reg->udf_register(
667           "reset_connection", INT_RESULT,
668           reinterpret_cast<Udf_func_any>(reset_connection_exe),
669           reset_connection_init, nullptr);
670     } else {
671       LogPluginErr(ERROR_LEVEL, ER_LOG_PRINTF_MSG, "fail udf registartion");
672     }
673   }
674   mysql_plugin_registry_release(reg);
675 }
676 
unregister_udf_reset_connection()677 void unregister_udf_reset_connection() {
678   DBUG_TRACE;
679   auto reg = mysql_plugin_registry_acquire();
680   {
681     Udf_registrator udf_reg{"udf_registration", reg};
682     if (udf_reg.is_valid()) {
683       int was_present = 0;
684       udf_reg->udf_unregister("reset_connection", &was_present);
685     }
686   }
687   mysql_plugin_registry_release(reg);
688 }
689 }  // namespace
690 
test_sql_service_plugin_init(void * p)691 static int test_sql_service_plugin_init(void *p) {
692   create_log_file(log_filename);
693   DBUG_TRACE;
694   if (init_logging_service_for_plugin(&reg_srv, &log_bi, &log_bs)) return 1;
695   LogPluginErr(INFORMATION_LEVEL, ER_LOG_PRINTF_MSG, "Installation.");
696 
697   plg = p;
698   register_udf_reset_connection();
699 
700   /* Test of service: sql */
701   test_sql(p);
702 
703   return 0;
704 }
705 
test_sql_service_plugin_deinit(void * p MY_ATTRIBUTE ((unused)))706 static int test_sql_service_plugin_deinit(void *p MY_ATTRIBUTE((unused))) {
707   DBUG_TRACE;
708   LogPluginErr(INFORMATION_LEVEL, ER_LOG_PRINTF_MSG, "Uninstallation.");
709 
710   unregister_udf_reset_connection();
711 
712   deinit_logging_service_for_plugin(&reg_srv, &log_bi, &log_bs);
713   my_close(outfile, MYF(0));
714   return 0;
715 }
716 
717 struct st_mysql_daemon test_sql_service_plugin = {
718     MYSQL_DAEMON_INTERFACE_VERSION};
719 
720 /*
721   Plugin library descriptor
722 */
723 
mysql_declare_plugin(test_daemon)724 mysql_declare_plugin(test_daemon){
725     MYSQL_DAEMON_PLUGIN,
726     &test_sql_service_plugin,
727     "test_sql_reset_connection",
728     PLUGIN_AUTHOR_ORACLE,
729     "Test sql reset connection",
730     PLUGIN_LICENSE_GPL,
731     test_sql_service_plugin_init,   /* Plugin Init */
732     nullptr,                        /* Plugin Check uninstall */
733     test_sql_service_plugin_deinit, /* Plugin Deinit */
734     0x0100 /* 1.0 */,
735     nullptr, /* status variables                */
736     nullptr, /* system variables                */
737     nullptr, /* config options                  */
738     0,       /* flags                           */
739 } mysql_declare_plugin_end;
740