1 /*-
2 * Copyright 2016 Vsevolod Stakhov
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 /*
17 * Bayesian classifier
18 */
19 #include "classifiers.h"
20 #include "rspamd.h"
21 #include "stat_internal.h"
22 #include "math.h"
23
24 #define msg_err_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_CRITICAL, \
25 "bayes", task->task_pool->tag.uid, \
26 G_STRFUNC, \
27 __VA_ARGS__)
28 #define msg_warn_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_WARNING, \
29 "bayes", task->task_pool->tag.uid, \
30 G_STRFUNC, \
31 __VA_ARGS__)
32 #define msg_info_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_INFO, \
33 "bayes", task->task_pool->tag.uid, \
34 G_STRFUNC, \
35 __VA_ARGS__)
36 #define msg_debug_bayes(...) rspamd_conditional_debug_fast (NULL, task->from_addr, \
37 rspamd_bayes_log_id, "bayes", task->task_pool->tag.uid, \
38 G_STRFUNC, \
39 __VA_ARGS__)
40
INIT_LOG_MODULE_PUBLIC(bayes)41 INIT_LOG_MODULE_PUBLIC(bayes)
42
43 static inline GQuark
44 bayes_error_quark (void)
45 {
46 return g_quark_from_static_string ("bayes-error");
47 }
48
49 /**
50 * Returns probability of chisquare > value with specified number of freedom
51 * degrees
52 * @param value value to test
53 * @param freedom_deg number of degrees of freedom
54 * @return
55 */
56 static gdouble
inv_chi_square(struct rspamd_task * task,gdouble value,gint freedom_deg)57 inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg)
58 {
59 double prob, sum, m;
60 gint i;
61
62 errno = 0;
63 m = -value;
64 prob = exp (value);
65
66 if (errno == ERANGE) {
67 /*
68 * e^x where x is large *NEGATIVE* number is OK, so we have a very strong
69 * confidence that inv-chi-square is close to zero
70 */
71 msg_debug_bayes ("exp overflow");
72
73 if (value < 0) {
74 return 0;
75 }
76 else {
77 return 1.0;
78 }
79 }
80
81 sum = prob;
82
83 msg_debug_bayes ("m: %f, probability: %g", m, prob);
84
85 /*
86 * m is our confidence in class
87 * prob is e ^ x (small value since x is normally less than zero
88 * So we integrate over degrees of freedom and produce the total result
89 * from 1.0 (no confidence) to 0.0 (full confidence)
90 */
91 for (i = 1; i < freedom_deg; i++) {
92 prob *= m / (gdouble)i;
93 sum += prob;
94 msg_debug_bayes ("i=%d, probability: %g, sum: %g", i, prob, sum);
95 }
96
97 return MIN (1.0, sum);
98 }
99
100 struct bayes_task_closure {
101 double ham_prob;
102 double spam_prob;
103 gdouble meta_skip_prob;
104 guint64 processed_tokens;
105 guint64 total_hits;
106 guint64 text_tokens;
107 struct rspamd_task *task;
108 };
109
110 /*
111 * Mathematically we use pow(complexity, complexity), where complexity is the
112 * window index
113 */
114 static const double feature_weight[] = { 0, 3125, 256, 27, 1, 0, 0, 0 };
115
116 #define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt)))
117 /*
118 * In this callback we calculate local probabilities for tokens
119 */
120 static void
bayes_classify_token(struct rspamd_classifier * ctx,rspamd_token_t * tok,struct bayes_task_closure * cl)121 bayes_classify_token (struct rspamd_classifier *ctx,
122 rspamd_token_t *tok, struct bayes_task_closure *cl)
123 {
124 guint i;
125 gint id;
126 guint spam_count = 0, ham_count = 0, total_count = 0;
127 struct rspamd_statfile *st;
128 struct rspamd_task *task;
129 const gchar *token_type = "txt";
130 double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
131 ham_prob, fw, w, val;
132
133 task = cl->task;
134
135 #if 0
136 if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_LUA_META) {
137 /* Ignore lua metatokens for now */
138 return;
139 }
140 #endif
141
142 if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
143 val = rspamd_random_double_fast ();
144
145 if (val <= cl->meta_skip_prob) {
146 if (tok->t1 && tok->t2) {
147 msg_debug_bayes (
148 "token(meta) %uL <%*s:%*s> probabilistically skipped",
149 tok->data,
150 (int) tok->t1->original.len, tok->t1->original.begin,
151 (int) tok->t2->original.len, tok->t2->original.begin);
152 }
153
154 return;
155 }
156 }
157
158 for (i = 0; i < ctx->statfiles_ids->len; i++) {
159 id = g_array_index (ctx->statfiles_ids, gint, i);
160 st = g_ptr_array_index (ctx->ctx->statfiles, id);
161 g_assert (st != NULL);
162 val = tok->values[id];
163
164 if (val > 0) {
165 if (st->stcf->is_spam) {
166 spam_count += val;
167 }
168 else {
169 ham_count += val;
170 }
171
172 total_count += val;
173 cl->total_hits += val;
174 }
175 }
176
177 /* Probability for this token */
178 if (total_count >= ctx->cfg->min_token_hits) {
179 spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns));
180 ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns));
181 spam_prob = spam_freq / (spam_freq + ham_freq);
182 ham_prob = ham_freq / (spam_freq + ham_freq);
183
184 if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
185 fw = 1.0;
186 }
187 else {
188 fw = feature_weight[tok->window_idx %
189 G_N_ELEMENTS (feature_weight)];
190 }
191
192
193 w = (fw * total_count) / (1.0 + fw * total_count);
194
195 bayes_spam_prob = PROB_COMBINE (spam_prob, total_count, w, 0.5);
196
197 if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) ||
198 (bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) {
199 msg_debug_bayes (
200 "token %uL <%*s:%*s> skipped, probability not in range: %f",
201 tok->data,
202 (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
203 (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
204 bayes_spam_prob);
205
206 return;
207 }
208
209 bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5);
210
211 cl->spam_prob += log (bayes_spam_prob);
212 cl->ham_prob += log (bayes_ham_prob);
213 cl->processed_tokens ++;
214
215 if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
216 cl->text_tokens ++;
217 }
218 else {
219 token_type = "meta";
220 }
221
222 if (tok->t1 && tok->t2) {
223 msg_debug_bayes ("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
224 "total_count: %ud, "
225 "spam_count: %ud, ham_count: %ud,"
226 "spam_prob: %.3f, ham_prob: %.3f, "
227 "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
228 "current spam probability: %.3f, current ham probability: %.3f",
229 token_type,
230 tok->data,
231 (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
232 (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
233 fw, w, total_count, spam_count, ham_count,
234 spam_prob, ham_prob,
235 bayes_spam_prob, bayes_ham_prob,
236 cl->spam_prob, cl->ham_prob);
237 }
238 else {
239 msg_debug_bayes ("token(%s) %uL <?:?>: weight: %f, cf: %f, "
240 "total_count: %ud, "
241 "spam_count: %ud, ham_count: %ud,"
242 "spam_prob: %.3f, ham_prob: %.3f, "
243 "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
244 "current spam probability: %.3f, current ham probability: %.3f",
245 token_type,
246 tok->data,
247 fw, w, total_count, spam_count, ham_count,
248 spam_prob, ham_prob,
249 bayes_spam_prob, bayes_ham_prob,
250 cl->spam_prob, cl->ham_prob);
251 }
252 }
253 }
254
255
256
257 gboolean
bayes_init(struct rspamd_config * cfg,struct ev_loop * ev_base,struct rspamd_classifier * cl)258 bayes_init (struct rspamd_config *cfg,
259 struct ev_loop *ev_base,
260 struct rspamd_classifier *cl)
261 {
262 cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER;
263
264 return TRUE;
265 }
266
267 void
bayes_fin(struct rspamd_classifier * cl)268 bayes_fin (struct rspamd_classifier *cl)
269 {
270 }
271
272 gboolean
bayes_classify(struct rspamd_classifier * ctx,GPtrArray * tokens,struct rspamd_task * task)273 bayes_classify (struct rspamd_classifier * ctx,
274 GPtrArray *tokens,
275 struct rspamd_task *task)
276 {
277 double final_prob, h, s, *pprob;
278 gchar sumbuf[32];
279 struct rspamd_statfile *st = NULL;
280 struct bayes_task_closure cl;
281 rspamd_token_t *tok;
282 guint i, text_tokens = 0;
283 gint id;
284
285 g_assert (ctx != NULL);
286 g_assert (tokens != NULL);
287
288 memset (&cl, 0, sizeof (cl));
289 cl.task = task;
290
291 /* Check min learns */
292 if (ctx->cfg->min_learns > 0) {
293 if (ctx->ham_learns < ctx->cfg->min_learns) {
294 msg_info_task ("not classified as ham. The ham class needs more "
295 "training samples. Currently: %ul; minimum %ud required",
296 ctx->ham_learns, ctx->cfg->min_learns);
297
298 return TRUE;
299 }
300 if (ctx->spam_learns < ctx->cfg->min_learns) {
301 msg_info_task ("not classified as spam. The spam class needs more "
302 "training samples. Currently: %ul; minimum %ud required",
303 ctx->spam_learns, ctx->cfg->min_learns);
304
305 return TRUE;
306 }
307 }
308
309 for (i = 0; i < tokens->len; i ++) {
310 tok = g_ptr_array_index (tokens, i);
311 if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
312 text_tokens ++;
313 }
314 }
315
316 if (text_tokens == 0) {
317 msg_info_task ("skipped classification as there are no text tokens. "
318 "Total tokens: %ud",
319 tokens->len);
320
321 return TRUE;
322 }
323
324 /*
325 * Skip some metatokens if we don't have enough text tokens
326 */
327 if (text_tokens > tokens->len - text_tokens) {
328 cl.meta_skip_prob = 0.0;
329 }
330 else {
331 cl.meta_skip_prob = 1.0 - text_tokens / tokens->len;
332 }
333
334 for (i = 0; i < tokens->len; i ++) {
335 tok = g_ptr_array_index (tokens, i);
336
337 bayes_classify_token (ctx, tok, &cl);
338 }
339
340 if (cl.processed_tokens == 0) {
341 msg_info_bayes ("no tokens found in bayes database "
342 "(%ud total tokens, %ud text tokens), ignore stats",
343 tokens->len, text_tokens);
344
345 return TRUE;
346 }
347
348 if (ctx->cfg->min_tokens > 0 &&
349 cl.text_tokens < (gint)(ctx->cfg->min_tokens * 0.1)) {
350 msg_info_bayes ("ignore bayes probability since we have "
351 "found too few text tokens: %uL (of %ud checked), "
352 "at least %d required",
353 cl.text_tokens,
354 text_tokens,
355 (gint)(ctx->cfg->min_tokens * 0.1));
356
357 return TRUE;
358 }
359
360 if (cl.spam_prob > -300 && cl.ham_prob > -300) {
361 /* Fisher value is low enough to apply inv_chi_square */
362 h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
363 s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);
364 }
365 else {
366 /* Use naive method */
367 if (cl.spam_prob < cl.ham_prob) {
368 h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) /
369 (1.0 + exp(cl.spam_prob - cl.ham_prob));
370 s = 1.0 - h;
371 }
372 else {
373 s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) /
374 (1.0 + exp(cl.ham_prob - cl.spam_prob));
375 h = 1.0 - s;
376 }
377 }
378
379 if (isfinite (s) && isfinite (h)) {
380 final_prob = (s + 1.0 - h) / 2.;
381 msg_debug_bayes (
382 "got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
383 " %L tokens processed of %ud total tokens;"
384 " %uL text tokens found of %ud text tokens)",
385 cl.ham_prob,
386 h,
387 cl.spam_prob,
388 s,
389 cl.processed_tokens,
390 tokens->len,
391 cl.text_tokens,
392 text_tokens);
393 }
394 else {
395 /*
396 * We have some overflow, hence we need to check which class
397 * is NaN
398 */
399 if (isfinite (h)) {
400 final_prob = 1.0;
401 msg_debug_bayes ("spam class is full: no"
402 " ham samples");
403 }
404 else if (isfinite (s)) {
405 final_prob = 0.0;
406 msg_debug_bayes ("ham class is full: no"
407 " spam samples");
408 }
409 else {
410 final_prob = 0.5;
411 msg_warn_bayes ("spam and ham classes are both full");
412 }
413 }
414
415 pprob = rspamd_mempool_alloc (task->task_pool, sizeof (*pprob));
416 *pprob = final_prob;
417 rspamd_mempool_set_variable (task->task_pool, "bayes_prob", pprob, NULL);
418
419 if (cl.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
420 /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
421 for (i = 0; i < ctx->statfiles_ids->len; i++) {
422 id = g_array_index (ctx->statfiles_ids, gint, i);
423 st = g_ptr_array_index (ctx->ctx->statfiles, id);
424
425 if (final_prob > 0.5 && st->stcf->is_spam) {
426 break;
427 }
428 else if (final_prob < 0.5 && !st->stcf->is_spam) {
429 break;
430 }
431 }
432
433 /* Correctly scale HAM */
434 if (final_prob < 0.5) {
435 final_prob = 1.0 - final_prob;
436 }
437
438 /*
439 * Bayes p is from 0.5 to 1.0, but confidence is from 0 to 1, so
440 * we need to rescale it to display correctly
441 */
442 rspamd_snprintf (sumbuf, sizeof (sumbuf), "%.2f%%",
443 (final_prob - 0.5) * 200.);
444 final_prob = rspamd_normalize_probability (final_prob, 0.5);
445 g_assert (st != NULL);
446
447 if (final_prob > 1 || final_prob < 0) {
448 msg_err_bayes ("internal error: probability %f is outside of the "
449 "allowed range [0..1]", final_prob);
450
451 if (final_prob > 1) {
452 final_prob = 1.0;
453 }
454 else {
455 final_prob = 0.0;
456 }
457 }
458
459 rspamd_task_insert_result (task,
460 st->stcf->symbol,
461 final_prob,
462 sumbuf);
463 }
464
465 return TRUE;
466 }
467
468 gboolean
bayes_learn_spam(struct rspamd_classifier * ctx,GPtrArray * tokens,struct rspamd_task * task,gboolean is_spam,gboolean unlearn,GError ** err)469 bayes_learn_spam (struct rspamd_classifier * ctx,
470 GPtrArray *tokens,
471 struct rspamd_task *task,
472 gboolean is_spam,
473 gboolean unlearn,
474 GError **err)
475 {
476 guint i, j, total_cnt, spam_cnt, ham_cnt;
477 gint id;
478 struct rspamd_statfile *st;
479 rspamd_token_t *tok;
480 gboolean incrementing;
481
482 g_assert (ctx != NULL);
483 g_assert (tokens != NULL);
484
485 incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
486
487 for (i = 0; i < tokens->len; i++) {
488 total_cnt = 0;
489 spam_cnt = 0;
490 ham_cnt = 0;
491 tok = g_ptr_array_index (tokens, i);
492
493 for (j = 0; j < ctx->statfiles_ids->len; j++) {
494 id = g_array_index (ctx->statfiles_ids, gint, j);
495 st = g_ptr_array_index (ctx->ctx->statfiles, id);
496 g_assert (st != NULL);
497
498 if (!!st->stcf->is_spam == !!is_spam) {
499 if (incrementing) {
500 tok->values[id] = 1;
501 }
502 else {
503 tok->values[id]++;
504 }
505
506 total_cnt += tok->values[id];
507
508 if (st->stcf->is_spam) {
509 spam_cnt += tok->values[id];
510 }
511 else {
512 ham_cnt += tok->values[id];
513 }
514 }
515 else {
516 if (tok->values[id] > 0 && unlearn) {
517 /* Unlearning */
518 if (incrementing) {
519 tok->values[id] = -1;
520 }
521 else {
522 tok->values[id]--;
523 }
524
525 if (st->stcf->is_spam) {
526 spam_cnt += tok->values[id];
527 }
528 else {
529 ham_cnt += tok->values[id];
530 }
531 total_cnt += tok->values[id];
532 }
533 else if (incrementing) {
534 tok->values[id] = 0;
535 }
536 }
537 }
538
539 if (tok->t1 && tok->t2) {
540 msg_debug_bayes ("token %uL <%*s:%*s>: window: %d, total_count: %d, "
541 "spam_count: %d, ham_count: %d",
542 tok->data,
543 (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
544 (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
545 tok->window_idx, total_cnt, spam_cnt, ham_cnt);
546 }
547 else {
548 msg_debug_bayes ("token %uL <?:?>: window: %d, total_count: %d, "
549 "spam_count: %d, ham_count: %d",
550 tok->data,
551 tok->window_idx, total_cnt, spam_cnt, ham_cnt);
552 }
553 }
554
555 return TRUE;
556 }
557