1--[[
2Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15]]--
16
17--[[[
18-- @module lua_magic/heuristics
19-- This module contains heuristics for some specific cases
20--]]
21
22local rspamd_trie = require "rspamd_trie"
23local rspamd_util = require "rspamd_util"
24local lua_util = require "lua_util"
25local bit = require "bit"
26local fun = require "fun"
27
28local N = "lua_magic"
29local msoffice_trie
30local msoffice_patterns = {
31  doc = {[[WordDocument]]},
32  xls = {[[Workbook]], [[Book]]},
33  ppt = {[[PowerPoint Document]], [[Current User]]},
34  vsd = {[[VisioDocument]]},
35}
36local msoffice_trie_clsid
37local msoffice_clsids = {
38  doc = {[[0609020000000000c000000000000046]]},
39  xls = {[[1008020000000000c000000000000046]], [[2008020000000000c000000000000046]]},
40  ppt = {[[108d81649b4fcf1186ea00aa00b929e8]]},
41  msg = {[[46f0060000000000c000000000000046]], [[0b0d020000000000c000000000000046]]},
42  msi = {[[84100c0000000000c000000000000046]]},
43}
44local zip_trie
45local zip_patterns = {
46  -- https://lists.oasis-open.org/archives/office/200505/msg00006.html
47  odt = {
48    [[mimetypeapplication/vnd\.oasis\.opendocument\.text]],
49    [[mimetypeapplication/vnd\.oasis\.opendocument\.image]],
50    [[mimetypeapplication/vnd\.oasis\.opendocument\.graphic]]
51  },
52  ods = {
53    [[mimetypeapplication/vnd\.oasis\.opendocument\.spreadsheet]],
54    [[mimetypeapplication/vnd\.oasis\.opendocument\.formula]],
55    [[mimetypeapplication/vnd\.oasis\.opendocument\.chart]]
56  },
57  odp = {[[mimetypeapplication/vnd\.oasis\.opendocument\.presentation]]},
58  epub = {[[epub\+zip]]},
59  asice = {[[mimetypeapplication/vnd\.etsi\.asic-e\+zipPK]]},
60  asics = {[[mimetypeapplication/vnd\.etsi\.asic-s\+zipPK]]},
61}
62
63local txt_trie
64local txt_patterns = {
65  html = {
66    {[[(?i)<html\b]], 32},
67    {[[(?i)<script\b]], 20}, -- Commonly used by spammers
68    {[[<script\s+type="text\/javascript">]], 31}, -- Another spammy pattern
69    {[[(?i)<\!DOCTYPE HTML\b]], 33},
70    {[[(?i)<body\b]], 20},
71    {[[(?i)<table\b]], 20},
72    {[[(?i)<a\b]], 10},
73    {[[(?i)<p\b]], 10},
74    {[[(?i)<div\b]], 10},
75    {[[(?i)<span\b]], 10},
76  },
77  csv = {
78    {[[(?:[-a-zA-Z0-9_]+\s*,){2,}(?:[-a-zA-Z0-9_]+,?[ ]*[\r\n])]], 20}
79  },
80  ics = {
81    {[[^BEGIN:VCALENDAR\r?\n]], 40},
82  },
83  vcf = {
84    {[[^BEGIN:VCARD\r?\n]], 40},
85  },
86  xml = {
87    {[[<\?xml\b.+\?>]], 31},
88  }
89}
90
91-- Used to match pattern index and extension
92local msoffice_clsid_indexes = {}
93local msoffice_patterns_indexes = {}
94local zip_patterns_indexes = {}
95local txt_patterns_indexes = {}
96
97local exports = {}
98
99local function compile_tries()
100  local default_compile_flags = bit.bor(rspamd_trie.flags.re,
101      rspamd_trie.flags.dot_all,
102      rspamd_trie.flags.single_match,
103      rspamd_trie.flags.no_start)
104  local function compile_pats(patterns, indexes, transform_func, compile_flags)
105    local strs = {}
106    for ext,pats in pairs(patterns) do
107      for _,pat in ipairs(pats) do
108        -- These are utf16 strings in fact...
109        strs[#strs + 1] = transform_func(pat)
110        indexes[#indexes + 1] = {ext, pat}
111      end
112    end
113
114    return rspamd_trie.create(strs, compile_flags or default_compile_flags)
115  end
116
117  if not msoffice_trie then
118    -- Directory names
119    local function msoffice_pattern_transform(pat)
120      return '^' ..
121          table.concat(
122              fun.totable(
123                  fun.map(function(c) return c .. [[\x{00}]] end,
124                      fun.iter(pat))))
125    end
126    local function msoffice_clsid_transform(pat)
127      local hex_table = {}
128      for i=1,#pat,2 do
129        local subc = pat:sub(i, i + 1)
130        hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
131      end
132
133      return '^' .. table.concat(hex_table) .. '$'
134    end
135    -- Directory entries
136    msoffice_trie = compile_pats(msoffice_patterns, msoffice_patterns_indexes,
137        msoffice_pattern_transform)
138    -- Clsids
139    msoffice_trie_clsid = compile_pats(msoffice_clsids, msoffice_clsid_indexes,
140        msoffice_clsid_transform)
141    -- Misc zip patterns at the initial fragment
142    zip_trie = compile_pats(zip_patterns, zip_patterns_indexes,
143        function(pat) return pat end)
144    -- Text patterns at the initial fragment
145    txt_trie = compile_pats(txt_patterns, txt_patterns_indexes,
146        function(pat_tbl) return pat_tbl[1] end,
147        bit.bor(rspamd_trie.flags.re,
148            rspamd_trie.flags.dot_all,
149            rspamd_trie.flags.no_start))
150  end
151end
152
153-- Call immediately on require
154compile_tries()
155
156local function detect_ole_format(input, log_obj, _, part)
157  local inplen = #input
158  if inplen < 0x31 + 4 then
159    lua_util.debugm(N, log_obj, "short length: %s", inplen)
160    return nil
161  end
162
163  local bom,sec_size = rspamd_util.unpack('<I2<I2', input:span(29, 4))
164  if bom == 0xFFFE then
165    bom = '<'
166  else
167    lua_util.debugm(N, log_obj, "bom file!: %s", bom)
168    bom = '>'; sec_size = bit.bswap(sec_size)
169  end
170
171  if sec_size < 7 or sec_size > 31 then
172    lua_util.debugm(N, log_obj, "bad sec_size: %s", sec_size)
173    return nil
174  end
175
176  sec_size = 2 ^ sec_size
177
178  -- SecID of first sector of the directory stream
179  local directory_offset = (rspamd_util.unpack(bom .. 'I4', input:span(0x31, 4)))
180      * sec_size + 512 + 1
181  lua_util.debugm(N, log_obj, "directory: %s", directory_offset)
182
183  if inplen < directory_offset then
184    lua_util.debugm(N, log_obj, "short length: %s", inplen)
185    return nil
186  end
187
188  local function process_dir_entry(offset)
189    local dtype = input:byte(offset + 66)
190    lua_util.debugm(N, log_obj, "dtype: %s, offset: %s", dtype, offset)
191
192    if dtype then
193      if dtype == 5 then
194        -- Extract clsid
195        local matches = msoffice_trie_clsid:match(input:span(offset + 80, 16))
196        if matches then
197          for n,_ in pairs(matches) do
198            if msoffice_clsid_indexes[n] then
199              lua_util.debugm(N, log_obj, "found valid clsid for %s",
200                  msoffice_clsid_indexes[n][1])
201              return true,msoffice_clsid_indexes[n][1]
202            end
203          end
204        end
205        return true,nil
206      elseif dtype == 2 then
207        local matches = msoffice_trie:match(input:span(offset, 64))
208        if matches then
209          for n,_ in pairs(matches) do
210            if msoffice_patterns_indexes[n] then
211              return true,msoffice_patterns_indexes[n][1]
212            end
213          end
214        end
215        return true,nil
216      elseif dtype >= 0 and dtype < 5 then
217        -- Bad type
218        return true,nil
219      end
220    end
221
222    return false,nil
223  end
224
225  repeat
226    local res,ext = process_dir_entry(directory_offset)
227
228    if res and ext then
229      return ext,60
230    end
231
232    if not res then
233      break
234    end
235
236    directory_offset = directory_offset + 128
237  until directory_offset >= inplen
238end
239
240exports.ole_format_heuristic = detect_ole_format
241
242local function process_top_detected(res)
243  local extensions = lua_util.keys(res)
244
245  if #extensions > 0 then
246    table.sort(extensions, function(ex1, ex2)
247      return res[ex1] > res[ex2]
248    end)
249
250    return extensions[1],res[extensions[1]]
251  end
252
253  return nil
254end
255
256local function detect_archive_flaw(part, arch, log_obj, _)
257  local arch_type = arch:get_type()
258  local res = {
259    docx = 0,
260    xlsx = 0,
261    pptx = 0,
262    jar = 0,
263    odt = 0,
264    odp = 0,
265    ods = 0,
266    apk = 0,
267  } -- ext + confidence pairs
268
269  -- General msoffice patterns
270  local function add_msoffice_confidence(incr)
271    res.docx = res.docx + incr
272    res.xlsx = res.xlsx + incr
273    res.pptx = res.pptx + incr
274  end
275
276  if arch_type == 'zip' then
277    -- Find specific files/folders in zip file
278    local files = arch:get_files(100) or {}
279    for _,file in ipairs(files) do
280      if file == '[Content_Types].xml' then
281        add_msoffice_confidence(10)
282      elseif file:sub(1, 3) == 'xl/' then
283        res.xlsx = res.xlsx + 30
284      elseif file:sub(1, 5) == 'word/' then
285        res.docx = res.docx + 30
286      elseif file:sub(1, 4) == 'ppt/' then
287        res.pptx = res.pptx + 30
288      elseif file == 'META-INF/MANIFEST.MF' then
289        res.jar = res.jar + 40
290      elseif file == 'AndroidManifest.xml' then
291        res.apk = res.apk + 60
292      end
293    end
294
295    local ext,weight = process_top_detected(res)
296
297    if weight >= 40 then
298      return ext,weight
299    end
300
301    -- Apply misc Zip detection logic
302    local content = part:get_content()
303
304    if #content > 128 then
305      local start_span = content:span(1, 128)
306
307      local matches = zip_trie:match(start_span)
308      if matches then
309        for n,_ in pairs(matches) do
310          if zip_patterns_indexes[n] then
311            lua_util.debugm(N, log_obj, "found zip pattern for %s",
312                zip_patterns_indexes[n][1])
313            return zip_patterns_indexes[n][1],40
314          end
315        end
316      end
317    end
318  end
319
320  return arch_type:lower(),40
321end
322
323local csv_grammar
324-- Returns a grammar that will count commas
325local function get_csv_grammar()
326  if not csv_grammar then
327    local lpeg = require'lpeg'
328
329    local field = '"' * lpeg.Cs(((lpeg.P(1) - '"') + lpeg.P'""' / '"')^0) * '"' +
330        lpeg.C((1 - lpeg.S',\n"')^0)
331
332    csv_grammar = lpeg.Cf(lpeg.Cc(0) * field * lpeg.P( (lpeg.P(',') +
333        lpeg.P('\t')) * field)^1 * (lpeg.S'\r\n' + -1),
334        function(acc) return acc + 1 end)
335  end
336
337  return csv_grammar
338end
339local function validate_csv(part, content, log_obj)
340  local max_chunk = 32768
341  local chunk = content:sub(1, max_chunk)
342
343  local expected_commas
344  local matched_lines = 0
345  local max_matched_lines = 10
346
347  lua_util.debugm(N, log_obj, "check for csv pattern")
348
349  for s in chunk:lines() do
350    local ncommas = get_csv_grammar():match(s)
351
352    if not ncommas then
353      lua_util.debugm(N, log_obj, "not a csv line at line number %s",
354          matched_lines)
355      return false
356    end
357
358    if expected_commas and ncommas ~= expected_commas then
359      -- Mismatched commas
360      lua_util.debugm(N, log_obj, "missmatched commas on line %s: %s != %s",
361          matched_lines, ncommas, expected_commas)
362      return false
363    elseif not expected_commas then
364      if ncommas == 0 then
365        lua_util.debugm(N, log_obj, "no commas in the first line")
366        return false
367      end
368      expected_commas = ncommas
369    end
370
371    matched_lines = matched_lines + 1
372
373    if matched_lines > max_matched_lines then
374      break
375    end
376  end
377
378  lua_util.debugm(N, log_obj, "csv content is sane: %s fields; %s lines checked",
379      expected_commas, matched_lines)
380
381  return true
382end
383
384exports.mime_part_heuristic = function(part, log_obj, _)
385  if part:is_archive() then
386    local arch = part:get_archive()
387    return detect_archive_flaw(part, arch, log_obj)
388  end
389
390  return nil
391end
392
393exports.text_part_heuristic = function(part, log_obj, _)
394  -- We get some span of data and check it
395  local function is_span_text(span)
396    -- We examine 8 bit content, and we assume it might be localized text
397    -- if it has more than 3 subsequent 8 bit characters
398    local function rough_8bit_check(bytes, idx, remain, len)
399      local b = bytes[idx]
400      local n8bit = 0
401
402      while b >= 127 and idx < len do
403        -- utf8 part
404        if bit.band(b, 0xe0) == 0xc0 and remain > 1 and
405                bit.band(bytes[idx + 1], 0xc0) == 0x80 then
406          return true,1
407        elseif bit.band(b, 0xf0) == 0xe0 and remain > 2 and
408                bit.band(bytes[idx + 1], 0xc0) == 0x80 and
409                bit.band(bytes[idx + 2], 0xc0) == 0x80 then
410          return true,2
411        elseif bit.band(b, 0xf8) == 0xf0 and remain > 3 and
412                bit.band(bytes[idx + 1], 0xc0) == 0x80 and
413                bit.band(bytes[idx + 2], 0xc0) == 0x80 and
414                bit.band(bytes[idx + 3], 0xc0) == 0x80 then
415          return true,3
416        end
417
418        n8bit = n8bit + 1
419        idx = idx + 1
420        b = bytes[idx]
421        remain = remain - 1
422      end
423
424      if n8bit >= 3 then
425        return true,n8bit
426      end
427
428      return false,0
429    end
430
431    -- Convert to string as LuaJIT can optimise string.sub (and fun.iter) but not C calls
432    local tlen = #span
433    local non_printable = 0
434    local bytes = span:bytes()
435    local i = 1
436    repeat
437      local b = bytes[i]
438
439      if (b < 0x20) and not (b == 0x0d or b == 0x0a or b == 0x09) then
440        non_printable = non_printable + 1
441      elseif b >= 127 then
442        local c,nskip = rough_8bit_check(bytes, i, tlen - i, tlen)
443
444        if not c then
445          non_printable = non_printable + 1
446        else
447          i = i + nskip
448        end
449      end
450      i = i + 1
451    until i > tlen
452
453    lua_util.debugm(N, log_obj, "text part check: %s printable, %s non-printable, %s total",
454        tlen - non_printable, non_printable, tlen)
455    if non_printable / tlen > 0.0078125 then
456      return false
457    end
458
459    return true
460  end
461
462  local parent = part:get_parent()
463
464  if parent then
465    local parent_type,parent_subtype = parent:get_type()
466
467    if parent_type == 'multipart' and parent_subtype == 'encrypted' then
468      -- Skip text heuristics for encrypted parts
469      lua_util.debugm(N, log_obj, "text part check: parent is encrypted, not a text part")
470
471      return false
472    end
473  end
474
475  local content = part:get_content()
476  local mtype,msubtype = part:get_type()
477  local clen = #content
478  local is_text
479
480  if clen > 0 then
481    if clen > 80 * 3 then
482      -- Use chunks
483      is_text = is_span_text(content:span(1, 160)) and is_span_text(content:span(clen - 80, 80))
484    else
485      is_text = is_span_text(content)
486    end
487
488    if is_text and mtype ~= 'message' then
489      -- Try patterns
490      local span_len = math.min(4096, clen)
491      local start_span = content:span(1, span_len)
492      local matches = txt_trie:match(start_span)
493      local res = {}
494      local fname = part:get_filename()
495
496      if matches then
497        -- Require at least 2 occurrences of those patterns
498        for n,positions in pairs(matches) do
499          local ext,weight = txt_patterns_indexes[n][1], txt_patterns_indexes[n][2][2]
500          if ext then
501            res[ext] = (res[ext] or 0) + weight * #positions
502            lua_util.debugm(N, log_obj, "found txt pattern for %s: %s, total: %s; %s/%s announced",
503                ext, weight * #positions, res[ext], mtype, msubtype)
504          end
505        end
506
507        if res.html and res.html >= 40  then
508          -- HTML has priority over something like js...
509          return 'html', res.html
510        end
511
512        local ext, weight = process_top_detected(res)
513
514        if weight then
515          if weight >= 40 then
516            -- Extra validation for csv extension
517            if ext ~= 'csv' or validate_csv(part, content, log_obj) then
518              return ext, weight
519            end
520          elseif fname and weight >= 20 then
521            return ext, weight
522          end
523        end
524      end
525
526      -- Content type stuff
527      if (mtype == 'text' or mtype == 'application') and
528              (msubtype == 'html' or msubtype == 'xhtml+xml') then
529        return 'html', 21
530      end
531
532      -- Extension stuff
533      local function has_extension(file, ext)
534        local ext_len = ext:len()
535        return file:len() > ext_len + 1
536                and file:sub(-ext_len):lower() == ext
537                and file:sub(-ext_len - 1, -ext_len - 1) == '.'
538      end
539
540      if fname and (has_extension(fname, 'htm') or has_extension(fname, 'html')) then
541        return 'html',21
542      end
543
544      if mtype ~= 'text' then
545        -- Do not treat non text patterns as text
546        return nil
547      end
548
549      return 'txt',40
550    end
551  end
552end
553
554exports.pdf_format_heuristic = function(input, log_obj, pos, part)
555  local weight = 10
556  local ext = string.match(part:get_filename() or '', '%.([^.]+)$')
557  -- If we found a pattern at the beginning
558  if pos <= 10 then
559    weight = weight + 30
560  end
561  -- If the announced extension is `pdf`
562  if ext and ext:lower() == 'pdf' then
563    weight = weight + 30
564  end
565
566  return 'pdf',weight
567end
568
569exports.pe_part_heuristic = function(input, log_obj, pos, part)
570  if not input then
571    return
572  end
573
574  -- pe header should start at the offset that is placed in msdos header at position 60..64
575  local pe_ptr_bin = input:sub(60, 64)
576  if #pe_ptr_bin ~= 4 then
577    return
578  end
579
580  -- it is an LE 32 bit integer
581  local pe_ptr = rspamd_util.unpack("<I4", pe_ptr_bin)
582  -- if pe header magic matches the offset, it is definitely a PE file
583  if pe_ptr ~= pos then
584    return
585  end
586
587  return 'exe',30
588end
589
590return exports
591