1--[[
2Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15]]--
16
17-- This file contains functions to simplify bayes classifier auto-learning
18
19local lua_util = require "lua_util"
20local lua_verdict = require "lua_verdict"
21local N = "lua_bayes"
22
23local exports = {}
24
25exports.can_learn = function(task, is_spam, is_unlearn)
26  local learn_type = task:get_request_header('Learn-Type')
27
28  if not (learn_type and tostring(learn_type) == 'bulk') then
29    local prob = task:get_mempool():get_variable('bayes_prob', 'double')
30
31    if prob then
32      local in_class = false
33      local cl
34      if is_spam then
35        cl = 'spam'
36        in_class = prob >= 0.95
37      else
38        cl = 'ham'
39        in_class = prob <= 0.05
40      end
41
42      if in_class then
43        return false,string.format(
44            'already in class %s; probability %.2f%%',
45            cl, math.abs((prob - 0.5) * 200.0))
46      end
47    end
48  end
49
50  return true
51end
52
53exports.autolearn = function(task, conf)
54  local function log_can_autolearn(verdict, score, threshold)
55    local from = task:get_from('smtp')
56    local mime_rcpts = 'undef'
57    local mr = task:get_recipients('mime')
58    if mr then
59      for _,r in ipairs(mr) do
60        if mime_rcpts == 'undef' then
61          mime_rcpts = r.addr
62        else
63          mime_rcpts = mime_rcpts .. ',' .. r.addr
64        end
65      end
66    end
67
68    lua_util.debugm(N, task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>',
69        task:get_header('Message-Id') or '<undef>',
70        from and from[1].addr or 'undef',
71        verdict,
72        string.format("%.2f", score),
73        verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/',
74        threshold,
75        mime_rcpts)
76  end
77
78  -- We have autolearn config so let's figure out what is requested
79  local verdict,score = lua_verdict.get_specific_verdict("bayes", task)
80  local learn_spam,learn_ham = false, false
81
82  if verdict == 'passthrough' then
83    -- No need to autolearn
84    lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
85        verdict)
86    return
87  end
88
89  if conf.spam_threshold and conf.ham_threshold then
90    if verdict == 'spam' then
91      if conf.spam_threshold and score >= conf.spam_threshold then
92        log_can_autolearn(verdict, score, conf.spam_threshold)
93        learn_spam = true
94      end
95    elseif verdict == 'junk' then
96      if conf.junk_threshold and score >= conf.junk_threshold then
97        log_can_autolearn(verdict, score, conf.junk_threshold)
98        learn_spam = true
99      end
100    elseif verdict == 'ham' then
101      if conf.ham_threshold and score <= conf.ham_threshold then
102        log_can_autolearn(verdict, score, conf.ham_threshold)
103        learn_ham = true
104      end
105    end
106  elseif conf.learn_verdict then
107    if verdict == 'spam' or verdict == 'junk' then
108      learn_spam = true
109    elseif verdict == 'ham' then
110      learn_ham = true
111    end
112  end
113
114  if conf.check_balance then
115    -- Check balance of learns
116    local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
117    local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
118
119    local min_balance = 0.9
120    if conf.min_balance then min_balance = conf.min_balance end
121
122    if spam_learns > 0 or ham_learns > 0 then
123      local max_ratio = 1.0 / min_balance
124      local spam_learns_ratio = spam_learns / (ham_learns + 1)
125      if  spam_learns_ratio > max_ratio and learn_spam then
126        lua_util.debugm(N, task,
127            'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
128            spam_learns_ratio, min_balance, spam_learns, ham_learns)
129        learn_spam = false
130      end
131
132      local ham_learns_ratio = ham_learns / (spam_learns + 1)
133      if  ham_learns_ratio > max_ratio and learn_ham then
134        lua_util.debugm(N, task,
135            'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
136            ham_learns_ratio, min_balance, spam_learns, ham_learns)
137        learn_ham = false
138      end
139    end
140  end
141
142  if learn_spam then
143    return 'spam'
144  elseif learn_ham then
145    return 'ham'
146  end
147end
148
149return exports