1--[[ 2Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru> 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15]]-- 16 17-- This file contains functions to simplify bayes classifier auto-learning 18 19local lua_util = require "lua_util" 20local lua_verdict = require "lua_verdict" 21local N = "lua_bayes" 22 23local exports = {} 24 25exports.can_learn = function(task, is_spam, is_unlearn) 26 local learn_type = task:get_request_header('Learn-Type') 27 28 if not (learn_type and tostring(learn_type) == 'bulk') then 29 local prob = task:get_mempool():get_variable('bayes_prob', 'double') 30 31 if prob then 32 local in_class = false 33 local cl 34 if is_spam then 35 cl = 'spam' 36 in_class = prob >= 0.95 37 else 38 cl = 'ham' 39 in_class = prob <= 0.05 40 end 41 42 if in_class then 43 return false,string.format( 44 'already in class %s; probability %.2f%%', 45 cl, math.abs((prob - 0.5) * 200.0)) 46 end 47 end 48 end 49 50 return true 51end 52 53exports.autolearn = function(task, conf) 54 local function log_can_autolearn(verdict, score, threshold) 55 local from = task:get_from('smtp') 56 local mime_rcpts = 'undef' 57 local mr = task:get_recipients('mime') 58 if mr then 59 for _,r in ipairs(mr) do 60 if mime_rcpts == 'undef' then 61 mime_rcpts = r.addr 62 else 63 mime_rcpts = mime_rcpts .. ',' .. r.addr 64 end 65 end 66 end 67 68 lua_util.debugm(N, task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>', 69 task:get_header('Message-Id') or '<undef>', 70 from and from[1].addr or 'undef', 71 verdict, 72 string.format("%.2f", score), 73 verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/', 74 threshold, 75 mime_rcpts) 76 end 77 78 -- We have autolearn config so let's figure out what is requested 79 local verdict,score = lua_verdict.get_specific_verdict("bayes", task) 80 local learn_spam,learn_ham = false, false 81 82 if verdict == 'passthrough' then 83 -- No need to autolearn 84 lua_util.debugm(N, task, 'no need to autolearn - verdict: %s', 85 verdict) 86 return 87 end 88 89 if conf.spam_threshold and conf.ham_threshold then 90 if verdict == 'spam' then 91 if conf.spam_threshold and score >= conf.spam_threshold then 92 log_can_autolearn(verdict, score, conf.spam_threshold) 93 learn_spam = true 94 end 95 elseif verdict == 'junk' then 96 if conf.junk_threshold and score >= conf.junk_threshold then 97 log_can_autolearn(verdict, score, conf.junk_threshold) 98 learn_spam = true 99 end 100 elseif verdict == 'ham' then 101 if conf.ham_threshold and score <= conf.ham_threshold then 102 log_can_autolearn(verdict, score, conf.ham_threshold) 103 learn_ham = true 104 end 105 end 106 elseif conf.learn_verdict then 107 if verdict == 'spam' or verdict == 'junk' then 108 learn_spam = true 109 elseif verdict == 'ham' then 110 learn_ham = true 111 end 112 end 113 114 if conf.check_balance then 115 -- Check balance of learns 116 local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0 117 local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0 118 119 local min_balance = 0.9 120 if conf.min_balance then min_balance = conf.min_balance end 121 122 if spam_learns > 0 or ham_learns > 0 then 123 local max_ratio = 1.0 / min_balance 124 local spam_learns_ratio = spam_learns / (ham_learns + 1) 125 if spam_learns_ratio > max_ratio and learn_spam then 126 lua_util.debugm(N, task, 127 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns', 128 spam_learns_ratio, min_balance, spam_learns, ham_learns) 129 learn_spam = false 130 end 131 132 local ham_learns_ratio = ham_learns / (spam_learns + 1) 133 if ham_learns_ratio > max_ratio and learn_ham then 134 lua_util.debugm(N, task, 135 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns', 136 ham_learns_ratio, min_balance, spam_learns, ham_learns) 137 learn_ham = false 138 end 139 end 140 end 141 142 if learn_spam then 143 return 'spam' 144 elseif learn_ham then 145 return 'ham' 146 end 147end 148 149return exports