1#!/usr/local/bin/lua
2-- Script for training with TREC compatible corpora
3-- Version: v1.3 - Aug/2006 - Fidelis Assis
4
5-- DSTTT is not used any more. An SSTTT variant is used instead where
6-- extra learnings, using only the header, are done when the previous
7-- was not enough. In general, this variant results in better accuracy
8-- and now is the default training method.
9--
10-- TOE (Train On Error) can still be used if both "threshold"
11-- and "header_learn_threshold" are set to 0, but normally SSTTT is
12-- much better than TOE.
13--
14-- See the on-line book "The CRM114 Discriminator Revealed!" by
15-- William S. Yerazunis for the definitions of TOE, SSTTT and DSTTT.
16-- There's a link to the book on http://crm114.sourceforge.net.
17
18-- TOER means TOE + Reinforcements, another name for SSTTT. TONE -
19-- Train On or Near Error - is yet another name I sometimes refer to
20-- it as.
21
22--[[------------------------------------------------------------------
23
24This program is free software. You can use it as you want.
25As usual, without warranty of any kind. Use it at your own risk.
26
27How to use:
28
29$ ./toer.lua <path_to_index> [<index_name>]
30
31The index file is a list of message files with two fields separated
32by a space per line. The first field is the judge ("spam" or "ham")
33and the second is the message filename, relative to <path_to_index>.
34
35Ex:
36
37spam ../data/message001
38ham ../data/message002
39ham ../data/message003
40spam ../data/message004
41...
42
43In this example, the message files are in a dir named "data", parallel
44to <path_to_index>.
45
46OBS: toer.lua doesn't affect the counters used for accuracy statistics.
47
48--]]----------------------------------------------------------------
49
50-- start of program
51
52local osbf = require "osbf"  -- load osbf module
53local string = string
54local math = math
55
56local threshold_offset  = 5  -- [-5, +5]
57			     -- -5 => less false negatives
58			     --  0 => normal
59			     -- +5 => less false positives
60local delimiters	= "" -- token delimiters
61--local delimiters	= ".@:/"
62local num_buckets	= 94321 -- min value recommended for production
63--local num_buckets	= 4000037 -- value used for TREC tests
64local preserve_db       = false -- preserve databases or not between corpora
65local max_text_size     = 500000 -- 0 means full document
66local min_p_ratio       = 1  -- minimum probability ratio over the classes a
67			     -- feature must have so as not to be ignored
68local corpora_dir	= "./" -- default = local dir
69local corpora_index	= "index" -- default index filename
70local testsize		= 1000 -- number of messages in the testset
71local train_in_testset  = true
72local in_testset        = false -- initial value
73local log_prefix        = "toer-lua"
74local nonspam_index     = 1 -- index to the nonspam db in the table "classes"
75local spam_index        = 2 -- index to the spam db in the table "classes"
76
77-- Experimental constants
78local thick_threshold                = 20 -- overtraining protection
79local header_learn_threshold         = 14 -- header overtraining protection
80local reinforcement_degree           = 0.6
81local ham_reinforcement_limit        = 4
82local spam_reinforcement_limit       = 4
83local threshold_reinforcement_degree = 1.5
84
85-- Flags
86local classify_flags            = 0
87local count_classification_flag = 2
88local learn_flags               = 0
89local mistake_flag              = 2
90local reinforcement_flag        = 4
91
92-- dbset is the set of single class databases to be used for classification
93local dbset = {
94	classes     = {"nonspam.cfc", "spam.cfc"},
95	ncfs        = 1, -- split "classes" in 2 sublists. "ncfs" is
96	                 -- the number of classes in the first sublist.
97			 -- Here, the first sublist is {"nonspam.cfc"}
98			 -- and the second {"spam.cfc"}.
99	delimiters  = delimiters
100}
101-------------------------------------------------------------------------
102
103-- receives a file name and returns the number of lines
104function count_lines(file)
105  local f = assert(io.open(file, "r"))
106  local _, num_lines = string.gsub(f:read("*all"), '\n', '\n')
107  f:close()
108  return num_lines
109end
110
111-------------------------------------------------------------------------
112
113-- receives a single class database filename and returns
114-- a string with a statistics report of the database
115function dbfile_stats (dbfile)
116    local OSBF_Bayes_db_version = 5 -- OSBF-Bayes database indentifier
117    local report = "-- Statistics for " .. dbfile .. "\n"
118    local version = "OSBF-Bayes"
119    stats_lua = osbf.stats(dbfile)
120    if (stats_lua.version == OSBF_Bayes_db_version) then
121      report = report .. string.format(
122        "%-35s%12s\n%-35s%12d\n%-35s%12.1f\n%-35s%12d\n%-35s%12d\n%-35s%12d\n",
123        "Database version:", version,
124        "Total buckets in database:", stats_lua.buckets,
125        "Buckets used (%):", stats_lua.use * 100,
126        "Trainings:", stats_lua.learnings,
127        "Bucket size (bytes):", stats_lua.bucket_size,
128        "Header size (bytes):", stats_lua.header_size)
129      report = report .. string.format("%-35s%12d\n%-35s%12d\n%-35s%12d\n\n",
130        "Number of chains:", stats_lua.chains,
131        "Max chain len (buckets):", stats_lua.max_chain,
132        "Average chain length (buckets):", stats_lua.avg_chain,
133        "Max bucket displacement:", stats_lua.max_displacement)
134    else
135    	report = report .. string.format("%-35s%12s\n", "Database version:",
136	    "Unknown")
137    end
138
139    return report
140end
141
142-- return the header of the message
143function header(text)
144  local h = string.match(text, "^(.-\n)\n")
145  return h or text
146end
147
148-- check if a file exists
149function file_exists(file)
150  local f = io.open(file, "r")
151  if f then
152    f:close()
153    return true
154  else
155    return nil, "File not found"
156  end
157end
158
159-------------------------------------------------------------------------
160
161corpora_dir = arg[1]
162
163-- clean the databases
164if not preserve_db then
165  osbf.remove_db(dbset.classes)
166  assert(osbf.create_db(dbset.classes, num_buckets))
167else
168  if not (file_exists(dbset.classes[1]) and
169	 file_exists(dbset.classes[2])) then
170    assert(osbf.create_db(dbset.classes, num_buckets))
171  end
172end
173
174for i=1, 1 do
175
176    suffix = string.format("o%d_t%d_u%g_b%d_r%d_%s_%s", threshold_offset,
177		thick_threshold, header_learn_threshold, num_buckets,
178		min_p_ratio, string.gsub(corpora_dir, "/", "_"), corpora_index)
179    training_log = log_prefix .. "_training-log_" .. suffix
180    training_stats_report = log_prefix .. "_training-stats_" .. suffix
181    db_stats_report = log_prefix .. "_db-stats_" .. suffix
182
183    if not preserve_db then
184      osbf.remove_db(dbset.classes)
185      assert(osbf.create_db(dbset.classes, num_buckets))
186    end
187
188    local num_msgs, hams, spams, hams_test, spams_test = 0, 0, 0, 0, 0
189    local false_positives, false_negatives, trainings,
190    	  reinforcements = 0, 0, 0, 0
191    local false_positives_test, false_negatives_test,
192    	  reinforcements_test, trainings_test = 0, 0, 0, 0
193    local total_messages = count_lines(corpora_dir .. corpora_index)
194    local start_of_test = total_messages - testsize + 1
195
196    ini = os.time()
197
198    log = assert(io.open(training_log, "w"))
199    s = assert(io.open(corpora_dir .. corpora_index, "r"))
200    for line in s:lines() do
201      local judge, msg_name = string.match(line, "^(%S+)%s+(%S+)$")
202      local file_name = corpora_dir .. msg_name
203      local msg = assert(io.open(file_name, "r"))
204      local text = msg:read("*all")
205      msg:close()
206
207      if max_text_size > 0 then
208        text = string.sub(text, 1, max_text_size)
209        text = string.match(text, "^(.*)%s%S*$")
210      end
211      text = text .. " " .. string.match(text, "^%s*%S+%s+%S+%s+%S+%s+%S+")
212      local lim_orig_header = header(text)
213
214      local pR, p_array, i_pmax = osbf.classify(text, dbset, classify_flags)
215      if (pR == nil) then
216        error(p_array)
217      end
218
219      if pR < 0 then
220        class = "spam"
221      else
222	class = "ham"
223      end
224
225      num_msgs = num_msgs + 1
226      in_testset = num_msgs >= start_of_test
227
228      if (judge == "spam") then
229	spams = spams + 1
230	if in_testset then
231	  spams_test = spams_test + 1
232	end
233	-- check classification
234        if (pR >= 0) then
235	  -- wrong classification, false negative
236          result = "1"
237	  false_negatives = false_negatives + 1
238	  if not in_testset or train_in_testset then
239            assert(osbf.learn(text, dbset, spam_index, learn_flags))
240	    local new_pR =  osbf.classify(text, dbset, classify_flags)
241  	    trainings = trainings + 1
242
243	    if (header_learn_threshold > 0) then
244              if new_pR > (threshold_offset - thick_threshold) and
245                 (pR - new_pR) < header_learn_threshold then
246                local i = 0
247		local old_pR
248                local trd = threshold_reinforcement_degree *
249                            (threshold_offset - thick_threshold)
250                local rd = reinforcement_degree * header_learn_threshold
251                repeat
252                  old_pR = new_pR
253                  osbf.learn(lim_orig_header, dbset, spam_index,
254				reinforcement_flag)
255                  new_pR = osbf.classify(text, dbset, classify_flags)
256                  i = i + 1
257                until i >= spam_reinforcement_limit or
258                      new_pR < trd or (old_pR - new_pR) >= rd
259              end
260	    end
261
262	  end
263
264	  if in_testset then
265	    false_negatives_test = false_negatives_test + 1
266	    if train_in_testset then
267	      trainings_test = trainings_test + 1
268	    end
269	  end
270        else
271	  -- correctly classified as spam. check thick_threshold
272	  if pR > (threshold_offset - thick_threshold) then
273	    -- within unsure zone
274	    if not in_testset or train_in_testset then
275	      -- do reinforcement
276	      assert(osbf.learn(text, dbset, spam_index, learn_flags))
277	      local new_pR =  osbf.classify(text, dbset, classify_flags)
278
279  	      result = "r"
280
281              if new_pR > (threshold_offset - thick_threshold) and
282                 (pR - new_pR) < header_learn_threshold then
283                local i = 0
284                local old_pR
285                local trd = threshold_reinforcement_degree *
286                            (threshold_offset - thick_threshold)
287                local rd = reinforcement_degree * header_learn_threshold
288                repeat
289                  old_pR = new_pR
290                  osbf.learn(lim_orig_header, dbset, spam_index,
291				reinforcement_flag)
292                  new_pR = osbf.classify(text, dbset, classify_flags)
293                  i = i + 1
294                until i >= spam_reinforcement_limit or
295                      new_pR < trd or (old_pR - new_pR) >= rd
296              end
297
298	      reinforcements = reinforcements + 1
299	      if in_testset then
300	        reinforcements_test = reinforcements_test + 1
301	      end
302	    end
303	  else
304	    -- OK, out of unsure zone
305            result = "0"
306          end
307        end
308      else
309	hams = hams + 1
310	if in_testset then
311	  hams_test = hams_test + 1
312	end
313	-- check classification
314        if (pR >= 0) then
315	  -- correctly classified as ham. check thick_threshold
316	  if pR < (threshold_offset + thick_threshold) then
317	    -- within unsure zone
318	    if not in_testset or train_in_testset then
319	      -- do reinforcement
320	      assert(osbf.learn(text, dbset, nonspam_index, learn_flags))
321	      local new_pR =  osbf.classify(text, dbset, classify_flags)
322
323  	      result = "r"
324              if new_pR < (threshold_offset + thick_threshold) and
325                (new_pR - pR) < header_learn_threshold then
326                local i = 0
327	        local old_pR
328                local trd = threshold_reinforcement_degree *
329                            (threshold_offset + thick_threshold)
330                local rd = reinforcement_degree * header_learn_threshold
331                repeat
332                  old_pR = new_pR
333                  osbf.learn(lim_orig_header, dbset, nonspam_index,
334				reinforcement_flag)
335		  new_pR, p_array = osbf.classify(text, dbset, classify_flags)
336                  i = i + 1
337                 until i > ham_reinforcement_limit or
338                    new_pR > trd or (new_pR - old_pR) >= rd
339	      end
340
341	      reinforcements = reinforcements + 1
342	      if in_testset then
343	        reinforcements_test = reinforcements_test + 1
344	      end
345  	    end
346	  else
347	    -- OK, out of unsure zone
348	    result = "0"
349	  end
350        else
351	  -- wrong classification, false positive
352          result = "1"
353	  false_positives = false_positives + 1
354	  if not in_testset or train_in_testset then
355	    assert(osbf.learn(text, dbset, nonspam_index, learn_flags))
356  	    trainings = trainings + 1
357	  end
358	  local new_pR =  osbf.classify(text, dbset, classify_flags)
359
360	  if in_testset then
361	    false_positives_test = false_positives_test + 1
362	    if train_in_testset then
363	      trainings_test = trainings_test + 1
364	    end
365	  end
366
367          if new_pR < (threshold_offset + thick_threshold) and
368             (new_pR - pR) < header_learn_threshold then
369            local i = 0
370	    local old_pR
371            local trd = threshold_reinforcement_degree *
372                        (threshold_offset + thick_threshold)
373            local rd = reinforcement_degree * header_learn_threshold
374            repeat
375              old_pR = new_pR
376              osbf.learn(lim_orig_header, dbset, nonspam_index,
377			reinforcement_flag)
378              new_pR, p_array = osbf.classify(text, dbset, classify_flags)
379              i = i + 1
380            until i > ham_reinforcement_limit or
381                  new_pR > trd or (new_pR - old_pR) >= rd
382          end
383
384        end
385      end
386      log:write("file=",file_name," judge=", judge, " class=", class,
387	  " score=", string.format("%.4f", (0 - pR)), " user= genre= runid=none\n")
388      log:flush()
389    end
390    s:close()
391    local duration = os.time() - ini
392    log:flush()
393    log:close()
394
395    -- print database stats report
396    db_stats_fh = assert(io.open(db_stats_report, "w"))
397    for _, dbfile in ipairs(dbset.classes) do
398      db_stats_fh:write(dbfile_stats(dbfile))
399    end
400    db_stats_fh:close()
401
402    -- print training stats report
403    t_stats_fh = assert(io.open(training_stats_report, "w"))
404    t_stats_fh:write("-- Training statistics report\n\n")
405    t_stats_fh:write("Message corpus\n")
406    t_stats_fh:write(string.format("  %-26s%7d\n  %-26s%7d\n  %-26s%7d\n\n",
407	"Hams:", hams, "Spams:", spams, "Total messages:", hams+spams))
408
409    t_stats_fh:write("Training (OSBFBayes)\n")
410    t_stats_fh:write(string.format(
411      "  %-26s%7d\n  %-26s%7d\n  %-26s%7d\n  %-26s%7d\n  %-26s%7d\n  %-26s%7d\n\n",
412      "Thick treshold:", thick_threshold,
413      "Header learn-treshold:", header_learn_threshold,
414      "Trainings on error:", false_positives+false_negatives,
415      "Reinforcements:", reinforcements,
416      "Total learnings:", false_positives+false_negatives+reinforcements,
417      "Duration (sec):", duration))
418
419    t_stats_fh:write(
420	string.format("Performance in the final %d messages (testset)\n",
421			testsize))
422    t_stats_fh:write(string.format(
423      "  %-26s%7d\n  %-26s%7d\n  %-26s%7d\n",
424      "Hams in testset:", hams_test,
425      "Spams in testset:", spams_test,
426      "False positives:", false_positives_test))
427
428   t_stats_fh:write(string.format(
429     "  %-26s%7d\n  %-26s%7d\n  %-26s%7d\n  %-26s%10.2f\n  " ..
430	"%-26s%10.2f\n  %-26s%10.2f\n  %-26s%10.2f\n  %-26s%10.2f\n",
431     "False negatives:", false_negatives_test,
432     "Total errors in testset:", false_positives_test+false_negatives_test,
433     "Reinforcements in testset:", reinforcements_test,
434     "Ham recall (%):", 100 * (hams_test - false_positives_test)/hams_test,
435     "Ham precision (%):", 100 * (hams_test - false_positives_test) /
436     	(hams_test - false_positives_test + false_negatives_test),
437     "Spam recall (%):", 100 * (spams_test - false_negatives_test)/spams_test,
438     "Spam precision (%):", 100 * (spams_test - false_negatives_test) /
439     	(spams_test - false_negatives_test + false_positives_test),
440     "Accuracy (%):",
441      100 * (1 - (false_positives_test+false_negatives_test)/testsize)))
442   t_stats_fh:close()
443   -- end of report
444end
445
446