1 /* Copyright (c) 2004-2018 Dovecot authors, see the included COPYING file */
2 
3 #include "lib.h"
4 #include "array.h"
5 #include "ioloop.h"
6 #include "hash.h"
7 #include "str.h"
8 #include "time-util.h"
9 #include "sql-api-private.h"
10 
11 #include <time.h>
12 
13 struct event_category event_category_sql = {
14 	.name = "sql",
15 };
16 
17 struct sql_db_module_register sql_db_module_register = { 0 };
18 ARRAY_TYPE(sql_drivers) sql_drivers;
19 
sql_drivers_init(void)20 void sql_drivers_init(void)
21 {
22 	i_array_init(&sql_drivers, 8);
23 }
24 
sql_drivers_deinit(void)25 void sql_drivers_deinit(void)
26 {
27 	array_free(&sql_drivers);
28 }
29 
sql_driver_lookup(const char * name)30 static const struct sql_db *sql_driver_lookup(const char *name)
31 {
32 	const struct sql_db *const *drivers;
33 	unsigned int i, count;
34 
35 	drivers = array_get(&sql_drivers, &count);
36 	for (i = 0; i < count; i++) {
37 		if (strcmp(drivers[i]->name, name) == 0)
38 			return drivers[i];
39 	}
40 	return NULL;
41 }
42 
sql_driver_register(const struct sql_db * driver)43 void sql_driver_register(const struct sql_db *driver)
44 {
45 	if (sql_driver_lookup(driver->name) != NULL) {
46 		i_fatal("sql_driver_register(%s): Already registered",
47 			driver->name);
48 	}
49 	array_push_back(&sql_drivers, &driver);
50 }
51 
sql_driver_unregister(const struct sql_db * driver)52 void sql_driver_unregister(const struct sql_db *driver)
53 {
54 	const struct sql_db *const *drivers;
55 	unsigned int i, count;
56 
57 	drivers = array_get(&sql_drivers, &count);
58 	for (i = 0; i < count; i++) {
59 		if (drivers[i] == driver) {
60 			array_delete(&sql_drivers, i, 1);
61 			break;
62 		}
63 	}
64 }
65 
sql_init(const char * db_driver,const char * connect_string)66 struct sql_db *sql_init(const char *db_driver, const char *connect_string)
67 {
68 	const char *error;
69 	struct sql_db *db;
70 	struct sql_settings set = {
71 		.driver = db_driver,
72 		.connect_string = connect_string,
73 	};
74 
75 	if (sql_init_full(&set, &db, &error) < 0)
76 		i_fatal("%s", error);
77 	return db;
78 }
79 
sql_init_full(const struct sql_settings * set,struct sql_db ** db_r,const char ** error_r)80 int sql_init_full(const struct sql_settings *set, struct sql_db **db_r,
81 		  const char **error_r)
82 {
83 	const struct sql_db *driver;
84 	struct sql_db *db;
85 	int ret = 0;
86 
87 	i_assert(set->connect_string != NULL);
88 
89 	driver = sql_driver_lookup(set->driver);
90 	if (driver == NULL) {
91 		*error_r = t_strdup_printf("Unknown database driver '%s'", set->driver);
92 		return -1;
93 	}
94 
95 	if ((driver->flags & SQL_DB_FLAG_POOLED) == 0) {
96 		if (driver->v.init_full == NULL) {
97 			db = driver->v.init(set->connect_string);
98 		} else
99 			ret = driver->v.init_full(set, &db, error_r);
100 	} else
101 		ret = driver_sqlpool_init_full(set, driver, &db, error_r);
102 
103 	if (ret < 0)
104 		return -1;
105 
106 	sql_init_common(db);
107 	*db_r = db;
108 	return 0;
109 }
110 
sql_init_common(struct sql_db * db)111 void sql_init_common(struct sql_db *db)
112 {
113 	db->refcount = 1;
114 	i_array_init(&db->module_contexts, 5);
115 	hash_table_create(&db->prepared_stmt_hash, default_pool, 0,
116 			  str_hash, strcmp);
117 }
118 
sql_ref(struct sql_db * db)119 void sql_ref(struct sql_db *db)
120 {
121 	i_assert(db->refcount > 0);
122 	db->refcount++;
123 }
124 
125 static void
default_sql_prepared_statement_deinit(struct sql_prepared_statement * prep_stmt)126 default_sql_prepared_statement_deinit(struct sql_prepared_statement *prep_stmt)
127 {
128 	i_free(prep_stmt->query_template);
129 	i_free(prep_stmt);
130 }
131 
sql_prepared_statements_free(struct sql_db * db)132 static void sql_prepared_statements_free(struct sql_db *db)
133 {
134 	struct hash_iterate_context *iter;
135 	struct sql_prepared_statement *prep_stmt;
136 	char *query;
137 
138 	iter = hash_table_iterate_init(db->prepared_stmt_hash);
139 	while (hash_table_iterate(iter, db->prepared_stmt_hash, &query, &prep_stmt)) {
140 		i_assert(prep_stmt->refcount == 0);
141 		if (prep_stmt->db->v.prepared_statement_deinit != NULL)
142 			prep_stmt->db->v.prepared_statement_deinit(prep_stmt);
143 		else
144 			default_sql_prepared_statement_deinit(prep_stmt);
145 	}
146 	hash_table_iterate_deinit(&iter);
147 	hash_table_clear(db->prepared_stmt_hash, TRUE);
148 }
149 
sql_unref(struct sql_db ** _db)150 void sql_unref(struct sql_db **_db)
151 {
152 	struct sql_db *db = *_db;
153 
154 	*_db = NULL;
155 
156 	i_assert(db->refcount > 0);
157 	if (db->v.unref != NULL)
158 		db->v.unref(db);
159 	if (--db->refcount > 0)
160 		return;
161 
162 	timeout_remove(&db->to_reconnect);
163 	sql_prepared_statements_free(db);
164 	hash_table_destroy(&db->prepared_stmt_hash);
165 	db->v.deinit(db);
166 }
167 
sql_get_flags(struct sql_db * db)168 enum sql_db_flags sql_get_flags(struct sql_db *db)
169 {
170 	if (db->v.get_flags != NULL)
171 		return db->v.get_flags(db);
172 	else
173 		return db->flags;
174 }
175 
sql_connect(struct sql_db * db)176 int sql_connect(struct sql_db *db)
177 {
178 	time_t now;
179 
180 	switch (db->state) {
181 	case SQL_DB_STATE_DISCONNECTED:
182 		break;
183 	case SQL_DB_STATE_CONNECTING:
184 		return 0;
185 	default:
186 		return 1;
187 	}
188 
189 	/* don't try reconnecting more than once a second */
190 	now = time(NULL);
191 	if (db->last_connect_try + (time_t)db->connect_delay > now)
192 		return -1;
193 	db->last_connect_try = now;
194 
195 	return db->v.connect(db);
196 }
197 
sql_disconnect(struct sql_db * db)198 void sql_disconnect(struct sql_db *db)
199 {
200 	timeout_remove(&db->to_reconnect);
201 	db->v.disconnect(db);
202 }
203 
sql_escape_string(struct sql_db * db,const char * string)204 const char *sql_escape_string(struct sql_db *db, const char *string)
205 {
206 	return db->v.escape_string(db, string);
207 }
208 
sql_escape_blob(struct sql_db * db,const unsigned char * data,size_t size)209 const char *sql_escape_blob(struct sql_db *db,
210 			    const unsigned char *data, size_t size)
211 {
212 	return db->v.escape_blob(db, data, size);
213 }
214 
sql_exec(struct sql_db * db,const char * query)215 void sql_exec(struct sql_db *db, const char *query)
216 {
217 	db->v.exec(db, query);
218 }
219 
220 #undef sql_query
sql_query(struct sql_db * db,const char * query,sql_query_callback_t * callback,void * context)221 void sql_query(struct sql_db *db, const char *query,
222 	       sql_query_callback_t *callback, void *context)
223 {
224 	db->v.query(db, query, callback, context);
225 }
226 
sql_query_s(struct sql_db * db,const char * query)227 struct sql_result *sql_query_s(struct sql_db *db, const char *query)
228 {
229 	return db->v.query_s(db, query);
230 }
231 
232 static struct sql_prepared_statement *
default_sql_prepared_statement_init(struct sql_db * db,const char * query_template)233 default_sql_prepared_statement_init(struct sql_db *db,
234 				    const char *query_template)
235 {
236 	struct sql_prepared_statement *prep_stmt;
237 
238 	prep_stmt = i_new(struct sql_prepared_statement, 1);
239 	prep_stmt->db = db;
240 	prep_stmt->refcount = 1;
241 	prep_stmt->query_template = i_strdup(query_template);
242 	return prep_stmt;
243 }
244 
245 static struct sql_statement *
default_sql_statement_init_prepared(struct sql_prepared_statement * stmt)246 default_sql_statement_init_prepared(struct sql_prepared_statement *stmt)
247 {
248 	return sql_statement_init(stmt->db, stmt->query_template);
249 }
250 
sql_statement_get_query(struct sql_statement * stmt)251 const char *sql_statement_get_query(struct sql_statement *stmt)
252 {
253 	string_t *query = t_str_new(128);
254 	const char *const *args;
255 	unsigned int i, args_count, arg_pos = 0;
256 
257 	args = array_get(&stmt->args, &args_count);
258 
259 	for (i = 0; stmt->query_template[i] != '\0'; i++) {
260 		if (stmt->query_template[i] == '?') {
261 			if (arg_pos >= args_count ||
262 			    args[arg_pos] == NULL) {
263 				i_panic("lib-sql: Missing bind for arg #%u in statement: %s",
264 					arg_pos, stmt->query_template);
265 			}
266 			str_append(query, args[arg_pos++]);
267 		} else {
268 			str_append_c(query, stmt->query_template[i]);
269 		}
270 	}
271 	if (arg_pos != args_count) {
272 		i_panic("lib-sql: Too many bind args (%u) for statement: %s",
273 			args_count, stmt->query_template);
274 	}
275 	return str_c(query);
276 }
277 
278 static void
default_sql_statement_query(struct sql_statement * stmt,sql_query_callback_t * callback,void * context)279 default_sql_statement_query(struct sql_statement *stmt,
280 			    sql_query_callback_t *callback, void *context)
281 {
282 	sql_query(stmt->db, sql_statement_get_query(stmt),
283 		  callback, context);
284 	pool_unref(&stmt->pool);
285 }
286 
287 static struct sql_result *
default_sql_statement_query_s(struct sql_statement * stmt)288 default_sql_statement_query_s(struct sql_statement *stmt)
289 {
290 	struct sql_result *result =
291 		sql_query_s(stmt->db, sql_statement_get_query(stmt));
292 	pool_unref(&stmt->pool);
293 	return result;
294 }
295 
default_sql_update_stmt(struct sql_transaction_context * ctx,struct sql_statement * stmt,unsigned int * affected_rows)296 static void default_sql_update_stmt(struct sql_transaction_context *ctx,
297 				    struct sql_statement *stmt,
298 				    unsigned int *affected_rows)
299 {
300 	ctx->db->v.update(ctx, sql_statement_get_query(stmt),
301 			  affected_rows);
302 	pool_unref(&stmt->pool);
303 }
304 
305 struct sql_prepared_statement *
sql_prepared_statement_init(struct sql_db * db,const char * query_template)306 sql_prepared_statement_init(struct sql_db *db, const char *query_template)
307 {
308 	struct sql_prepared_statement *stmt;
309 
310 	stmt = hash_table_lookup(db->prepared_stmt_hash, query_template);
311 	if (stmt != NULL) {
312 		stmt->refcount++;
313 		return stmt;
314 	}
315 
316 	if (db->v.prepared_statement_init != NULL)
317 		stmt = db->v.prepared_statement_init(db, query_template);
318 	else
319 		stmt = default_sql_prepared_statement_init(db, query_template);
320 
321 	hash_table_insert(db->prepared_stmt_hash, stmt->query_template, stmt);
322 	return stmt;
323 }
324 
sql_prepared_statement_unref(struct sql_prepared_statement ** _prep_stmt)325 void sql_prepared_statement_unref(struct sql_prepared_statement **_prep_stmt)
326 {
327 	struct sql_prepared_statement *prep_stmt = *_prep_stmt;
328 
329 	*_prep_stmt = NULL;
330 
331 	i_assert(prep_stmt->refcount > 0);
332 	prep_stmt->refcount--;
333 }
334 
335 static void
sql_statement_init_fields(struct sql_statement * stmt,struct sql_db * db)336 sql_statement_init_fields(struct sql_statement *stmt, struct sql_db *db)
337 {
338 	stmt->db = db;
339 	p_array_init(&stmt->args, stmt->pool, 8);
340 }
341 
342 struct sql_statement *
sql_statement_init(struct sql_db * db,const char * query_template)343 sql_statement_init(struct sql_db *db, const char *query_template)
344 {
345 	struct sql_statement *stmt;
346 
347 	if (db->v.statement_init != NULL)
348 		stmt = db->v.statement_init(db, query_template);
349 	else {
350 		pool_t pool = pool_alloconly_create("sql statement", 1024);
351 		stmt = p_new(pool, struct sql_statement, 1);
352 		stmt->pool = pool;
353 	}
354 	stmt->query_template = p_strdup(stmt->pool, query_template);
355 	sql_statement_init_fields(stmt, db);
356 	return stmt;
357 }
358 
359 struct sql_statement *
sql_statement_init_prepared(struct sql_prepared_statement * prep_stmt)360 sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt)
361 {
362 	struct sql_statement *stmt;
363 
364 	if (prep_stmt->db->v.statement_init_prepared == NULL)
365 		return default_sql_statement_init_prepared(prep_stmt);
366 
367 	stmt = prep_stmt->db->v.statement_init_prepared(prep_stmt);
368 	sql_statement_init_fields(stmt, prep_stmt->db);
369 	return stmt;
370 }
371 
sql_statement_abort(struct sql_statement ** _stmt)372 void sql_statement_abort(struct sql_statement **_stmt)
373 {
374 	struct sql_statement *stmt = *_stmt;
375 
376 	*_stmt = NULL;
377 	if (stmt->db->v.statement_abort != NULL)
378 		stmt->db->v.statement_abort(stmt);
379 	pool_unref(&stmt->pool);
380 }
381 
sql_statement_set_timestamp(struct sql_statement * stmt,const struct timespec * ts)382 void sql_statement_set_timestamp(struct sql_statement *stmt,
383 				 const struct timespec *ts)
384 {
385 	if (stmt->db->v.statement_set_timestamp != NULL)
386 		stmt->db->v.statement_set_timestamp(stmt, ts);
387 }
388 
sql_statement_bind_str(struct sql_statement * stmt,unsigned int column_idx,const char * value)389 void sql_statement_bind_str(struct sql_statement *stmt,
390 			    unsigned int column_idx, const char *value)
391 {
392 	const char *escaped_value =
393 		p_strdup_printf(stmt->pool, "'%s'",
394 				sql_escape_string(stmt->db, value));
395 	array_idx_set(&stmt->args, column_idx, &escaped_value);
396 
397 	if (stmt->db->v.statement_bind_str != NULL)
398 		stmt->db->v.statement_bind_str(stmt, column_idx, value);
399 }
400 
sql_statement_bind_binary(struct sql_statement * stmt,unsigned int column_idx,const void * value,size_t value_size)401 void sql_statement_bind_binary(struct sql_statement *stmt,
402 			       unsigned int column_idx, const void *value,
403 			       size_t value_size)
404 {
405 	const char *value_str =
406 		p_strdup_printf(stmt->pool, "%s",
407 				sql_escape_blob(stmt->db, value, value_size));
408 	array_idx_set(&stmt->args, column_idx, &value_str);
409 
410 	if (stmt->db->v.statement_bind_binary != NULL) {
411 		stmt->db->v.statement_bind_binary(stmt, column_idx,
412 						  value, value_size);
413 	}
414 }
415 
sql_statement_bind_int64(struct sql_statement * stmt,unsigned int column_idx,int64_t value)416 void sql_statement_bind_int64(struct sql_statement *stmt,
417 			      unsigned int column_idx, int64_t value)
418 {
419 	const char *value_str = p_strdup_printf(stmt->pool, "%"PRId64, value);
420 	array_idx_set(&stmt->args, column_idx, &value_str);
421 
422 	if (stmt->db->v.statement_bind_int64 != NULL)
423 		stmt->db->v.statement_bind_int64(stmt, column_idx, value);
424 }
425 
426 #undef sql_statement_query
sql_statement_query(struct sql_statement ** _stmt,sql_query_callback_t * callback,void * context)427 void sql_statement_query(struct sql_statement **_stmt,
428 			 sql_query_callback_t *callback, void *context)
429 {
430 	struct sql_statement *stmt = *_stmt;
431 
432 	*_stmt = NULL;
433 	if (stmt->db->v.statement_query != NULL)
434 		stmt->db->v.statement_query(stmt, callback, context);
435 	else
436 		default_sql_statement_query(stmt, callback, context);
437 }
438 
sql_statement_query_s(struct sql_statement ** _stmt)439 struct sql_result *sql_statement_query_s(struct sql_statement **_stmt)
440 {
441 	struct sql_statement *stmt = *_stmt;
442 
443 	*_stmt = NULL;
444 	if (stmt->db->v.statement_query_s != NULL)
445 		return stmt->db->v.statement_query_s(stmt);
446 	else
447 		return default_sql_statement_query_s(stmt);
448 }
449 
sql_result_ref(struct sql_result * result)450 void sql_result_ref(struct sql_result *result)
451 {
452 	result->refcount++;
453 }
454 
sql_result_unref(struct sql_result * result)455 void sql_result_unref(struct sql_result *result)
456 {
457 	i_assert(result->refcount > 0);
458 	if (--result->refcount > 0)
459 		return;
460 
461 	i_free(result->map);
462 	result->v.free(result);
463 }
464 
465 static const struct sql_field_def *
sql_field_def_find(const struct sql_field_def * fields,const char * name)466 sql_field_def_find(const struct sql_field_def *fields, const char *name)
467 {
468 	unsigned int i;
469 
470 	for (i = 0; fields[i].name != NULL; i++) {
471 		if (strcasecmp(fields[i].name, name) == 0)
472 			return &fields[i];
473 	}
474 	return NULL;
475 }
476 
477 static void
sql_result_build_map(struct sql_result * result,const struct sql_field_def * fields,size_t dest_size)478 sql_result_build_map(struct sql_result *result,
479 		     const struct sql_field_def *fields, size_t dest_size)
480 {
481 	const struct sql_field_def *def;
482 	const char *name;
483 	unsigned int i, count, field_size = 0;
484 
485 	count = sql_result_get_fields_count(result);
486 
487 	result->map_size = count;
488 	result->map = i_new(struct sql_field_map, result->map_size);
489 	for (i = 0; i < count; i++) {
490 		name = sql_result_get_field_name(result, i);
491 		def = sql_field_def_find(fields, name);
492 		if (def != NULL) {
493 			result->map[i].type = def->type;
494 			result->map[i].offset = def->offset;
495 			switch (def->type) {
496 			case SQL_TYPE_STR:
497 				field_size = sizeof(const char *);
498 				break;
499 			case SQL_TYPE_UINT:
500 				field_size = sizeof(unsigned int);
501 				break;
502 			case SQL_TYPE_ULLONG:
503 				field_size = sizeof(unsigned long long);
504 				break;
505 			case SQL_TYPE_BOOL:
506 				field_size = sizeof(bool);
507 				break;
508 			}
509 			i_assert(def->offset + field_size <= dest_size);
510 		} else {
511 			result->map[i].offset = SIZE_MAX;
512 		}
513 	}
514 }
515 
sql_result_setup_fetch(struct sql_result * result,const struct sql_field_def * fields,void * dest,size_t dest_size)516 void sql_result_setup_fetch(struct sql_result *result,
517 			    const struct sql_field_def *fields,
518 			    void *dest, size_t dest_size)
519 {
520 	if (result->map == NULL)
521 		sql_result_build_map(result, fields, dest_size);
522 	result->fetch_dest = dest;
523 	result->fetch_dest_size = dest_size;
524 }
525 
sql_result_fetch(struct sql_result * result)526 static void sql_result_fetch(struct sql_result *result)
527 {
528 	unsigned int i, count;
529 	const char *value;
530 	void *ptr;
531 
532 	memset(result->fetch_dest, 0, result->fetch_dest_size);
533 	count = result->map_size;
534 	for (i = 0; i < count; i++) {
535 		if (result->map[i].offset == SIZE_MAX)
536 			continue;
537 
538 		value = sql_result_get_field_value(result, i);
539 		ptr = STRUCT_MEMBER_P(result->fetch_dest,
540 				      result->map[i].offset);
541 
542 		switch (result->map[i].type) {
543 		case SQL_TYPE_STR: {
544 			*((const char **)ptr) = value;
545 			break;
546 		}
547 		case SQL_TYPE_UINT: {
548 			if (value != NULL &&
549 			    str_to_uint(value, (unsigned int *)ptr) < 0)
550 				i_error("sql: Value not uint: %s", value);
551 			break;
552 		}
553 		case SQL_TYPE_ULLONG: {
554 			if (value != NULL &&
555 			    str_to_ullong(value, (unsigned long long *)ptr) < 0)
556 				i_error("sql: Value not ullong: %s", value);
557 			break;
558 		}
559 		case SQL_TYPE_BOOL: {
560 			if (value != NULL && (*value == 't' || *value == '1'))
561 				*((bool *)ptr) = TRUE;
562 			break;
563 		}
564 		}
565 	}
566 }
567 
sql_result_next_row(struct sql_result * result)568 int sql_result_next_row(struct sql_result *result)
569 {
570 	int ret;
571 
572 	if ((ret = result->v.next_row(result)) <= 0)
573 		return ret;
574 
575 	if (result->fetch_dest != NULL)
576 		sql_result_fetch(result);
577 	return 1;
578 }
579 
580 #undef sql_result_more
sql_result_more(struct sql_result ** result,sql_query_callback_t * callback,void * context)581 void sql_result_more(struct sql_result **result,
582 		     sql_query_callback_t *callback, void *context)
583 {
584 	i_assert((*result)->v.more != NULL);
585 
586 	(*result)->v.more(result, TRUE, callback, context);
587 }
588 
589 static void
sql_result_more_sync_callback(struct sql_result * result,void * context)590 sql_result_more_sync_callback(struct sql_result *result, void *context)
591 {
592 	struct sql_result **dest_result = context;
593 
594 	*dest_result = result;
595 }
596 
sql_result_more_s(struct sql_result ** result)597 void sql_result_more_s(struct sql_result **result)
598 {
599 	i_assert((*result)->v.more != NULL);
600 
601 	(*result)->v.more(result, FALSE, sql_result_more_sync_callback, result);
602 	/* the callback must have been called */
603 	i_assert(*result != NULL);
604 }
605 
sql_result_get_fields_count(struct sql_result * result)606 unsigned int sql_result_get_fields_count(struct sql_result *result)
607 {
608 	return result->v.get_fields_count(result);
609 }
610 
sql_result_get_field_name(struct sql_result * result,unsigned int idx)611 const char *sql_result_get_field_name(struct sql_result *result,
612 				      unsigned int idx)
613 {
614 	return result->v.get_field_name(result, idx);
615 }
616 
sql_result_find_field(struct sql_result * result,const char * field_name)617 int sql_result_find_field(struct sql_result *result, const char *field_name)
618 {
619 	return result->v.find_field(result, field_name);
620 }
621 
sql_result_get_field_value(struct sql_result * result,unsigned int idx)622 const char *sql_result_get_field_value(struct sql_result *result,
623 				       unsigned int idx)
624 {
625 	return result->v.get_field_value(result, idx);
626 }
627 
628 const unsigned char *
sql_result_get_field_value_binary(struct sql_result * result,unsigned int idx,size_t * size_r)629 sql_result_get_field_value_binary(struct sql_result *result,
630 				  unsigned int idx, size_t *size_r)
631 {
632 	return result->v.get_field_value_binary(result, idx, size_r);
633 }
634 
sql_result_find_field_value(struct sql_result * result,const char * field_name)635 const char *sql_result_find_field_value(struct sql_result *result,
636 					const char *field_name)
637 {
638 	return result->v.find_field_value(result, field_name);
639 }
640 
sql_result_get_values(struct sql_result * result)641 const char *const *sql_result_get_values(struct sql_result *result)
642 {
643 	return result->v.get_values(result);
644 }
645 
sql_result_get_error(struct sql_result * result)646 const char *sql_result_get_error(struct sql_result *result)
647 {
648 	return result->v.get_error(result);
649 }
650 
sql_result_get_error_type(struct sql_result * result)651 enum sql_result_error_type sql_result_get_error_type(struct sql_result *result)
652 {
653 	return result->error_type;
654 }
655 
656 static void
sql_result_not_connected_free(struct sql_result * result ATTR_UNUSED)657 sql_result_not_connected_free(struct sql_result *result ATTR_UNUSED)
658 {
659 }
660 
661 static int
sql_result_not_connected_next_row(struct sql_result * result ATTR_UNUSED)662 sql_result_not_connected_next_row(struct sql_result *result ATTR_UNUSED)
663 {
664 	return -1;
665 }
666 
667 static const char *
sql_result_not_connected_get_error(struct sql_result * result ATTR_UNUSED)668 sql_result_not_connected_get_error(struct sql_result *result ATTR_UNUSED)
669 {
670 	return SQL_ERRSTR_NOT_CONNECTED;
671 }
672 
sql_transaction_begin(struct sql_db * db)673 struct sql_transaction_context *sql_transaction_begin(struct sql_db *db)
674 {
675 	return db->v.transaction_begin(db);
676 }
677 
678 #undef sql_transaction_commit
sql_transaction_commit(struct sql_transaction_context ** _ctx,sql_commit_callback_t * callback,void * context)679 void sql_transaction_commit(struct sql_transaction_context **_ctx,
680 			    sql_commit_callback_t *callback, void *context)
681 {
682 	struct sql_transaction_context *ctx = *_ctx;
683 
684 	*_ctx = NULL;
685 	ctx->db->v.transaction_commit(ctx, callback, context);
686 }
687 
sql_transaction_commit_s(struct sql_transaction_context ** _ctx,const char ** error_r)688 int sql_transaction_commit_s(struct sql_transaction_context **_ctx,
689 			     const char **error_r)
690 {
691 	struct sql_transaction_context *ctx = *_ctx;
692 
693 	*_ctx = NULL;
694 	return ctx->db->v.transaction_commit_s(ctx, error_r);
695 }
696 
sql_transaction_rollback(struct sql_transaction_context ** _ctx)697 void sql_transaction_rollback(struct sql_transaction_context **_ctx)
698 {
699 	struct sql_transaction_context *ctx = *_ctx;
700 
701 	*_ctx = NULL;
702 	ctx->db->v.transaction_rollback(ctx);
703 }
704 
sql_update(struct sql_transaction_context * ctx,const char * query)705 void sql_update(struct sql_transaction_context *ctx, const char *query)
706 {
707 	ctx->db->v.update(ctx, query, NULL);
708 }
709 
sql_update_stmt(struct sql_transaction_context * ctx,struct sql_statement ** _stmt)710 void sql_update_stmt(struct sql_transaction_context *ctx,
711 		     struct sql_statement **_stmt)
712 {
713 	struct sql_statement *stmt = *_stmt;
714 
715 	*_stmt = NULL;
716 	if (ctx->db->v.update_stmt != NULL)
717 		ctx->db->v.update_stmt(ctx, stmt, NULL);
718 	else
719 		default_sql_update_stmt(ctx, stmt, NULL);
720 }
721 
sql_update_get_rows(struct sql_transaction_context * ctx,const char * query,unsigned int * affected_rows)722 void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query,
723 			 unsigned int *affected_rows)
724 {
725 	ctx->db->v.update(ctx, query, affected_rows);
726 }
727 
sql_update_stmt_get_rows(struct sql_transaction_context * ctx,struct sql_statement ** _stmt,unsigned int * affected_rows)728 void sql_update_stmt_get_rows(struct sql_transaction_context *ctx,
729 			      struct sql_statement **_stmt,
730 			      unsigned int *affected_rows)
731 {
732 	struct sql_statement *stmt = *_stmt;
733 
734 	*_stmt = NULL;
735 	if (ctx->db->v.update_stmt != NULL)
736 		ctx->db->v.update_stmt(ctx, stmt, affected_rows);
737 	else
738 		default_sql_update_stmt(ctx, stmt, affected_rows);
739 }
740 
sql_db_set_state(struct sql_db * db,enum sql_db_state state)741 void sql_db_set_state(struct sql_db *db, enum sql_db_state state)
742 {
743 	enum sql_db_state old_state = db->state;
744 
745 	if (db->state == state)
746 		return;
747 
748 	db->state = state;
749 	if (db->state_change_callback != NULL) {
750 		db->state_change_callback(db, old_state,
751 					  db->state_change_context);
752 	}
753 }
754 
sql_transaction_add_query(struct sql_transaction_context * ctx,pool_t pool,const char * query,unsigned int * affected_rows)755 void sql_transaction_add_query(struct sql_transaction_context *ctx, pool_t pool,
756 			       const char *query, unsigned int *affected_rows)
757 {
758 	struct sql_transaction_query *tquery;
759 
760 	tquery = p_new(pool, struct sql_transaction_query, 1);
761 	tquery->trans = ctx;
762 	tquery->query = p_strdup(pool, query);
763 	tquery->affected_rows = affected_rows;
764 
765 	if (ctx->head == NULL)
766 		ctx->head = tquery;
767 	else
768 		ctx->tail->next = tquery;
769 	ctx->tail = tquery;
770 }
771 
sql_connection_log_finished(struct sql_db * db)772 void sql_connection_log_finished(struct sql_db *db)
773 {
774 	struct event_passthrough *e = event_create_passthrough(db->event)->
775 		set_name(SQL_CONNECTION_FINISHED);
776 	e_debug(e->event(),
777 		"Connection finished (queries=%"PRIu64", slow queries=%"PRIu64")",
778 		db->succeeded_queries + db->failed_queries,
779 		db->slow_queries);
780 }
781 
782 struct event_passthrough *
sql_query_finished_event(struct sql_db * db,struct event * event,const char * query,bool success,int * duration_r)783 sql_query_finished_event(struct sql_db *db, struct event *event, const char *query,
784 			 bool success, int *duration_r)
785 {
786 	int diff;
787 	struct timeval tv;
788 	event_get_create_time(event, &tv);
789 	struct event_passthrough *e = event_create_passthrough(event)->
790 			set_name(SQL_QUERY_FINISHED)->
791 			add_str("query_first_word", t_strcut(query, ' '));
792 	diff = timeval_diff_msecs(&ioloop_timeval, &tv);
793 
794 	if (!success) {
795 		db->failed_queries++;
796 	} else {
797 		db->succeeded_queries++;
798 	}
799 
800 	if (diff >= SQL_SLOW_QUERY_MSEC) {
801 		e->add_str("slow_query", "y");
802 		db->slow_queries++;
803 	}
804 
805 	if (duration_r != NULL)
806 		*duration_r = diff;
807 
808 	return e;
809 }
810 
sql_transaction_finished_event(struct sql_transaction_context * ctx)811 struct event_passthrough *sql_transaction_finished_event(struct sql_transaction_context *ctx)
812 {
813 	return event_create_passthrough(ctx->event)->
814 		set_name(SQL_TRANSACTION_FINISHED);
815 }
816 
sql_wait(struct sql_db * db)817 void sql_wait(struct sql_db *db)
818 {
819 	if (db->v.wait != NULL)
820 		db->v.wait(db);
821 }
822 
823 
824 struct sql_result sql_not_connected_result = {
825 	.v = {
826 		sql_result_not_connected_free,
827 		sql_result_not_connected_next_row,
828 		NULL, NULL, NULL, NULL, NULL, NULL, NULL,
829 		sql_result_not_connected_get_error,
830 		NULL,
831 	},
832 	.failed_try_retry = TRUE
833 };
834