1 /* spmfilter - mail filtering framework
2  * Copyright (C) 2009-2012 Axel Steiner, Werner Detter and SpaceNet AG
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 3 of the License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this program.  If not, see <http://www.gnu.org/licenses/>.
16  */
17 
18 #define _GNU_SOURCE
19 
20 #include <stdlib.h>
21 #include <string.h>
22 #include <stdarg.h>
23 #include <unistd.h>
24 #include <time.h>
25 #include <assert.h>
26 #include <zdb.h>
27 
28 #include "smf_settings.h"
29 #include "smf_session.h"
30 #include "smf_trace.h"
31 #include "smf_lookup.h"
32 #include "smf_lookup_sql.h"
33 #include "smf_core.h"
34 #include "smf_dict.h"
35 #include "smf_list.h"
36 #include "smf_internal.h"
37 
38 #define THIS_MODULE "lookup_sql"
39 
smf_lookup_sql_abort_handler(const char * error)40 void smf_lookup_sql_abort_handler(const char *error) {
41     TRACE(TRACE_ERR, "%s", error);
42 }
43 
smf_lookup_sql_get_rand_host(SMFSettings_T * settings)44 char *smf_lookup_sql_get_rand_host(SMFSettings_T *settings) {
45     int random;
46     SMFListElem_T *e = NULL;
47     int count = 0;
48 
49     assert(settings);
50 
51     TRACE(TRACE_DEBUG,"trying to get random sql server");
52     srand(time(NULL));
53     random = rand() % smf_list_size(settings->sql_host);
54     e = smf_list_head(settings->sql_host);
55     while(e != NULL) {
56         count++;
57         if(count != random)
58             return (char *)smf_list_data(e);
59 
60         e = e->next;
61     }
62     return NULL;
63 }
64 
smf_lookup_sql_con_close(Connection_T c)65 void smf_lookup_sql_con_close(Connection_T c) {
66     TRACE(TRACE_LOOKUP,"returning connection to pool");
67     Connection_close(c);
68     return;
69 }
70 
smf_lookup_sql_get_dsn(SMFSettings_T * settings,char * host)71 char *smf_lookup_sql_get_dsn(SMFSettings_T *settings, char *host) {
72     assert(settings);
73     char *sdsn = NULL;
74     SMFListElem_T *e = NULL;
75 
76     sdsn = (char *)calloc(1,sizeof(char));
77 
78     if (settings->sql_driver != NULL) {
79         smf_core_strcat_printf(&sdsn,"%s://",settings->sql_driver);
80     } else {
81         TRACE(TRACE_ERR,"error, no sql driver defined!");
82         return NULL;
83     }
84 
85     if (host != NULL) {
86         smf_core_strcat_printf(&sdsn,"%s",host);
87     } else {
88         if ((strcasecmp(settings->backend_connection,"balance") == 0) &&
89                 (strcasecmp(settings->sql_driver,"sqlite") != 0)) {
90             smf_core_strcat_printf(&sdsn,"%s",smf_lookup_sql_get_rand_host(settings));
91         } else {
92             if (strcasecmp(settings->sql_driver,"sqlite") != 0) {
93                 e = smf_list_head(settings->sql_host);
94                 smf_core_strcat_printf(&sdsn,"%s",(char *)smf_list_data(e));
95             }
96         }
97     }
98 
99     if (settings->sql_port)
100         smf_core_strcat_printf(&sdsn,":%u",settings->sql_port);
101 
102     if (settings->sql_name) {
103         if (strcasecmp(settings->sql_driver,"sqlite") == 0) {
104             /* expand ~ in db name to HOME env variable */
105             if ((strlen(settings->sql_name) > 0 ) && (settings->sql_name[0] == '~')) {
106                 char *homedir;
107                 if ((homedir = getenv ("HOME")) == NULL)
108                     TRACE(TRACE_ERR,"can't expand ~ in db name");
109                 asprintf(&settings->sql_name,"%s%s", homedir, &(settings->sql_name[1]));
110             }
111 
112             smf_core_strcat_printf(&sdsn, "%s", settings->sql_name);
113         } else {
114             smf_core_strcat_printf(&sdsn,"/%s",settings->sql_name);
115         }
116     }
117 
118     if (settings->sql_user && strlen((const char*)settings->sql_user)) {
119         smf_core_strcat_printf(&sdsn,"?user=%s", settings->sql_user);
120 
121         if (settings->sql_pass && strlen((const char *)settings->sql_pass))
122             smf_core_strcat_printf(&sdsn,"&password=%s", settings->sql_pass);
123 
124         if (strcasecmp(settings->sql_driver,"mysql") == 0) {
125             if (settings->sql_encoding && strlen((const char *)settings->sql_encoding))
126                 smf_core_strcat_printf(&sdsn,"&charset=%s", settings->sql_encoding);
127         }
128     }
129 
130     TRACE(TRACE_LOOKUP,"sql db at url: [%s]", sdsn);
131     return sdsn;
132 }
133 
smf_lookup_sql_start_pool(SMFSettings_T * settings,char * dsn)134 int smf_lookup_sql_start_pool(SMFSettings_T *settings, char *dsn) {
135     int sweep_interval = 60;
136     Connection_T c = NULL;
137     SMFSQLConnection_T *con = NULL;
138 
139     assert(settings);
140     assert(dsn);
141 
142     if (settings->lookup_connection != NULL)
143         smf_lookup_sql_disconnect(settings);
144 
145     con = malloc(sizeof(SMFSQLConnection_T));
146     con->pool = NULL;
147     con->url = URL_new(dsn);
148     if (settings->lookup_connection != NULL) smf_lookup_sql_disconnect(settings);
149 
150     settings->lookup_connection = (void *)con;
151 
152     if (!(con->pool = ConnectionPool_new(con->url))) {
153         TRACE(TRACE_ERR,"error creating database connection pool");
154         smf_lookup_sql_disconnect(settings);
155         return -1;
156     }
157 
158     if (settings->sql_max_connections > 0) {
159         if (settings->sql_max_connections < (unsigned int)ConnectionPool_getInitialConnections(con->pool))
160             ConnectionPool_setInitialConnections(con->pool, settings->sql_max_connections);
161         ConnectionPool_setMaxConnections(con->pool, settings->sql_max_connections);
162         TRACE(TRACE_LOOKUP,"database connection pool created with maximum connections of [%d]",settings->sql_max_connections);
163     }
164 
165     ConnectionPool_setReaper(con->pool, sweep_interval);
166     TRACE(TRACE_LOOKUP, "run a database connection reaper thread every [%d] seconds", sweep_interval);
167 
168     if (strcasecmp(settings->sql_driver,"sqlite") != 0)
169         ConnectionPool_setAbortHandler(con->pool, smf_lookup_sql_abort_handler);
170 
171     ConnectionPool_start(con->pool);
172 
173     if (!(c = ConnectionPool_getConnection(con->pool))) {
174         smf_lookup_sql_disconnect(settings);
175         return -1;
176     }
177     if (Connection_ping(c) == 0) {
178         smf_lookup_sql_disconnect(settings);
179         return -1;
180     }
181     smf_lookup_sql_con_close(c);
182 
183     TRACE(TRACE_LOOKUP, "database connection pool started with [%d] connections, max [%d]",
184     ConnectionPool_getInitialConnections(con->pool), ConnectionPool_getMaxConnections(con->pool));
185 
186     return 0;
187 }
188 
smf_lookup_sql_connect(SMFSettings_T * settings)189 int smf_lookup_sql_connect(SMFSettings_T *settings) {
190     char *dsn = NULL;
191     int ret = -1;
192     SMFListElem_T *elem;
193     char *host;
194 
195     assert(settings);
196 
197     dsn = smf_lookup_sql_get_dsn(settings, NULL);
198 
199     if ((ret = smf_lookup_sql_start_pool(settings,dsn)) != 0) {
200         TRACE(TRACE_ERR,"failed to initialize sql pool\n");
201         /* check failover connections */
202         elem = smf_list_head(settings->sql_host);
203         while(elem != NULL) {
204             if (dsn != NULL) free(dsn);
205 
206             host = (char *)smf_list_data(elem);
207             dsn = smf_lookup_sql_get_dsn(settings, host);
208 
209             if ((ret = smf_lookup_sql_start_pool(settings,dsn)) == 0)
210                 break;
211 
212             elem = elem->next;
213         }
214     }
215 
216     if (ret == 0) {
217         TRACE(TRACE_LOOKUP,"successfully initialized sql pool\n");
218     } else {
219         TRACE(TRACE_LOOKUP,"failed initialized sql pool\n");
220     }
221 
222     free(dsn);
223 
224     return ret;
225 }
226 
smf_lookup_sql_disconnect(SMFSettings_T * settings)227 void smf_lookup_sql_disconnect(SMFSettings_T *settings) {
228     SMFSQLConnection_T *con = NULL;
229     assert(settings);
230 
231     if (settings->lookup_connection != NULL) {
232         con = (SMFSQLConnection_T *)settings->lookup_connection;
233 
234         TRACE(TRACE_LOOKUP,"closing database connection");
235         ConnectionPool_stop(con->pool);
236         ConnectionPool_free(&con->pool);
237         URL_free(&con->url);
238         free(con);
239         settings->lookup_connection = NULL;
240     }
241 }
242 
smf_lookup_sql_get_connection(ConnectionPool_T pool)243 Connection_T smf_lookup_sql_get_connection(ConnectionPool_T pool) {
244     int i=0, k=0;
245     Connection_T c;
246     while (i++<30) {
247         c = ConnectionPool_getConnection(pool);
248 
249         if (c) {
250             if(Connection_ping(c) == 1) break;
251             else Connection_close(c);
252         }
253         if((int)(i % 5)==0) {
254             TRACE(TRACE_WARNING, "Thread is having trouble obtaining a database connection. Try [%d]", i);
255             k = ConnectionPool_reapConnections(pool);
256             TRACE(TRACE_LOOKUP, "Database reaper closed [%d] stale connections", k);
257         }
258         sleep(1);
259     }
260     if (! c) {
261         TRACE(TRACE_ERR,"[%p] can't get a database connection from the pool! max [%d] size [%d] active [%d]",
262             pool,
263             ConnectionPool_getMaxConnections(pool),
264             ConnectionPool_size(pool),
265             ConnectionPool_active(pool));
266     }
267 
268     assert(c);
269     TRACE(TRACE_LOOKUP,"[%p] got connection from pool", c);
270     return c;
271 }
272 
smf_lookup_sql_query(SMFSettings_T * settings,SMFSession_T * session,const char * q,...)273 SMFList_T *smf_lookup_sql_query(SMFSettings_T *settings, SMFSession_T *session, const char *q, ...) {
274     SMFSQLConnection_T *con;
275     Connection_T c;
276     ResultSet_T r;
277     SMFList_T *result;
278     va_list ap;
279     char *query;
280     int i;
281 
282     va_start(ap, q);
283     vasprintf(&query,q,ap);
284     va_end(ap);
285     smf_core_strstrip(query);
286 
287     if (strlen(query) == 0) return NULL;
288 
289     /* active connection? */
290     if (settings->lookup_connection == NULL)
291         if(smf_lookup_sql_connect(settings) != 0) return NULL;
292 
293     con = (SMFSQLConnection_T *)settings->lookup_connection;
294     if (con->pool == NULL)
295         if (smf_lookup_sql_connect(settings) != 0) return NULL;
296 
297     if (smf_list_new(&result,smf_internal_dict_list_destroy)!=0) {
298         return NULL;
299     } else {
300         c = smf_lookup_sql_get_connection(con->pool);
301 
302         TRY
303             r = Connection_executeQuery(c, query,NULL);
304         CATCH(SQLException)
305             STRACE(TRACE_ERR,session->id,"SQL error: %s\n", Connection_getLastError(c));
306             return NULL;
307         END_TRY;
308 
309         while (ResultSet_next(r)) {
310             SMFDict_T *d = smf_dict_new();
311 
312             for (i=1; i <= ResultSet_getColumnCount(r); i++) {
313                 int blob_size = 0;
314                 char *c = (char *)ResultSet_getColumnName(r,i);
315                 char *col_name = NULL;
316                 col_name = strdup(c);
317                 const void *data = ResultSet_getBlob(r, i, &blob_size);
318 
319                 smf_dict_set(d,col_name,data);
320                 free(col_name);
321             }
322 
323             if (smf_list_append(result,d) != 0) return NULL;
324         }
325 
326         STRACE(TRACE_LOOKUP,session->id,"query [%s] returned [%d] rows", query, result->size);
327     }
328 
329     free(query);
330     smf_lookup_sql_con_close(c);
331 
332     /* if not persistent, close connection */
333     if (settings->lookup_persistent != 1)
334         smf_lookup_sql_disconnect(settings);
335 
336     return result;
337 }
338 
339