1 /*-
2  * Copyright 2016 Vsevolod Stakhov
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 /*
17  * Bayesian classifier
18  */
19 #include "classifiers.h"
20 #include "rspamd.h"
21 #include "stat_internal.h"
22 #include "math.h"
23 
24 #define msg_err_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_CRITICAL, \
25         "bayes", task->task_pool->tag.uid, \
26         G_STRFUNC, \
27         __VA_ARGS__)
28 #define msg_warn_bayes(...)   rspamd_default_log_function (G_LOG_LEVEL_WARNING, \
29         "bayes", task->task_pool->tag.uid, \
30         G_STRFUNC, \
31         __VA_ARGS__)
32 #define msg_info_bayes(...)   rspamd_default_log_function (G_LOG_LEVEL_INFO, \
33         "bayes", task->task_pool->tag.uid, \
34         G_STRFUNC, \
35         __VA_ARGS__)
36 #define msg_debug_bayes(...)  rspamd_conditional_debug_fast (NULL, task->from_addr, \
37         rspamd_bayes_log_id, "bayes", task->task_pool->tag.uid, \
38         G_STRFUNC, \
39         __VA_ARGS__)
40 
INIT_LOG_MODULE_PUBLIC(bayes)41 INIT_LOG_MODULE_PUBLIC(bayes)
42 
43 static inline GQuark
44 bayes_error_quark (void)
45 {
46 	return g_quark_from_static_string ("bayes-error");
47 }
48 
49 /**
50  * Returns probability of chisquare > value with specified number of freedom
51  * degrees
52  * @param value value to test
53  * @param freedom_deg number of degrees of freedom
54  * @return
55  */
56 static gdouble
inv_chi_square(struct rspamd_task * task,gdouble value,gint freedom_deg)57 inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg)
58 {
59 	double prob, sum, m;
60 	gint i;
61 
62 	errno = 0;
63 	m = -value;
64 	prob = exp (value);
65 
66 	if (errno == ERANGE) {
67 		/*
68 		 * e^x where x is large *NEGATIVE* number is OK, so we have a very strong
69 		 * confidence that inv-chi-square is close to zero
70 		 */
71 		msg_debug_bayes ("exp overflow");
72 
73 		if (value < 0) {
74 			return 0;
75 		}
76 		else {
77 			return 1.0;
78 		}
79 	}
80 
81 	sum = prob;
82 
83 	msg_debug_bayes ("m: %f, probability: %g", m, prob);
84 
85 	/*
86 	 * m is our confidence in class
87 	 * prob is e ^ x (small value since x is normally less than zero
88 	 * So we integrate over degrees of freedom and produce the total result
89 	 * from 1.0 (no confidence) to 0.0 (full confidence)
90 	 */
91 	for (i = 1; i < freedom_deg; i++) {
92 		prob *= m / (gdouble)i;
93 		sum += prob;
94 		msg_debug_bayes ("i=%d, probability: %g, sum: %g", i, prob, sum);
95 	}
96 
97 	return MIN (1.0, sum);
98 }
99 
100 struct bayes_task_closure {
101 	double ham_prob;
102 	double spam_prob;
103 	gdouble meta_skip_prob;
104 	guint64 processed_tokens;
105 	guint64 total_hits;
106 	guint64 text_tokens;
107 	struct rspamd_task *task;
108 };
109 
110 /*
111  * Mathematically we use pow(complexity, complexity), where complexity is the
112  * window index
113  */
114 static const double feature_weight[] = { 0, 3125, 256, 27, 1, 0, 0, 0 };
115 
116 #define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt)))
117 /*
118  * In this callback we calculate local probabilities for tokens
119  */
120 static void
bayes_classify_token(struct rspamd_classifier * ctx,rspamd_token_t * tok,struct bayes_task_closure * cl)121 bayes_classify_token (struct rspamd_classifier *ctx,
122 		rspamd_token_t *tok, struct bayes_task_closure *cl)
123 {
124 	guint i;
125 	gint id;
126 	guint spam_count = 0, ham_count = 0, total_count = 0;
127 	struct rspamd_statfile *st;
128 	struct rspamd_task *task;
129 	const gchar *token_type = "txt";
130 	double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
131 		ham_prob, fw, w, val;
132 
133 	task = cl->task;
134 
135 #if 0
136 	if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_LUA_META) {
137 		/* Ignore lua metatokens for now */
138 		return;
139 	}
140 #endif
141 
142 	if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
143 		val = rspamd_random_double_fast ();
144 
145 		if (val <= cl->meta_skip_prob) {
146 			if (tok->t1 && tok->t2) {
147 				msg_debug_bayes (
148 						"token(meta) %uL <%*s:%*s> probabilistically skipped",
149 						tok->data,
150 						(int) tok->t1->original.len, tok->t1->original.begin,
151 						(int) tok->t2->original.len, tok->t2->original.begin);
152 			}
153 
154 			return;
155 		}
156 	}
157 
158 	for (i = 0; i < ctx->statfiles_ids->len; i++) {
159 		id = g_array_index (ctx->statfiles_ids, gint, i);
160 		st = g_ptr_array_index (ctx->ctx->statfiles, id);
161 		g_assert (st != NULL);
162 		val = tok->values[id];
163 
164 		if (val > 0) {
165 			if (st->stcf->is_spam) {
166 				spam_count += val;
167 			}
168 			else {
169 				ham_count += val;
170 			}
171 
172 			total_count += val;
173 			cl->total_hits += val;
174 		}
175 	}
176 
177 	/* Probability for this token */
178 	if (total_count >= ctx->cfg->min_token_hits) {
179 		spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns));
180 		ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns));
181 		spam_prob = spam_freq / (spam_freq + ham_freq);
182 		ham_prob = ham_freq / (spam_freq + ham_freq);
183 
184 		if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
185 			fw = 1.0;
186 		}
187 		else {
188 			fw = feature_weight[tok->window_idx %
189 					G_N_ELEMENTS (feature_weight)];
190 		}
191 
192 
193 		w = (fw * total_count) / (1.0 + fw * total_count);
194 
195 		bayes_spam_prob = PROB_COMBINE (spam_prob, total_count, w, 0.5);
196 
197 		if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) ||
198 			(bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) {
199 			msg_debug_bayes (
200 					"token %uL <%*s:%*s> skipped, probability not in range: %f",
201 					tok->data,
202 					(int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
203 					(int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
204 					bayes_spam_prob);
205 
206 			return;
207 		}
208 
209 		bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5);
210 
211 		cl->spam_prob += log (bayes_spam_prob);
212 		cl->ham_prob += log (bayes_ham_prob);
213 		cl->processed_tokens ++;
214 
215 		if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
216 			cl->text_tokens ++;
217 		}
218 		else {
219 			token_type = "meta";
220 		}
221 
222 		if (tok->t1 && tok->t2) {
223 			msg_debug_bayes ("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
224 					"total_count: %ud, "
225 					"spam_count: %ud, ham_count: %ud,"
226 					"spam_prob: %.3f, ham_prob: %.3f, "
227 					"bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
228 					"current spam probability: %.3f, current ham probability: %.3f",
229 					token_type,
230 					tok->data,
231 					(int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
232 					(int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
233 					fw, w, total_count, spam_count, ham_count,
234 					spam_prob, ham_prob,
235 					bayes_spam_prob, bayes_ham_prob,
236 					cl->spam_prob, cl->ham_prob);
237 		}
238 		else {
239 			msg_debug_bayes ("token(%s) %uL <?:?>: weight: %f, cf: %f, "
240 					"total_count: %ud, "
241 					"spam_count: %ud, ham_count: %ud,"
242 					"spam_prob: %.3f, ham_prob: %.3f, "
243 					"bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
244 					"current spam probability: %.3f, current ham probability: %.3f",
245 					token_type,
246 					tok->data,
247 					fw, w, total_count, spam_count, ham_count,
248 					spam_prob, ham_prob,
249 					bayes_spam_prob, bayes_ham_prob,
250 					cl->spam_prob, cl->ham_prob);
251 		}
252 	}
253 }
254 
255 
256 
257 gboolean
bayes_init(struct rspamd_config * cfg,struct ev_loop * ev_base,struct rspamd_classifier * cl)258 bayes_init (struct rspamd_config *cfg,
259 			struct ev_loop *ev_base,
260 			struct rspamd_classifier *cl)
261 {
262 	cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER;
263 
264 	return TRUE;
265 }
266 
267 void
bayes_fin(struct rspamd_classifier * cl)268 bayes_fin (struct rspamd_classifier *cl)
269 {
270 }
271 
272 gboolean
bayes_classify(struct rspamd_classifier * ctx,GPtrArray * tokens,struct rspamd_task * task)273 bayes_classify (struct rspamd_classifier * ctx,
274 		GPtrArray *tokens,
275 		struct rspamd_task *task)
276 {
277 	double final_prob, h, s, *pprob;
278 	gchar sumbuf[32];
279 	struct rspamd_statfile *st = NULL;
280 	struct bayes_task_closure cl;
281 	rspamd_token_t *tok;
282 	guint i, text_tokens = 0;
283 	gint id;
284 
285 	g_assert (ctx != NULL);
286 	g_assert (tokens != NULL);
287 
288 	memset (&cl, 0, sizeof (cl));
289 	cl.task = task;
290 
291 	/* Check min learns */
292 	if (ctx->cfg->min_learns > 0) {
293 		if (ctx->ham_learns < ctx->cfg->min_learns) {
294 			msg_info_task ("not classified as ham. The ham class needs more "
295 					"training samples. Currently: %ul; minimum %ud required",
296 					ctx->ham_learns, ctx->cfg->min_learns);
297 
298 			return TRUE;
299 		}
300 		if (ctx->spam_learns < ctx->cfg->min_learns) {
301 			msg_info_task ("not classified as spam. The spam class needs more "
302 					"training samples. Currently: %ul; minimum %ud required",
303 					ctx->spam_learns, ctx->cfg->min_learns);
304 
305 			return TRUE;
306 		}
307 	}
308 
309 	for (i = 0; i < tokens->len; i ++) {
310 		tok = g_ptr_array_index (tokens, i);
311 		if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
312 			text_tokens ++;
313 		}
314 	}
315 
316 	if (text_tokens == 0) {
317 		msg_info_task ("skipped classification as there are no text tokens. "
318 				"Total tokens: %ud",
319 				tokens->len);
320 
321 		return TRUE;
322 	}
323 
324 	/*
325 	 * Skip some metatokens if we don't have enough text tokens
326 	 */
327 	if (text_tokens > tokens->len - text_tokens) {
328 		cl.meta_skip_prob = 0.0;
329 	}
330 	else {
331 		cl.meta_skip_prob = 1.0 - text_tokens / tokens->len;
332 	}
333 
334 	for (i = 0; i < tokens->len; i ++) {
335 		tok = g_ptr_array_index (tokens, i);
336 
337 		bayes_classify_token (ctx, tok, &cl);
338 	}
339 
340 	if (cl.processed_tokens == 0) {
341 		msg_info_bayes ("no tokens found in bayes database "
342 				  "(%ud total tokens, %ud text tokens), ignore stats",
343 				tokens->len, text_tokens);
344 
345 		return TRUE;
346 	}
347 
348 	if (ctx->cfg->min_tokens > 0 &&
349 		cl.text_tokens < (gint)(ctx->cfg->min_tokens * 0.1)) {
350 		msg_info_bayes ("ignore bayes probability since we have "
351 						"found too few text tokens: %uL (of %ud checked), "
352 						"at least %d required",
353 						cl.text_tokens,
354 						text_tokens,
355 						(gint)(ctx->cfg->min_tokens * 0.1));
356 
357 		return TRUE;
358 	}
359 
360 	if (cl.spam_prob > -300 && cl.ham_prob > -300) {
361 		/* Fisher value is low enough to apply inv_chi_square */
362 		h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
363 		s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);
364 	}
365 	else {
366 		/* Use naive method */
367 		if (cl.spam_prob < cl.ham_prob) {
368 			h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) /
369 					(1.0 + exp(cl.spam_prob - cl.ham_prob));
370 			s = 1.0 - h;
371 		}
372 		else {
373 			s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) /
374 				(1.0 + exp(cl.ham_prob - cl.spam_prob));
375 			h = 1.0 - s;
376 		}
377 	}
378 
379 	if (isfinite (s) && isfinite (h)) {
380 		final_prob = (s + 1.0 - h) / 2.;
381 		msg_debug_bayes (
382 				"got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
383 				" %L tokens processed of %ud total tokens;"
384 				" %uL text tokens found of %ud text tokens)",
385 				cl.ham_prob,
386 				h,
387 				cl.spam_prob,
388 				s,
389 				cl.processed_tokens,
390 				tokens->len,
391 				cl.text_tokens,
392 				text_tokens);
393 	}
394 	else {
395 		/*
396 		 * We have some overflow, hence we need to check which class
397 		 * is NaN
398 		 */
399 		if (isfinite (h)) {
400 			final_prob = 1.0;
401 			msg_debug_bayes ("spam class is full: no"
402 					" ham samples");
403 		}
404 		else if (isfinite (s)) {
405 			final_prob = 0.0;
406 			msg_debug_bayes ("ham class is full: no"
407 					" spam samples");
408 		}
409 		else {
410 			final_prob = 0.5;
411 			msg_warn_bayes ("spam and ham classes are both full");
412 		}
413 	}
414 
415 	pprob = rspamd_mempool_alloc (task->task_pool, sizeof (*pprob));
416 	*pprob = final_prob;
417 	rspamd_mempool_set_variable (task->task_pool, "bayes_prob", pprob, NULL);
418 
419 	if (cl.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
420 		/* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
421 		for (i = 0; i < ctx->statfiles_ids->len; i++) {
422 			id = g_array_index (ctx->statfiles_ids, gint, i);
423 			st = g_ptr_array_index (ctx->ctx->statfiles, id);
424 
425 			if (final_prob > 0.5 && st->stcf->is_spam) {
426 				break;
427 			}
428 			else if (final_prob < 0.5 && !st->stcf->is_spam) {
429 				break;
430 			}
431 		}
432 
433 		/* Correctly scale HAM */
434 		if (final_prob < 0.5) {
435 			final_prob = 1.0 - final_prob;
436 		}
437 
438 		/*
439 		 * Bayes p is from 0.5 to 1.0, but confidence is from 0 to 1, so
440 		 * we need to rescale it to display correctly
441 		 */
442 		rspamd_snprintf (sumbuf, sizeof (sumbuf), "%.2f%%",
443 				(final_prob - 0.5) * 200.);
444 		final_prob = rspamd_normalize_probability (final_prob, 0.5);
445 		g_assert (st != NULL);
446 
447 		if (final_prob > 1 || final_prob < 0) {
448 			msg_err_bayes ("internal error: probability %f is outside of the "
449 				  "allowed range [0..1]", final_prob);
450 
451 			if (final_prob > 1) {
452 				final_prob = 1.0;
453 			}
454 			else {
455 				final_prob = 0.0;
456 			}
457 		}
458 
459 		rspamd_task_insert_result (task,
460 				st->stcf->symbol,
461 				final_prob,
462 				sumbuf);
463 	}
464 
465 	return TRUE;
466 }
467 
468 gboolean
bayes_learn_spam(struct rspamd_classifier * ctx,GPtrArray * tokens,struct rspamd_task * task,gboolean is_spam,gboolean unlearn,GError ** err)469 bayes_learn_spam (struct rspamd_classifier * ctx,
470 		GPtrArray *tokens,
471 		struct rspamd_task *task,
472 		gboolean is_spam,
473 		gboolean unlearn,
474 		GError **err)
475 {
476 	guint i, j, total_cnt, spam_cnt, ham_cnt;
477 	gint id;
478 	struct rspamd_statfile *st;
479 	rspamd_token_t *tok;
480 	gboolean incrementing;
481 
482 	g_assert (ctx != NULL);
483 	g_assert (tokens != NULL);
484 
485 	incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
486 
487 	for (i = 0; i < tokens->len; i++) {
488 		total_cnt = 0;
489 		spam_cnt = 0;
490 		ham_cnt = 0;
491 		tok = g_ptr_array_index (tokens, i);
492 
493 		for (j = 0; j < ctx->statfiles_ids->len; j++) {
494 			id = g_array_index (ctx->statfiles_ids, gint, j);
495 			st = g_ptr_array_index (ctx->ctx->statfiles, id);
496 			g_assert (st != NULL);
497 
498 			if (!!st->stcf->is_spam == !!is_spam) {
499 				if (incrementing) {
500 					tok->values[id] = 1;
501 				}
502 				else {
503 					tok->values[id]++;
504 				}
505 
506 				total_cnt += tok->values[id];
507 
508 				if (st->stcf->is_spam) {
509 					spam_cnt += tok->values[id];
510 				}
511 				else {
512 					ham_cnt += tok->values[id];
513 				}
514 			}
515 			else {
516 				if (tok->values[id] > 0 && unlearn) {
517 					/* Unlearning */
518 					if (incrementing) {
519 						tok->values[id] = -1;
520 					}
521 					else {
522 						tok->values[id]--;
523 					}
524 
525 					if (st->stcf->is_spam) {
526 						spam_cnt += tok->values[id];
527 					}
528 					else {
529 						ham_cnt += tok->values[id];
530 					}
531 					total_cnt += tok->values[id];
532 				}
533 				else if (incrementing) {
534 					tok->values[id] = 0;
535 				}
536 			}
537 		}
538 
539 		if (tok->t1 && tok->t2) {
540 			msg_debug_bayes ("token %uL <%*s:%*s>: window: %d, total_count: %d, "
541 					"spam_count: %d, ham_count: %d",
542 					tok->data,
543 					(int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
544 					(int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
545 					tok->window_idx, total_cnt, spam_cnt, ham_cnt);
546 		}
547 		else {
548 			msg_debug_bayes ("token %uL <?:?>: window: %d, total_count: %d, "
549 					"spam_count: %d, ham_count: %d",
550 					tok->data,
551 					tok->window_idx, total_cnt, spam_cnt, ham_cnt);
552 		}
553 	}
554 
555 	return TRUE;
556 }
557