1local lua_util = require "lua_util"
2local rspamd_util = require "rspamd_util"
3local fun = require "fun"
4
5local utility = {}
6
7function utility.get_all_symbols(logs, ignore_symbols)
8  -- Returns a list of all symbols
9
10  local symbols_set = {}
11
12  for _, line in pairs(logs) do
13    line = lua_util.rspamd_str_split(line, " ")
14    for i=4,(#line-1) do
15      line[i] = line[i]:gsub("%s+", "")
16      if not symbols_set[line[i]] then
17        symbols_set[line[i]] = true
18      end
19    end
20  end
21
22  local all_symbols = {}
23
24  for symbol, _ in pairs(symbols_set) do
25    if not ignore_symbols[symbol] then
26      all_symbols[#all_symbols + 1] = symbol
27    end
28  end
29
30  table.sort(all_symbols)
31
32  return all_symbols
33end
34
35function utility.read_log_file(file)
36
37  local lines = {}
38  local messages = {}
39
40  local fd = assert(io.open(file, "r"))
41  local fname = string.gsub(file, "(.*/)(.*)", "%2")
42
43  for line in fd:lines() do
44    local start,stop = string.find(line, fname .. ':')
45
46    if start and stop then
47      table.insert(lines, string.sub(line, 1, start))
48      table.insert(messages, string.sub(line, stop + 1, -1))
49    end
50end
51
52  io.close(fd)
53
54  return lines,messages
55end
56
57function utility.get_all_logs(dirs)
58  -- Reads all log files in the directory and returns a list of logs.
59
60  if type(dirs) == 'string' then
61    dirs = {dirs}
62  end
63
64  local all_logs = {}
65  local all_messages = {}
66
67  for _,dir in ipairs(dirs) do
68    if dir:sub(-1, -1) == "/" then
69      dir = dir:sub(1, -2)
70      local files = rspamd_util.glob(dir .. "/*.log")
71      for _, file in pairs(files) do
72        local logs,messages = utility.read_log_file(file)
73        for i=1,#logs do
74          table.insert(all_logs, logs[i])
75          table.insert(all_messages, messages[i])
76        end
77      end
78    else
79      local logs,messages = utility.read_log_file(dir)
80      for i=1,#logs do
81        table.insert(all_logs, logs[i])
82        table.insert(all_messages, messages[i])
83      end
84    end
85  end
86
87  return all_logs,all_messages
88end
89
90function utility.get_all_symbol_scores(conf, ignore_symbols)
91  local symbols = conf:get_symbols_scores()
92
93  return fun.tomap(fun.map(function(name, elt)
94    return name,elt['score']
95  end, fun.filter(function(name, elt)
96    return not ignore_symbols[name]
97  end, symbols)))
98end
99
100function utility.generate_statistics_from_logs(logs, messages, threshold)
101
102  -- Returns file_stats table and list of symbol_stats table.
103
104  local file_stats = {
105    no_of_emails = 0,
106    no_of_spam = 0,
107    no_of_ham = 0,
108    spam_percent = 0,
109    ham_percent = 0,
110    true_positives = 0,
111    true_negatives = 0,
112    false_negative_rate = 0,
113    false_positive_rate = 0,
114    overall_accuracy = 0,
115    fscore = 0,
116    avg_scan_time = 0,
117    slowest_file = nil,
118    slowest = 0
119  }
120
121  local all_symbols_stats = {}
122  local all_fps = {}
123  local all_fns = {}
124
125  local false_positives = 0
126  local false_negatives = 0
127  local true_positives = 0
128  local true_negatives = 0
129  local no_of_emails = 0
130  local no_of_spam = 0
131  local no_of_ham = 0
132
133  for i, log in ipairs(logs) do
134    log = lua_util.rspamd_str_trim(log)
135    log = lua_util.rspamd_str_split(log, " ")
136    local message = messages[i]
137
138    local is_spam = (log[1] == "SPAM")
139    local score = tonumber(log[2])
140
141    no_of_emails = no_of_emails + 1
142
143    if is_spam then
144      no_of_spam = no_of_spam + 1
145    else
146      no_of_ham = no_of_ham + 1
147    end
148
149    if is_spam and (score >= threshold) then
150      true_positives = true_positives + 1
151    elseif is_spam and (score < threshold) then
152      false_negatives = false_negatives + 1
153      table.insert(all_fns, message)
154    elseif not is_spam and (score >= threshold) then
155      false_positives = false_positives + 1
156      table.insert(all_fps, message)
157    else
158      true_negatives = true_negatives + 1
159    end
160
161    for j=4, (#log-1) do
162      if all_symbols_stats[log[j]] == nil then
163        all_symbols_stats[log[j]] = {
164          name = message,
165          no_of_hits = 0,
166          spam_hits = 0,
167          ham_hits = 0,
168          spam_overall = 0
169        }
170      end
171      local sym = log[j]
172
173      all_symbols_stats[sym].no_of_hits = all_symbols_stats[sym].no_of_hits + 1
174
175      if is_spam then
176        all_symbols_stats[sym].spam_hits = all_symbols_stats[sym].spam_hits + 1
177      else
178        all_symbols_stats[sym].ham_hits = all_symbols_stats[sym].ham_hits + 1
179      end
180
181      -- Find slowest message
182      if ((tonumber(log[#log]) or 0) > file_stats.slowest) then
183          file_stats.slowest = tonumber(log[#log])
184          file_stats.slowest_file = message
185      end
186    end
187  end
188
189  -- Calculating file stats
190
191  file_stats.no_of_ham = no_of_ham
192  file_stats.no_of_spam = no_of_spam
193  file_stats.no_of_emails = no_of_emails
194  file_stats.true_positives = true_positives
195  file_stats.true_negatives = true_negatives
196
197  if no_of_emails > 0 then
198    file_stats.spam_percent = no_of_spam * 100 / no_of_emails
199    file_stats.ham_percent = no_of_ham * 100 / no_of_emails
200    file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
201        no_of_emails
202  end
203
204  if no_of_ham > 0 then
205    file_stats.false_positive_rate = false_positives * 100 / no_of_ham
206  end
207
208  if no_of_spam > 0 then
209    file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
210  end
211
212  file_stats.fscore = 2 * true_positives / (2
213      * true_positives
214      + false_positives
215      + false_negatives)
216
217  -- Calculating symbol stats
218
219  for _, symbol_stats in pairs(all_symbols_stats) do
220    symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
221    symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
222    symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
223    symbol_stats.spam_overall = symbol_stats.spam_percent /
224        (symbol_stats.spam_percent + symbol_stats.ham_percent)
225  end
226
227  return file_stats, all_symbols_stats, all_fps, all_fns
228end
229
230return utility
231