1#!/usr/local/bin/lua 2-- Script for training with TREC compatible corpora 3-- Version: v1.3 - Aug/2006 - Fidelis Assis 4 5-- DSTTT is not used any more. An SSTTT variant is used instead where 6-- extra learnings, using only the header, are done when the previous 7-- was not enough. In general, this variant results in better accuracy 8-- and now is the default training method. 9-- 10-- TOE (Train On Error) can still be used if both "threshold" 11-- and "header_learn_threshold" are set to 0, but normally SSTTT is 12-- much better than TOE. 13-- 14-- See the on-line book "The CRM114 Discriminator Revealed!" by 15-- William S. Yerazunis for the definitions of TOE, SSTTT and DSTTT. 16-- There's a link to the book on http://crm114.sourceforge.net. 17 18-- TOER means TOE + Reinforcements, another name for SSTTT. TONE - 19-- Train On or Near Error - is yet another name I sometimes refer to 20-- it as. 21 22--[[------------------------------------------------------------------ 23 24This program is free software. You can use it as you want. 25As usual, without warranty of any kind. Use it at your own risk. 26 27How to use: 28 29$ ./toer.lua <path_to_index> [<index_name>] 30 31The index file is a list of message files with two fields separated 32by a space per line. The first field is the judge ("spam" or "ham") 33and the second is the message filename, relative to <path_to_index>. 34 35Ex: 36 37spam ../data/message001 38ham ../data/message002 39ham ../data/message003 40spam ../data/message004 41... 42 43In this example, the message files are in a dir named "data", parallel 44to <path_to_index>. 45 46OBS: toer.lua doesn't affect the counters used for accuracy statistics. 47 48--]]---------------------------------------------------------------- 49 50-- start of program 51 52local osbf = require "osbf" -- load osbf module 53local string = string 54local math = math 55 56local threshold_offset = 5 -- [-5, +5] 57 -- -5 => less false negatives 58 -- 0 => normal 59 -- +5 => less false positives 60local delimiters = "" -- token delimiters 61--local delimiters = ".@:/" 62local num_buckets = 94321 -- min value recommended for production 63--local num_buckets = 4000037 -- value used for TREC tests 64local preserve_db = false -- preserve databases or not between corpora 65local max_text_size = 500000 -- 0 means full document 66local min_p_ratio = 1 -- minimum probability ratio over the classes a 67 -- feature must have so as not to be ignored 68local corpora_dir = "./" -- default = local dir 69local corpora_index = "index" -- default index filename 70local testsize = 1000 -- number of messages in the testset 71local train_in_testset = true 72local in_testset = false -- initial value 73local log_prefix = "toer-lua" 74local nonspam_index = 1 -- index to the nonspam db in the table "classes" 75local spam_index = 2 -- index to the spam db in the table "classes" 76 77-- Experimental constants 78local thick_threshold = 20 -- overtraining protection 79local header_learn_threshold = 14 -- header overtraining protection 80local reinforcement_degree = 0.6 81local ham_reinforcement_limit = 4 82local spam_reinforcement_limit = 4 83local threshold_reinforcement_degree = 1.5 84 85-- Flags 86local classify_flags = 0 87local count_classification_flag = 2 88local learn_flags = 0 89local mistake_flag = 2 90local reinforcement_flag = 4 91 92-- dbset is the set of single class databases to be used for classification 93local dbset = { 94 classes = {"nonspam.cfc", "spam.cfc"}, 95 ncfs = 1, -- split "classes" in 2 sublists. "ncfs" is 96 -- the number of classes in the first sublist. 97 -- Here, the first sublist is {"nonspam.cfc"} 98 -- and the second {"spam.cfc"}. 99 delimiters = delimiters 100} 101------------------------------------------------------------------------- 102 103-- receives a file name and returns the number of lines 104function count_lines(file) 105 local f = assert(io.open(file, "r")) 106 local _, num_lines = string.gsub(f:read("*all"), '\n', '\n') 107 f:close() 108 return num_lines 109end 110 111------------------------------------------------------------------------- 112 113-- receives a single class database filename and returns 114-- a string with a statistics report of the database 115function dbfile_stats (dbfile) 116 local OSBF_Bayes_db_version = 5 -- OSBF-Bayes database indentifier 117 local report = "-- Statistics for " .. dbfile .. "\n" 118 local version = "OSBF-Bayes" 119 stats_lua = osbf.stats(dbfile) 120 if (stats_lua.version == OSBF_Bayes_db_version) then 121 report = report .. string.format( 122 "%-35s%12s\n%-35s%12d\n%-35s%12.1f\n%-35s%12d\n%-35s%12d\n%-35s%12d\n", 123 "Database version:", version, 124 "Total buckets in database:", stats_lua.buckets, 125 "Buckets used (%):", stats_lua.use * 100, 126 "Trainings:", stats_lua.learnings, 127 "Bucket size (bytes):", stats_lua.bucket_size, 128 "Header size (bytes):", stats_lua.header_size) 129 report = report .. string.format("%-35s%12d\n%-35s%12d\n%-35s%12d\n\n", 130 "Number of chains:", stats_lua.chains, 131 "Max chain len (buckets):", stats_lua.max_chain, 132 "Average chain length (buckets):", stats_lua.avg_chain, 133 "Max bucket displacement:", stats_lua.max_displacement) 134 else 135 report = report .. string.format("%-35s%12s\n", "Database version:", 136 "Unknown") 137 end 138 139 return report 140end 141 142-- return the header of the message 143function header(text) 144 local h = string.match(text, "^(.-\n)\n") 145 return h or text 146end 147 148-- check if a file exists 149function file_exists(file) 150 local f = io.open(file, "r") 151 if f then 152 f:close() 153 return true 154 else 155 return nil, "File not found" 156 end 157end 158 159------------------------------------------------------------------------- 160 161corpora_dir = arg[1] 162 163-- clean the databases 164if not preserve_db then 165 osbf.remove_db(dbset.classes) 166 assert(osbf.create_db(dbset.classes, num_buckets)) 167else 168 if not (file_exists(dbset.classes[1]) and 169 file_exists(dbset.classes[2])) then 170 assert(osbf.create_db(dbset.classes, num_buckets)) 171 end 172end 173 174for i=1, 1 do 175 176 suffix = string.format("o%d_t%d_u%g_b%d_r%d_%s_%s", threshold_offset, 177 thick_threshold, header_learn_threshold, num_buckets, 178 min_p_ratio, string.gsub(corpora_dir, "/", "_"), corpora_index) 179 training_log = log_prefix .. "_training-log_" .. suffix 180 training_stats_report = log_prefix .. "_training-stats_" .. suffix 181 db_stats_report = log_prefix .. "_db-stats_" .. suffix 182 183 if not preserve_db then 184 osbf.remove_db(dbset.classes) 185 assert(osbf.create_db(dbset.classes, num_buckets)) 186 end 187 188 local num_msgs, hams, spams, hams_test, spams_test = 0, 0, 0, 0, 0 189 local false_positives, false_negatives, trainings, 190 reinforcements = 0, 0, 0, 0 191 local false_positives_test, false_negatives_test, 192 reinforcements_test, trainings_test = 0, 0, 0, 0 193 local total_messages = count_lines(corpora_dir .. corpora_index) 194 local start_of_test = total_messages - testsize + 1 195 196 ini = os.time() 197 198 log = assert(io.open(training_log, "w")) 199 s = assert(io.open(corpora_dir .. corpora_index, "r")) 200 for line in s:lines() do 201 local judge, msg_name = string.match(line, "^(%S+)%s+(%S+)$") 202 local file_name = corpora_dir .. msg_name 203 local msg = assert(io.open(file_name, "r")) 204 local text = msg:read("*all") 205 msg:close() 206 207 if max_text_size > 0 then 208 text = string.sub(text, 1, max_text_size) 209 text = string.match(text, "^(.*)%s%S*$") 210 end 211 text = text .. " " .. string.match(text, "^%s*%S+%s+%S+%s+%S+%s+%S+") 212 local lim_orig_header = header(text) 213 214 local pR, p_array, i_pmax = osbf.classify(text, dbset, classify_flags) 215 if (pR == nil) then 216 error(p_array) 217 end 218 219 if pR < 0 then 220 class = "spam" 221 else 222 class = "ham" 223 end 224 225 num_msgs = num_msgs + 1 226 in_testset = num_msgs >= start_of_test 227 228 if (judge == "spam") then 229 spams = spams + 1 230 if in_testset then 231 spams_test = spams_test + 1 232 end 233 -- check classification 234 if (pR >= 0) then 235 -- wrong classification, false negative 236 result = "1" 237 false_negatives = false_negatives + 1 238 if not in_testset or train_in_testset then 239 assert(osbf.learn(text, dbset, spam_index, learn_flags)) 240 local new_pR = osbf.classify(text, dbset, classify_flags) 241 trainings = trainings + 1 242 243 if (header_learn_threshold > 0) then 244 if new_pR > (threshold_offset - thick_threshold) and 245 (pR - new_pR) < header_learn_threshold then 246 local i = 0 247 local old_pR 248 local trd = threshold_reinforcement_degree * 249 (threshold_offset - thick_threshold) 250 local rd = reinforcement_degree * header_learn_threshold 251 repeat 252 old_pR = new_pR 253 osbf.learn(lim_orig_header, dbset, spam_index, 254 reinforcement_flag) 255 new_pR = osbf.classify(text, dbset, classify_flags) 256 i = i + 1 257 until i >= spam_reinforcement_limit or 258 new_pR < trd or (old_pR - new_pR) >= rd 259 end 260 end 261 262 end 263 264 if in_testset then 265 false_negatives_test = false_negatives_test + 1 266 if train_in_testset then 267 trainings_test = trainings_test + 1 268 end 269 end 270 else 271 -- correctly classified as spam. check thick_threshold 272 if pR > (threshold_offset - thick_threshold) then 273 -- within unsure zone 274 if not in_testset or train_in_testset then 275 -- do reinforcement 276 assert(osbf.learn(text, dbset, spam_index, learn_flags)) 277 local new_pR = osbf.classify(text, dbset, classify_flags) 278 279 result = "r" 280 281 if new_pR > (threshold_offset - thick_threshold) and 282 (pR - new_pR) < header_learn_threshold then 283 local i = 0 284 local old_pR 285 local trd = threshold_reinforcement_degree * 286 (threshold_offset - thick_threshold) 287 local rd = reinforcement_degree * header_learn_threshold 288 repeat 289 old_pR = new_pR 290 osbf.learn(lim_orig_header, dbset, spam_index, 291 reinforcement_flag) 292 new_pR = osbf.classify(text, dbset, classify_flags) 293 i = i + 1 294 until i >= spam_reinforcement_limit or 295 new_pR < trd or (old_pR - new_pR) >= rd 296 end 297 298 reinforcements = reinforcements + 1 299 if in_testset then 300 reinforcements_test = reinforcements_test + 1 301 end 302 end 303 else 304 -- OK, out of unsure zone 305 result = "0" 306 end 307 end 308 else 309 hams = hams + 1 310 if in_testset then 311 hams_test = hams_test + 1 312 end 313 -- check classification 314 if (pR >= 0) then 315 -- correctly classified as ham. check thick_threshold 316 if pR < (threshold_offset + thick_threshold) then 317 -- within unsure zone 318 if not in_testset or train_in_testset then 319 -- do reinforcement 320 assert(osbf.learn(text, dbset, nonspam_index, learn_flags)) 321 local new_pR = osbf.classify(text, dbset, classify_flags) 322 323 result = "r" 324 if new_pR < (threshold_offset + thick_threshold) and 325 (new_pR - pR) < header_learn_threshold then 326 local i = 0 327 local old_pR 328 local trd = threshold_reinforcement_degree * 329 (threshold_offset + thick_threshold) 330 local rd = reinforcement_degree * header_learn_threshold 331 repeat 332 old_pR = new_pR 333 osbf.learn(lim_orig_header, dbset, nonspam_index, 334 reinforcement_flag) 335 new_pR, p_array = osbf.classify(text, dbset, classify_flags) 336 i = i + 1 337 until i > ham_reinforcement_limit or 338 new_pR > trd or (new_pR - old_pR) >= rd 339 end 340 341 reinforcements = reinforcements + 1 342 if in_testset then 343 reinforcements_test = reinforcements_test + 1 344 end 345 end 346 else 347 -- OK, out of unsure zone 348 result = "0" 349 end 350 else 351 -- wrong classification, false positive 352 result = "1" 353 false_positives = false_positives + 1 354 if not in_testset or train_in_testset then 355 assert(osbf.learn(text, dbset, nonspam_index, learn_flags)) 356 trainings = trainings + 1 357 end 358 local new_pR = osbf.classify(text, dbset, classify_flags) 359 360 if in_testset then 361 false_positives_test = false_positives_test + 1 362 if train_in_testset then 363 trainings_test = trainings_test + 1 364 end 365 end 366 367 if new_pR < (threshold_offset + thick_threshold) and 368 (new_pR - pR) < header_learn_threshold then 369 local i = 0 370 local old_pR 371 local trd = threshold_reinforcement_degree * 372 (threshold_offset + thick_threshold) 373 local rd = reinforcement_degree * header_learn_threshold 374 repeat 375 old_pR = new_pR 376 osbf.learn(lim_orig_header, dbset, nonspam_index, 377 reinforcement_flag) 378 new_pR, p_array = osbf.classify(text, dbset, classify_flags) 379 i = i + 1 380 until i > ham_reinforcement_limit or 381 new_pR > trd or (new_pR - old_pR) >= rd 382 end 383 384 end 385 end 386 log:write("file=",file_name," judge=", judge, " class=", class, 387 " score=", string.format("%.4f", (0 - pR)), " user= genre= runid=none\n") 388 log:flush() 389 end 390 s:close() 391 local duration = os.time() - ini 392 log:flush() 393 log:close() 394 395 -- print database stats report 396 db_stats_fh = assert(io.open(db_stats_report, "w")) 397 for _, dbfile in ipairs(dbset.classes) do 398 db_stats_fh:write(dbfile_stats(dbfile)) 399 end 400 db_stats_fh:close() 401 402 -- print training stats report 403 t_stats_fh = assert(io.open(training_stats_report, "w")) 404 t_stats_fh:write("-- Training statistics report\n\n") 405 t_stats_fh:write("Message corpus\n") 406 t_stats_fh:write(string.format(" %-26s%7d\n %-26s%7d\n %-26s%7d\n\n", 407 "Hams:", hams, "Spams:", spams, "Total messages:", hams+spams)) 408 409 t_stats_fh:write("Training (OSBFBayes)\n") 410 t_stats_fh:write(string.format( 411 " %-26s%7d\n %-26s%7d\n %-26s%7d\n %-26s%7d\n %-26s%7d\n %-26s%7d\n\n", 412 "Thick treshold:", thick_threshold, 413 "Header learn-treshold:", header_learn_threshold, 414 "Trainings on error:", false_positives+false_negatives, 415 "Reinforcements:", reinforcements, 416 "Total learnings:", false_positives+false_negatives+reinforcements, 417 "Duration (sec):", duration)) 418 419 t_stats_fh:write( 420 string.format("Performance in the final %d messages (testset)\n", 421 testsize)) 422 t_stats_fh:write(string.format( 423 " %-26s%7d\n %-26s%7d\n %-26s%7d\n", 424 "Hams in testset:", hams_test, 425 "Spams in testset:", spams_test, 426 "False positives:", false_positives_test)) 427 428 t_stats_fh:write(string.format( 429 " %-26s%7d\n %-26s%7d\n %-26s%7d\n %-26s%10.2f\n " .. 430 "%-26s%10.2f\n %-26s%10.2f\n %-26s%10.2f\n %-26s%10.2f\n", 431 "False negatives:", false_negatives_test, 432 "Total errors in testset:", false_positives_test+false_negatives_test, 433 "Reinforcements in testset:", reinforcements_test, 434 "Ham recall (%):", 100 * (hams_test - false_positives_test)/hams_test, 435 "Ham precision (%):", 100 * (hams_test - false_positives_test) / 436 (hams_test - false_positives_test + false_negatives_test), 437 "Spam recall (%):", 100 * (spams_test - false_negatives_test)/spams_test, 438 "Spam precision (%):", 100 * (spams_test - false_negatives_test) / 439 (spams_test - false_negatives_test + false_positives_test), 440 "Accuracy (%):", 441 100 * (1 - (false_positives_test+false_negatives_test)/testsize))) 442 t_stats_fh:close() 443 -- end of report 444end 445 446