1 /*-
2  * Copyright 2016 Vsevolod Stakhov
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "config.h"
17 #include "learn_cache.h"
18 #include "rspamd.h"
19 #include "stat_api.h"
20 #include "stat_internal.h"
21 #include "cryptobox.h"
22 #include "ucl.h"
23 #include "hiredis.h"
24 #include "adapters/libev.h"
25 #include "lua/lua_common.h"
26 #include "libmime/message.h"
27 
28 #define REDIS_DEFAULT_TIMEOUT 0.5
29 #define REDIS_STAT_TIMEOUT 30
30 #define REDIS_DEFAULT_PORT 6379
31 #define DEFAULT_REDIS_KEY "learned_ids"
32 
33 static const gchar *M = "redis learn cache";
34 
35 struct rspamd_redis_cache_ctx {
36 	lua_State *L;
37 	struct rspamd_statfile_config *stcf;
38 	const gchar *password;
39 	const gchar *dbname;
40 	const gchar *redis_object;
41 	gdouble timeout;
42 	gint conf_ref;
43 };
44 
45 struct rspamd_redis_cache_runtime {
46 	struct rspamd_redis_cache_ctx *ctx;
47 	struct rspamd_task *task;
48 	struct upstream *selected;
49 	ev_timer timer_ev;
50 	redisAsyncContext *redis;
51 	gboolean has_event;
52 };
53 
54 static GQuark
rspamd_stat_cache_redis_quark(void)55 rspamd_stat_cache_redis_quark (void)
56 {
57 	return g_quark_from_static_string (M);
58 }
59 
60 static inline struct upstream_list *
rspamd_redis_get_servers(struct rspamd_redis_cache_ctx * ctx,const gchar * what)61 rspamd_redis_get_servers (struct rspamd_redis_cache_ctx *ctx,
62 						  const gchar *what)
63 {
64 	lua_State *L = ctx->L;
65 	struct upstream_list *res;
66 
67 	lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref);
68 	lua_pushstring (L, what);
69 	lua_gettable (L, -2);
70 	res = *((struct upstream_list**)lua_touserdata (L, -1));
71 	lua_settop (L, 0);
72 
73 	return res;
74 }
75 
76 static void
rspamd_redis_cache_maybe_auth(struct rspamd_redis_cache_ctx * ctx,redisAsyncContext * redis)77 rspamd_redis_cache_maybe_auth (struct rspamd_redis_cache_ctx *ctx,
78 		redisAsyncContext *redis)
79 {
80 	if (ctx->password) {
81 		redisAsyncCommand (redis, NULL, NULL, "AUTH %s", ctx->password);
82 	}
83 	if (ctx->dbname) {
84 		redisAsyncCommand (redis, NULL, NULL, "SELECT %s", ctx->dbname);
85 	}
86 }
87 
88 /* Called on connection termination */
89 static void
rspamd_redis_cache_fin(gpointer data)90 rspamd_redis_cache_fin (gpointer data)
91 {
92 	struct rspamd_redis_cache_runtime *rt = data;
93 	redisAsyncContext *redis;
94 
95 	rt->has_event = FALSE;
96 	ev_timer_stop (rt->task->event_loop, &rt->timer_ev);
97 
98 	if (rt->redis) {
99 		redis = rt->redis;
100 		rt->redis = NULL;
101 		/* This calls for all callbacks pending */
102 		redisAsyncFree (redis);
103 	}
104 }
105 
106 static void
rspamd_redis_cache_timeout(EV_P_ ev_timer * w,int revents)107 rspamd_redis_cache_timeout (EV_P_ ev_timer *w, int revents)
108 {
109 	struct rspamd_redis_cache_runtime *rt =
110 			(struct rspamd_redis_cache_runtime *)w->data;
111 	struct rspamd_task *task;
112 
113 	task = rt->task;
114 
115 	msg_err_task ("connection to redis server %s timed out",
116 			rspamd_upstream_name (rt->selected));
117 	rspamd_upstream_fail (rt->selected, FALSE, "timeout");
118 
119 	if (rt->has_event) {
120 		rspamd_session_remove_event (task->s, rspamd_redis_cache_fin, rt);
121 	}
122 }
123 
124 /* Called when we have checked the specified message id */
125 static void
rspamd_stat_cache_redis_get(redisAsyncContext * c,gpointer r,gpointer priv)126 rspamd_stat_cache_redis_get (redisAsyncContext *c, gpointer r, gpointer priv)
127 {
128 	struct rspamd_redis_cache_runtime *rt = priv;
129 	redisReply *reply = r;
130 	struct rspamd_task *task;
131 	glong val = 0;
132 
133 	task = rt->task;
134 
135 	if (c->err == 0) {
136 		if (reply) {
137 			if (G_LIKELY (reply->type == REDIS_REPLY_INTEGER)) {
138 				val = reply->integer;
139 			}
140 			else if (reply->type == REDIS_REPLY_STRING) {
141 				rspamd_strtol (reply->str, reply->len, &val);
142 			}
143 			else {
144 				if (reply->type == REDIS_REPLY_ERROR) {
145 					msg_err_task ("cannot learn %s: redis error: \"%s\"",
146 							rt->ctx->stcf->symbol, reply->str);
147 				}
148 				else if (reply->type != REDIS_REPLY_NIL) {
149 					msg_err_task ("bad learned type for %s: %d",
150 							rt->ctx->stcf->symbol, reply->type);
151 				}
152 
153 				val = 0;
154 			}
155 		}
156 
157 		if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
158 				(val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
159 			/* Already learned */
160 			msg_info_task ("<%s> has been already "
161 					"learned as %s, ignore it", MESSAGE_FIELD (task, message_id),
162 					(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
163 			task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
164 		}
165 		else if (val != 0) {
166 			/* Unlearn flag */
167 			task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
168 		}
169 
170 		rspamd_upstream_ok (rt->selected);
171 	}
172 	else {
173 		rspamd_upstream_fail (rt->selected, FALSE, c->errstr);
174 	}
175 
176 	if (rt->has_event) {
177 		rspamd_session_remove_event (task->s, rspamd_redis_cache_fin, rt);
178 	}
179 }
180 
181 /* Called when we have learned the specified message id */
182 static void
rspamd_stat_cache_redis_set(redisAsyncContext * c,gpointer r,gpointer priv)183 rspamd_stat_cache_redis_set (redisAsyncContext *c, gpointer r, gpointer priv)
184 {
185 	struct rspamd_redis_cache_runtime *rt = priv;
186 	struct rspamd_task *task;
187 
188 	task = rt->task;
189 
190 	if (c->err == 0) {
191 		/* XXX: we ignore results here */
192 		rspamd_upstream_ok (rt->selected);
193 	}
194 	else {
195 		rspamd_upstream_fail (rt->selected, FALSE, c->errstr);
196 	}
197 
198 	if (rt->has_event) {
199 		rspamd_session_remove_event (task->s, rspamd_redis_cache_fin, rt);
200 	}
201 }
202 
203 static void
rspamd_stat_cache_redis_generate_id(struct rspamd_task * task)204 rspamd_stat_cache_redis_generate_id (struct rspamd_task *task)
205 {
206 	rspamd_cryptobox_hash_state_t st;
207 	rspamd_token_t *tok;
208 	guint i;
209 	guchar out[rspamd_cryptobox_HASHBYTES];
210 	gchar *b32out;
211 	gchar *user = NULL;
212 
213 	rspamd_cryptobox_hash_init (&st, NULL, 0);
214 
215 	user = rspamd_mempool_get_variable (task->task_pool, "stat_user");
216 	/* Use dedicated hash space for per users cache */
217 	if (user != NULL) {
218 		rspamd_cryptobox_hash_update (&st, user, strlen (user));
219 	}
220 
221 	for (i = 0; i < task->tokens->len; i ++) {
222 		tok = g_ptr_array_index (task->tokens, i);
223 		rspamd_cryptobox_hash_update (&st, (guchar *)&tok->data,
224 				sizeof (tok->data));
225 	}
226 
227 	rspamd_cryptobox_hash_final (&st, out);
228 
229 	b32out = rspamd_mempool_alloc (task->task_pool,
230 			sizeof (out) * 8 / 5 + 3);
231 	i = rspamd_encode_base32_buf (out, sizeof (out), b32out,
232 			sizeof (out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);
233 
234 	if (i > 0) {
235 		/* Zero terminate */
236 		b32out[i] = '\0';
237 	}
238 
239 	rspamd_mempool_set_variable (task->task_pool, "words_hash", b32out, NULL);
240 }
241 
242 gpointer
rspamd_stat_cache_redis_init(struct rspamd_stat_ctx * ctx,struct rspamd_config * cfg,struct rspamd_statfile * st,const ucl_object_t * cf)243 rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
244 		struct rspamd_config *cfg,
245 		struct rspamd_statfile *st,
246 		const ucl_object_t *cf)
247 {
248 	struct rspamd_redis_cache_ctx *cache_ctx;
249 	struct rspamd_statfile_config *stf = st->stcf;
250 	const ucl_object_t *obj;
251 	gboolean ret = FALSE;
252 	lua_State *L = (lua_State *)cfg->lua_state;
253 	gint conf_ref = -1;
254 
255 	cache_ctx = g_malloc0 (sizeof (*cache_ctx));
256 	cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
257 	cache_ctx->L = L;
258 
259 	/* First search in backend configuration */
260 	obj = ucl_object_lookup (st->classifier->cfg->opts, "backend");
261 	if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) {
262 		ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref);
263 	}
264 
265 	/* Now try statfiles config */
266 	if (!ret && stf->opts) {
267 		ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref);
268 	}
269 
270 	/* Now try classifier config */
271 	if (!ret && st->classifier->cfg->opts) {
272 		ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref);
273 	}
274 
275 	/* Now try global redis settings */
276 	if (!ret) {
277 		obj = ucl_object_lookup (cfg->rcl_obj, "redis");
278 
279 		if (obj) {
280 			const ucl_object_t *specific_obj;
281 
282 			specific_obj = ucl_object_lookup (obj, "statistics");
283 
284 			if (specific_obj) {
285 				ret = rspamd_lua_try_load_redis (L,
286 						specific_obj, cfg, &conf_ref);
287 			}
288 			else {
289 				ret = rspamd_lua_try_load_redis (L,
290 						obj, cfg, &conf_ref);
291 			}
292 		}
293 	}
294 
295 	if (!ret) {
296 		msg_err_config ("cannot init redis cache for %s", stf->symbol);
297 		g_free (cache_ctx);
298 		return NULL;
299 	}
300 
301 	obj = ucl_object_lookup (st->classifier->cfg->opts, "cache_key");
302 
303 	if (obj) {
304 		cache_ctx->redis_object = ucl_object_tostring (obj);
305 	}
306 	else {
307 		cache_ctx->redis_object = DEFAULT_REDIS_KEY;
308 	}
309 
310 	cache_ctx->conf_ref = conf_ref;
311 
312 	/* Check some common table values */
313 	lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref);
314 
315 	lua_pushstring (L, "timeout");
316 	lua_gettable (L, -2);
317 	if (lua_type (L, -1) == LUA_TNUMBER) {
318 		cache_ctx->timeout = lua_tonumber (L, -1);
319 	}
320 	lua_pop (L, 1);
321 
322 	lua_pushstring (L, "db");
323 	lua_gettable (L, -2);
324 	if (lua_type (L, -1) == LUA_TSTRING) {
325 		cache_ctx->dbname = rspamd_mempool_strdup (cfg->cfg_pool,
326 				lua_tostring (L, -1));
327 	}
328 	lua_pop (L, 1);
329 
330 	lua_pushstring (L, "password");
331 	lua_gettable (L, -2);
332 	if (lua_type (L, -1) == LUA_TSTRING) {
333 		cache_ctx->password = rspamd_mempool_strdup (cfg->cfg_pool,
334 				lua_tostring (L, -1));
335 	}
336 	lua_pop (L, 1);
337 
338 	lua_settop (L, 0);
339 
340 	cache_ctx->stcf = stf;
341 
342 	return (gpointer)cache_ctx;
343 }
344 
345 gpointer
rspamd_stat_cache_redis_runtime(struct rspamd_task * task,gpointer c,gboolean learn)346 rspamd_stat_cache_redis_runtime (struct rspamd_task *task,
347 		gpointer c, gboolean learn)
348 {
349 	struct rspamd_redis_cache_ctx *ctx = c;
350 	struct rspamd_redis_cache_runtime *rt;
351 	struct upstream *up;
352 	struct upstream_list *ups;
353 	rspamd_inet_addr_t *addr;
354 
355 	g_assert (ctx != NULL);
356 
357 	if (task->tokens == NULL || task->tokens->len == 0) {
358 		return NULL;
359 	}
360 
361 	if (learn) {
362 		ups = rspamd_redis_get_servers (ctx, "write_servers");
363 
364 		if (!ups) {
365 			msg_err_task ("no write servers defined for %s, cannot learn",
366 					ctx->stcf->symbol);
367 			return NULL;
368 		}
369 
370 		up = rspamd_upstream_get (ups,
371 				RSPAMD_UPSTREAM_MASTER_SLAVE,
372 				NULL,
373 				0);
374 	}
375 	else {
376 		ups = rspamd_redis_get_servers (ctx, "read_servers");
377 
378 		if (!ups) {
379 			msg_err_task ("no read servers defined for %s, cannot check",
380 					ctx->stcf->symbol);
381 			return NULL;
382 		}
383 
384 		up = rspamd_upstream_get (ups,
385 				RSPAMD_UPSTREAM_ROUND_ROBIN,
386 				NULL,
387 				0);
388 	}
389 
390 	if (up == NULL) {
391 		msg_err_task ("no upstreams reachable");
392 		return NULL;
393 	}
394 
395 	rt = rspamd_mempool_alloc0 (task->task_pool, sizeof (*rt));
396 	rt->selected = up;
397 	rt->task = task;
398 	rt->ctx = ctx;
399 
400 	addr = rspamd_upstream_addr_next (up);
401 	g_assert (addr != NULL);
402 
403 	if (rspamd_inet_address_get_af (addr) == AF_UNIX) {
404 		rt->redis = redisAsyncConnectUnix (rspamd_inet_address_to_string (addr));
405 	}
406 	else {
407 		rt->redis = redisAsyncConnect (rspamd_inet_address_to_string (addr),
408 				rspamd_inet_address_get_port (addr));
409 	}
410 
411 	if (rt->redis == NULL) {
412 		msg_warn_task ("cannot connect to redis server %s: %s",
413 				rspamd_inet_address_to_string_pretty (addr),
414 				strerror (errno));
415 
416 		return NULL;
417 	}
418 	else if (rt->redis->err != REDIS_OK) {
419 		msg_warn_task ("cannot connect to redis server %s: %s",
420 				rspamd_inet_address_to_string_pretty (addr),
421 				rt->redis->errstr);
422 		redisAsyncFree (rt->redis);
423 		rt->redis = NULL;
424 
425 		return NULL;
426 	}
427 
428 	redisLibevAttach (task->event_loop, rt->redis);
429 
430 	/* Now check stats */
431 	rt->timer_ev.data = rt;
432 	ev_timer_init (&rt->timer_ev, rspamd_redis_cache_timeout,
433 			rt->ctx->timeout, 0.0);
434 	rspamd_redis_cache_maybe_auth (ctx, rt->redis);
435 
436 	if (!learn) {
437 		rspamd_stat_cache_redis_generate_id (task);
438 	}
439 
440 	return rt;
441 }
442 
443 gint
rspamd_stat_cache_redis_check(struct rspamd_task * task,gboolean is_spam,gpointer runtime)444 rspamd_stat_cache_redis_check (struct rspamd_task *task,
445 		gboolean is_spam,
446 		gpointer runtime)
447 {
448 	struct rspamd_redis_cache_runtime *rt = runtime;
449 	gchar *h;
450 
451 	if (rspamd_session_blocked (task->s)) {
452 		return RSPAMD_LEARN_INGORE;
453 	}
454 
455 	h = rspamd_mempool_get_variable (task->task_pool, "words_hash");
456 
457 	if (h == NULL) {
458 		return RSPAMD_LEARN_INGORE;
459 	}
460 
461 	if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_get, rt,
462 			"HGET %s %s",
463 			rt->ctx->redis_object, h) == REDIS_OK) {
464 		rspamd_session_add_event (task->s,
465 				rspamd_redis_cache_fin,
466 				rt,
467 				M);
468 		ev_timer_start (rt->task->event_loop, &rt->timer_ev);
469 		rt->has_event = TRUE;
470 	}
471 
472 	/* We need to return OK every time */
473 	return RSPAMD_LEARN_OK;
474 }
475 
476 gint
rspamd_stat_cache_redis_learn(struct rspamd_task * task,gboolean is_spam,gpointer runtime)477 rspamd_stat_cache_redis_learn (struct rspamd_task *task,
478 		gboolean is_spam,
479 		gpointer runtime)
480 {
481 	struct rspamd_redis_cache_runtime *rt = runtime;
482 	gchar *h;
483 	gint flag;
484 
485 	if (rt == NULL || rt->ctx == NULL || rspamd_session_blocked (task->s)) {
486 		return RSPAMD_LEARN_INGORE;
487 	}
488 
489 	h = rspamd_mempool_get_variable (task->task_pool, "words_hash");
490 	g_assert (h != NULL);
491 
492 	flag = (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? 1 : -1;
493 
494 	if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_set, rt,
495 			"HSET %s %s %d",
496 			rt->ctx->redis_object, h, flag) == REDIS_OK) {
497 		rspamd_session_add_event (task->s,
498 				rspamd_redis_cache_fin, rt, M);
499 		ev_timer_start (rt->task->event_loop, &rt->timer_ev);
500 		rt->has_event = TRUE;
501 	}
502 
503 	/* We need to return OK every time */
504 	return RSPAMD_LEARN_OK;
505 }
506 
507 void
rspamd_stat_cache_redis_close(gpointer c)508 rspamd_stat_cache_redis_close (gpointer c)
509 {
510 	struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *)c;
511 	lua_State *L;
512 
513 	L = ctx->L;
514 
515 	if (ctx->conf_ref) {
516 		luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref);
517 	}
518 
519 	g_free (ctx);
520 }
521