1 /*-
2 * Copyright 2016 Vsevolod Stakhov
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "config.h"
17 #include "learn_cache.h"
18 #include "rspamd.h"
19 #include "stat_api.h"
20 #include "stat_internal.h"
21 #include "cryptobox.h"
22 #include "ucl.h"
23 #include "hiredis.h"
24 #include "adapters/libev.h"
25 #include "lua/lua_common.h"
26 #include "libmime/message.h"
27
28 #define REDIS_DEFAULT_TIMEOUT 0.5
29 #define REDIS_STAT_TIMEOUT 30
30 #define REDIS_DEFAULT_PORT 6379
31 #define DEFAULT_REDIS_KEY "learned_ids"
32
33 static const gchar *M = "redis learn cache";
34
35 struct rspamd_redis_cache_ctx {
36 lua_State *L;
37 struct rspamd_statfile_config *stcf;
38 const gchar *password;
39 const gchar *dbname;
40 const gchar *redis_object;
41 gdouble timeout;
42 gint conf_ref;
43 };
44
45 struct rspamd_redis_cache_runtime {
46 struct rspamd_redis_cache_ctx *ctx;
47 struct rspamd_task *task;
48 struct upstream *selected;
49 ev_timer timer_ev;
50 redisAsyncContext *redis;
51 gboolean has_event;
52 };
53
54 static GQuark
rspamd_stat_cache_redis_quark(void)55 rspamd_stat_cache_redis_quark (void)
56 {
57 return g_quark_from_static_string (M);
58 }
59
60 static inline struct upstream_list *
rspamd_redis_get_servers(struct rspamd_redis_cache_ctx * ctx,const gchar * what)61 rspamd_redis_get_servers (struct rspamd_redis_cache_ctx *ctx,
62 const gchar *what)
63 {
64 lua_State *L = ctx->L;
65 struct upstream_list *res;
66
67 lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref);
68 lua_pushstring (L, what);
69 lua_gettable (L, -2);
70 res = *((struct upstream_list**)lua_touserdata (L, -1));
71 lua_settop (L, 0);
72
73 return res;
74 }
75
76 static void
rspamd_redis_cache_maybe_auth(struct rspamd_redis_cache_ctx * ctx,redisAsyncContext * redis)77 rspamd_redis_cache_maybe_auth (struct rspamd_redis_cache_ctx *ctx,
78 redisAsyncContext *redis)
79 {
80 if (ctx->password) {
81 redisAsyncCommand (redis, NULL, NULL, "AUTH %s", ctx->password);
82 }
83 if (ctx->dbname) {
84 redisAsyncCommand (redis, NULL, NULL, "SELECT %s", ctx->dbname);
85 }
86 }
87
88 /* Called on connection termination */
89 static void
rspamd_redis_cache_fin(gpointer data)90 rspamd_redis_cache_fin (gpointer data)
91 {
92 struct rspamd_redis_cache_runtime *rt = data;
93 redisAsyncContext *redis;
94
95 rt->has_event = FALSE;
96 ev_timer_stop (rt->task->event_loop, &rt->timer_ev);
97
98 if (rt->redis) {
99 redis = rt->redis;
100 rt->redis = NULL;
101 /* This calls for all callbacks pending */
102 redisAsyncFree (redis);
103 }
104 }
105
106 static void
rspamd_redis_cache_timeout(EV_P_ ev_timer * w,int revents)107 rspamd_redis_cache_timeout (EV_P_ ev_timer *w, int revents)
108 {
109 struct rspamd_redis_cache_runtime *rt =
110 (struct rspamd_redis_cache_runtime *)w->data;
111 struct rspamd_task *task;
112
113 task = rt->task;
114
115 msg_err_task ("connection to redis server %s timed out",
116 rspamd_upstream_name (rt->selected));
117 rspamd_upstream_fail (rt->selected, FALSE, "timeout");
118
119 if (rt->has_event) {
120 rspamd_session_remove_event (task->s, rspamd_redis_cache_fin, rt);
121 }
122 }
123
124 /* Called when we have checked the specified message id */
125 static void
rspamd_stat_cache_redis_get(redisAsyncContext * c,gpointer r,gpointer priv)126 rspamd_stat_cache_redis_get (redisAsyncContext *c, gpointer r, gpointer priv)
127 {
128 struct rspamd_redis_cache_runtime *rt = priv;
129 redisReply *reply = r;
130 struct rspamd_task *task;
131 glong val = 0;
132
133 task = rt->task;
134
135 if (c->err == 0) {
136 if (reply) {
137 if (G_LIKELY (reply->type == REDIS_REPLY_INTEGER)) {
138 val = reply->integer;
139 }
140 else if (reply->type == REDIS_REPLY_STRING) {
141 rspamd_strtol (reply->str, reply->len, &val);
142 }
143 else {
144 if (reply->type == REDIS_REPLY_ERROR) {
145 msg_err_task ("cannot learn %s: redis error: \"%s\"",
146 rt->ctx->stcf->symbol, reply->str);
147 }
148 else if (reply->type != REDIS_REPLY_NIL) {
149 msg_err_task ("bad learned type for %s: %d",
150 rt->ctx->stcf->symbol, reply->type);
151 }
152
153 val = 0;
154 }
155 }
156
157 if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
158 (val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
159 /* Already learned */
160 msg_info_task ("<%s> has been already "
161 "learned as %s, ignore it", MESSAGE_FIELD (task, message_id),
162 (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
163 task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
164 }
165 else if (val != 0) {
166 /* Unlearn flag */
167 task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
168 }
169
170 rspamd_upstream_ok (rt->selected);
171 }
172 else {
173 rspamd_upstream_fail (rt->selected, FALSE, c->errstr);
174 }
175
176 if (rt->has_event) {
177 rspamd_session_remove_event (task->s, rspamd_redis_cache_fin, rt);
178 }
179 }
180
181 /* Called when we have learned the specified message id */
182 static void
rspamd_stat_cache_redis_set(redisAsyncContext * c,gpointer r,gpointer priv)183 rspamd_stat_cache_redis_set (redisAsyncContext *c, gpointer r, gpointer priv)
184 {
185 struct rspamd_redis_cache_runtime *rt = priv;
186 struct rspamd_task *task;
187
188 task = rt->task;
189
190 if (c->err == 0) {
191 /* XXX: we ignore results here */
192 rspamd_upstream_ok (rt->selected);
193 }
194 else {
195 rspamd_upstream_fail (rt->selected, FALSE, c->errstr);
196 }
197
198 if (rt->has_event) {
199 rspamd_session_remove_event (task->s, rspamd_redis_cache_fin, rt);
200 }
201 }
202
203 static void
rspamd_stat_cache_redis_generate_id(struct rspamd_task * task)204 rspamd_stat_cache_redis_generate_id (struct rspamd_task *task)
205 {
206 rspamd_cryptobox_hash_state_t st;
207 rspamd_token_t *tok;
208 guint i;
209 guchar out[rspamd_cryptobox_HASHBYTES];
210 gchar *b32out;
211 gchar *user = NULL;
212
213 rspamd_cryptobox_hash_init (&st, NULL, 0);
214
215 user = rspamd_mempool_get_variable (task->task_pool, "stat_user");
216 /* Use dedicated hash space for per users cache */
217 if (user != NULL) {
218 rspamd_cryptobox_hash_update (&st, user, strlen (user));
219 }
220
221 for (i = 0; i < task->tokens->len; i ++) {
222 tok = g_ptr_array_index (task->tokens, i);
223 rspamd_cryptobox_hash_update (&st, (guchar *)&tok->data,
224 sizeof (tok->data));
225 }
226
227 rspamd_cryptobox_hash_final (&st, out);
228
229 b32out = rspamd_mempool_alloc (task->task_pool,
230 sizeof (out) * 8 / 5 + 3);
231 i = rspamd_encode_base32_buf (out, sizeof (out), b32out,
232 sizeof (out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);
233
234 if (i > 0) {
235 /* Zero terminate */
236 b32out[i] = '\0';
237 }
238
239 rspamd_mempool_set_variable (task->task_pool, "words_hash", b32out, NULL);
240 }
241
242 gpointer
rspamd_stat_cache_redis_init(struct rspamd_stat_ctx * ctx,struct rspamd_config * cfg,struct rspamd_statfile * st,const ucl_object_t * cf)243 rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
244 struct rspamd_config *cfg,
245 struct rspamd_statfile *st,
246 const ucl_object_t *cf)
247 {
248 struct rspamd_redis_cache_ctx *cache_ctx;
249 struct rspamd_statfile_config *stf = st->stcf;
250 const ucl_object_t *obj;
251 gboolean ret = FALSE;
252 lua_State *L = (lua_State *)cfg->lua_state;
253 gint conf_ref = -1;
254
255 cache_ctx = g_malloc0 (sizeof (*cache_ctx));
256 cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
257 cache_ctx->L = L;
258
259 /* First search in backend configuration */
260 obj = ucl_object_lookup (st->classifier->cfg->opts, "backend");
261 if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) {
262 ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref);
263 }
264
265 /* Now try statfiles config */
266 if (!ret && stf->opts) {
267 ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref);
268 }
269
270 /* Now try classifier config */
271 if (!ret && st->classifier->cfg->opts) {
272 ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref);
273 }
274
275 /* Now try global redis settings */
276 if (!ret) {
277 obj = ucl_object_lookup (cfg->rcl_obj, "redis");
278
279 if (obj) {
280 const ucl_object_t *specific_obj;
281
282 specific_obj = ucl_object_lookup (obj, "statistics");
283
284 if (specific_obj) {
285 ret = rspamd_lua_try_load_redis (L,
286 specific_obj, cfg, &conf_ref);
287 }
288 else {
289 ret = rspamd_lua_try_load_redis (L,
290 obj, cfg, &conf_ref);
291 }
292 }
293 }
294
295 if (!ret) {
296 msg_err_config ("cannot init redis cache for %s", stf->symbol);
297 g_free (cache_ctx);
298 return NULL;
299 }
300
301 obj = ucl_object_lookup (st->classifier->cfg->opts, "cache_key");
302
303 if (obj) {
304 cache_ctx->redis_object = ucl_object_tostring (obj);
305 }
306 else {
307 cache_ctx->redis_object = DEFAULT_REDIS_KEY;
308 }
309
310 cache_ctx->conf_ref = conf_ref;
311
312 /* Check some common table values */
313 lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref);
314
315 lua_pushstring (L, "timeout");
316 lua_gettable (L, -2);
317 if (lua_type (L, -1) == LUA_TNUMBER) {
318 cache_ctx->timeout = lua_tonumber (L, -1);
319 }
320 lua_pop (L, 1);
321
322 lua_pushstring (L, "db");
323 lua_gettable (L, -2);
324 if (lua_type (L, -1) == LUA_TSTRING) {
325 cache_ctx->dbname = rspamd_mempool_strdup (cfg->cfg_pool,
326 lua_tostring (L, -1));
327 }
328 lua_pop (L, 1);
329
330 lua_pushstring (L, "password");
331 lua_gettable (L, -2);
332 if (lua_type (L, -1) == LUA_TSTRING) {
333 cache_ctx->password = rspamd_mempool_strdup (cfg->cfg_pool,
334 lua_tostring (L, -1));
335 }
336 lua_pop (L, 1);
337
338 lua_settop (L, 0);
339
340 cache_ctx->stcf = stf;
341
342 return (gpointer)cache_ctx;
343 }
344
345 gpointer
rspamd_stat_cache_redis_runtime(struct rspamd_task * task,gpointer c,gboolean learn)346 rspamd_stat_cache_redis_runtime (struct rspamd_task *task,
347 gpointer c, gboolean learn)
348 {
349 struct rspamd_redis_cache_ctx *ctx = c;
350 struct rspamd_redis_cache_runtime *rt;
351 struct upstream *up;
352 struct upstream_list *ups;
353 rspamd_inet_addr_t *addr;
354
355 g_assert (ctx != NULL);
356
357 if (task->tokens == NULL || task->tokens->len == 0) {
358 return NULL;
359 }
360
361 if (learn) {
362 ups = rspamd_redis_get_servers (ctx, "write_servers");
363
364 if (!ups) {
365 msg_err_task ("no write servers defined for %s, cannot learn",
366 ctx->stcf->symbol);
367 return NULL;
368 }
369
370 up = rspamd_upstream_get (ups,
371 RSPAMD_UPSTREAM_MASTER_SLAVE,
372 NULL,
373 0);
374 }
375 else {
376 ups = rspamd_redis_get_servers (ctx, "read_servers");
377
378 if (!ups) {
379 msg_err_task ("no read servers defined for %s, cannot check",
380 ctx->stcf->symbol);
381 return NULL;
382 }
383
384 up = rspamd_upstream_get (ups,
385 RSPAMD_UPSTREAM_ROUND_ROBIN,
386 NULL,
387 0);
388 }
389
390 if (up == NULL) {
391 msg_err_task ("no upstreams reachable");
392 return NULL;
393 }
394
395 rt = rspamd_mempool_alloc0 (task->task_pool, sizeof (*rt));
396 rt->selected = up;
397 rt->task = task;
398 rt->ctx = ctx;
399
400 addr = rspamd_upstream_addr_next (up);
401 g_assert (addr != NULL);
402
403 if (rspamd_inet_address_get_af (addr) == AF_UNIX) {
404 rt->redis = redisAsyncConnectUnix (rspamd_inet_address_to_string (addr));
405 }
406 else {
407 rt->redis = redisAsyncConnect (rspamd_inet_address_to_string (addr),
408 rspamd_inet_address_get_port (addr));
409 }
410
411 if (rt->redis == NULL) {
412 msg_warn_task ("cannot connect to redis server %s: %s",
413 rspamd_inet_address_to_string_pretty (addr),
414 strerror (errno));
415
416 return NULL;
417 }
418 else if (rt->redis->err != REDIS_OK) {
419 msg_warn_task ("cannot connect to redis server %s: %s",
420 rspamd_inet_address_to_string_pretty (addr),
421 rt->redis->errstr);
422 redisAsyncFree (rt->redis);
423 rt->redis = NULL;
424
425 return NULL;
426 }
427
428 redisLibevAttach (task->event_loop, rt->redis);
429
430 /* Now check stats */
431 rt->timer_ev.data = rt;
432 ev_timer_init (&rt->timer_ev, rspamd_redis_cache_timeout,
433 rt->ctx->timeout, 0.0);
434 rspamd_redis_cache_maybe_auth (ctx, rt->redis);
435
436 if (!learn) {
437 rspamd_stat_cache_redis_generate_id (task);
438 }
439
440 return rt;
441 }
442
443 gint
rspamd_stat_cache_redis_check(struct rspamd_task * task,gboolean is_spam,gpointer runtime)444 rspamd_stat_cache_redis_check (struct rspamd_task *task,
445 gboolean is_spam,
446 gpointer runtime)
447 {
448 struct rspamd_redis_cache_runtime *rt = runtime;
449 gchar *h;
450
451 if (rspamd_session_blocked (task->s)) {
452 return RSPAMD_LEARN_INGORE;
453 }
454
455 h = rspamd_mempool_get_variable (task->task_pool, "words_hash");
456
457 if (h == NULL) {
458 return RSPAMD_LEARN_INGORE;
459 }
460
461 if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_get, rt,
462 "HGET %s %s",
463 rt->ctx->redis_object, h) == REDIS_OK) {
464 rspamd_session_add_event (task->s,
465 rspamd_redis_cache_fin,
466 rt,
467 M);
468 ev_timer_start (rt->task->event_loop, &rt->timer_ev);
469 rt->has_event = TRUE;
470 }
471
472 /* We need to return OK every time */
473 return RSPAMD_LEARN_OK;
474 }
475
476 gint
rspamd_stat_cache_redis_learn(struct rspamd_task * task,gboolean is_spam,gpointer runtime)477 rspamd_stat_cache_redis_learn (struct rspamd_task *task,
478 gboolean is_spam,
479 gpointer runtime)
480 {
481 struct rspamd_redis_cache_runtime *rt = runtime;
482 gchar *h;
483 gint flag;
484
485 if (rt == NULL || rt->ctx == NULL || rspamd_session_blocked (task->s)) {
486 return RSPAMD_LEARN_INGORE;
487 }
488
489 h = rspamd_mempool_get_variable (task->task_pool, "words_hash");
490 g_assert (h != NULL);
491
492 flag = (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? 1 : -1;
493
494 if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_set, rt,
495 "HSET %s %s %d",
496 rt->ctx->redis_object, h, flag) == REDIS_OK) {
497 rspamd_session_add_event (task->s,
498 rspamd_redis_cache_fin, rt, M);
499 ev_timer_start (rt->task->event_loop, &rt->timer_ev);
500 rt->has_event = TRUE;
501 }
502
503 /* We need to return OK every time */
504 return RSPAMD_LEARN_OK;
505 }
506
507 void
rspamd_stat_cache_redis_close(gpointer c)508 rspamd_stat_cache_redis_close (gpointer c)
509 {
510 struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *)c;
511 lua_State *L;
512
513 L = ctx->L;
514
515 if (ctx->conf_ref) {
516 luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref);
517 }
518
519 g_free (ctx);
520 }
521