1--[[
2Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15]]--
16
17local fun = require 'fun'
18local meta_functions = require "lua_meta"
19local lua_util = require "lua_util"
20local rspamd_url = require "rspamd_url"
21local common = require "lua_selectors/common"
22local ts = require("tableshape").types
23local E = {}
24
25local url_flags_ts = ts.array_of(ts.one_of(lua_util.keys(rspamd_url.flags))):is_optional()
26
27local function gen_exclude_flags_filter(exclude_flags)
28  return function(u)
29    local got_flags = u:get_flags()
30    for _, flag in ipairs(exclude_flags) do
31      if got_flags[flag] then return false end
32    end
33    return true
34  end
35end
36
37local extractors = {
38  -- Plain id function
39  ['id'] = {
40    ['get_value'] = function(_, args)
41      if args[1] then
42        return args[1], 'string'
43      end
44
45      return '','string'
46    end,
47    ['description'] = [[Return value from function's argument or an empty string,
48For example, `id('Something')` returns a string 'Something']],
49    ['args_schema'] = {ts.string:is_optional()}
50  },
51  -- Similar but for making lists
52  ['list'] = {
53    ['get_value'] = function(_, args)
54      if args[1] then
55        return fun.map(tostring, args), 'string_list'
56      end
57
58      return {},'string_list'
59    end,
60    ['description'] = [[Return a list from function's arguments or an empty list,
61For example, `list('foo', 'bar')` returns a list {'foo', 'bar'}]],
62  },
63  -- Get source IP address
64  ['ip'] = {
65    ['get_value'] = function(task)
66      local ip = task:get_ip()
67      if ip and ip:is_valid() then return ip,'userdata' end
68      return nil
69    end,
70    ['description'] = [[Get source IP address]],
71  },
72  -- Get MIME from
73  ['from'] = {
74    ['get_value'] = function(task, args)
75      local from
76      if type(args) == 'table' then
77        from = task:get_from(args)
78      else
79        from = task:get_from(0)
80      end
81      if ((from or E)[1] or E).addr then
82        return from[1],'table'
83      end
84      return nil
85    end,
86    ['description'] = [[Get MIME or SMTP from (e.g. `from('smtp')` or `from('mime')`,
87uses any type by default)]],
88  },
89  ['rcpts'] = {
90    ['get_value'] = function(task, args)
91      local rcpts
92      if type(args) == 'table' then
93        rcpts = task:get_recipients(args)
94      else
95        rcpts = task:get_recipients(0)
96      end
97      if ((rcpts or E)[1] or E).addr then
98        return rcpts,'table_list'
99      end
100      return nil
101    end,
102    ['description'] = [[Get MIME or SMTP rcpts (e.g. `rcpts('smtp')` or `rcpts('mime')`,
103uses any type by default)]],
104  },
105  -- Get country (ASN module must be executed first)
106  ['country'] = {
107    ['get_value'] = function(task)
108      local country = task:get_mempool():get_variable('country')
109      if not country then
110        return nil
111      else
112        return country,'string'
113      end
114    end,
115    ['description'] = [[Get country (ASN module must be executed first)]],
116  },
117  -- Get ASN number
118  ['asn'] = {
119    ['type'] = 'string',
120    ['get_value'] = function(task)
121      local asn = task:get_mempool():get_variable('asn')
122      if not asn then
123        return nil
124      else
125        return asn,'string'
126      end
127    end,
128    ['description'] = [[Get AS number (ASN module must be executed first)]],
129  },
130  -- Get authenticated username
131  ['user'] = {
132    ['get_value'] = function(task)
133      local auser = task:get_user()
134      if not auser then
135        return nil
136      else
137        return auser,'string'
138      end
139    end,
140    ['description'] = 'Get authenticated user name',
141  },
142  -- Get principal recipient
143  ['to'] = {
144    ['get_value'] = function(task)
145      return task:get_principal_recipient(),'string'
146    end,
147    ['description'] = 'Get principal recipient',
148  },
149  -- Get content digest
150  ['digest'] = {
151    ['get_value'] = function(task)
152      return task:get_digest(),'string'
153    end,
154    ['description'] = 'Get content digest',
155  },
156  -- Get list of all attachments digests
157  ['attachments'] = {
158    ['get_value'] = function(task, args)
159      local parts = task:get_parts() or E
160      local digests = {}
161      for i,p in ipairs(parts) do
162        if p:is_attachment() then
163          table.insert(digests, common.get_cached_or_raw_digest(task, i, p, args))
164        end
165      end
166
167      if #digests > 0 then
168        return digests,'string_list'
169      end
170
171      return nil
172    end,
173    ['description'] = [[Get list of all attachments digests.
174The first optional argument is encoding (`hex`, `base32` (and forms `bleach32`, `rbase32`), `base64`),
175the second optional argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]],
176    ['args_schema'] = common.digest_schema()
177
178  },
179  -- Get all attachments files
180  ['files'] = {
181    ['get_value'] = function(task)
182      local parts = task:get_parts() or E
183      local files = {}
184
185      for _,p in ipairs(parts) do
186        local fname = p:get_filename()
187        if fname then
188          table.insert(files, fname)
189        end
190      end
191
192      if #files > 0 then
193        return files,'string_list'
194      end
195
196      return nil
197    end,
198    ['description'] = 'Get all attachments files',
199  },
200  -- Get languages for text parts
201  ['languages'] = {
202    ['get_value'] = function(task)
203      local text_parts = task:get_text_parts() or E
204      local languages = {}
205
206      for _,p in ipairs(text_parts) do
207        local lang = p:get_language()
208        if lang then
209          table.insert(languages, lang)
210        end
211      end
212
213      if #languages > 0 then
214        return languages,'string_list'
215      end
216
217      return nil
218    end,
219    ['description'] = 'Get languages for text parts',
220  },
221  -- Get helo value
222  ['helo'] = {
223    ['get_value'] = function(task)
224      return task:get_helo(),'string'
225    end,
226    ['description'] = 'Get helo value',
227  },
228  -- Get header with the name that is expected as an argument. Returns list of
229  -- headers with this name
230  ['header'] = {
231    ['get_value'] = function(task, args)
232      local strong = false
233      if args[2] then
234        if args[2]:match('strong') then
235          strong = true
236        end
237
238        if args[2]:match('full') then
239          return task:get_header_full(args[1], strong),'table_list'
240        end
241
242        return task:get_header(args[1], strong),'string'
243      else
244        return task:get_header(args[1]),'string'
245      end
246    end,
247    ['description'] = [[Get header with the name that is expected as an argument.
248The optional second argument accepts list of flags:
249  - `full`: returns all headers with this name with all data (like task:get_header_full())
250  - `strong`: use case sensitive match when matching header's name]],
251    ['args_schema'] = {ts.string,
252                       (ts.pattern("strong") + ts.pattern("full")):is_optional()}
253  },
254  -- Get list of received headers (returns list of tables)
255  ['received'] = {
256    ['get_value'] = function(task, args)
257      local rh = task:get_received_headers()
258      if not rh[1] then
259        return nil
260      end
261      if args[1] then
262        return fun.map(function(r) return r[args[1]] end, rh), 'string_list'
263      end
264
265      return rh,'table_list'
266    end,
267    ['description'] = [[Get list of received headers.
268If no arguments specified, returns list of tables. Otherwise, selects a specific element,
269e.g. `by_hostname`]],
270  },
271  -- Get all urls
272  ['urls'] = {
273    ['get_value'] = function(task, args)
274      local urls = task:get_urls()
275      if not urls[1] then
276        return nil
277      end
278      if args[1] then
279        return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list'
280      end
281      return urls,'userdata_list'
282    end,
283    ['description'] = [[Get list of all urls.
284If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
285e.g. `get_tld`]],
286  },
287  -- Get specific urls
288  ['specific_urls'] = {
289    ['get_value'] = function(task, args)
290      local params = args[1] or {}
291      params.task = task
292      params.no_cache = true
293      if params.exclude_flags then
294        params.filter = gen_exclude_flags_filter(params.exclude_flags)
295      end
296      local urls = lua_util.extract_specific_urls(params)
297      if not urls[1] then
298        return nil
299      end
300      return urls,'userdata_list'
301    end,
302    ['description'] = [[Get most specific urls. Arguments are equal to the Lua API function]],
303    ['args_schema'] = {ts.shape{
304      limit = ts.number + ts.string / tonumber,
305      esld_limit = (ts.number + ts.string / tonumber):is_optional(),
306      exclude_flags = url_flags_ts,
307      flags = url_flags_ts,
308      flags_mode = ts.one_of{'explicit'}:is_optional(),
309      prefix = ts.string:is_optional(),
310      need_content = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
311      need_emails = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
312      need_images = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
313      ignore_redirected = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
314    }}
315  },
316  -- URLs filtered by flags
317  ['urls_filtered'] = {
318    ['get_value'] = function(task, args)
319      local urls = task:get_urls_filtered(args[1], args[2])
320      if not urls[1] then
321        return nil
322      end
323      return urls,'userdata_list'
324    end,
325    ['description'] = [[Get list of all urls filtered by flags_include/exclude
326(see rspamd_task:get_urls_filtered for description)]],
327    ['args_schema'] = {ts.array_of{
328      url_flags_ts:is_optional(), url_flags_ts:is_optional()
329    }}
330  },
331  -- Get all emails
332  ['emails'] = {
333    ['get_value'] = function(task, args)
334      local urls = task:get_emails()
335      if not urls[1] then
336        return nil
337      end
338      if args[1] then
339        return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list'
340      end
341      return urls,'userdata_list'
342    end,
343    ['description'] = [[Get list of all emails.
344If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
345e.g. `get_user`]],
346  },
347  -- Get specific pool var. The first argument must be variable name,
348  -- the second argument is optional and defines the type (string by default)
349  ['pool_var'] = {
350    ['get_value'] = function(task, args)
351      local type = args[2] or 'string'
352      return task:get_mempool():get_variable(args[1], type),(type)
353    end,
354    ['description'] = [[Get specific pool var. The first argument must be variable name,
355the second argument is optional and defines the type (string by default)]],
356    ['args_schema'] = {ts.string, ts.string:is_optional()}
357  },
358  -- Get value of specific key from task cache
359  ['task_cache'] = {
360    ['get_value'] = function(task, args)
361      local val = task:cache_get(args[1])
362      if not val then
363        return
364      end
365      if type(val) == 'table' then
366        if not val[1] then
367          return
368        end
369        return val, 'string_list'
370      end
371      return val, 'string'
372    end,
373    ['description'] = [[Get value of specific key from task cache. The first argument must be
374the key name]],
375    ['args_schema'] = {ts.string}
376  },
377  -- Get specific HTTP request header. The first argument must be header name.
378  ['request_header'] = {
379    ['get_value'] = function(task, args)
380      local hdr = task:get_request_header(args[1])
381      if hdr then
382        return hdr,'string'
383      end
384
385      return nil
386    end,
387    ['description'] = [[Get specific HTTP request header.
388The first argument must be header name.]],
389    ['args_schema'] = {ts.string}
390  },
391  -- Get task date, optionally formatted
392  ['time'] = {
393    ['get_value'] = function(task, args)
394      local what = args[1] or 'message'
395      local dt = task:get_date{format = what, gmt = true}
396
397      if dt then
398        if args[2] then
399          -- Should be in format !xxx, as dt is in GMT
400          return os.date(args[2], dt),'string'
401        end
402
403        return tostring(dt),'string'
404      end
405
406      return nil
407    end,
408    ['description'] = [[Get task timestamp. The first argument is type:
409  - `connect`: connection timestamp (default)
410  - `message`: timestamp as defined by `Date` header
411
412  The second argument is optional time format, see [os.date](http://pgl.yoyo.org/luai/i/os.date) description]],
413    ['args_schema'] = {ts.one_of{'connect', 'message'}:is_optional(),
414                       ts.string:is_optional()}
415  },
416  -- Get text words from a message
417  ['words'] = {
418    ['get_value'] = function(task, args)
419      local how = args[1] or 'stem'
420      local tp = task:get_text_parts()
421
422      if tp then
423        local rtype = 'string_list'
424        if how == 'full' then
425          rtype = 'table_list'
426        end
427
428        return lua_util.flatten(
429            fun.map(function(p)
430              return p:get_words(how)
431            end, tp)), rtype
432      end
433
434      return nil
435    end,
436    ['description'] = [[Get words from text parts
437  - `stem`: stemmed words (default)
438  - `raw`: raw words
439  - `norm`: normalised words (lowercased)
440  - `full`: list of tables
441  ]],
442    ['args_schema'] = { ts.one_of { 'stem', 'raw', 'norm', 'full' }:is_optional()},
443  },
444  -- Get queue ID
445  ['queueid'] = {
446    ['get_value'] = function(task)
447      local queueid = task:get_queue_id()
448      if queueid then return queueid,'string' end
449      return nil
450    end,
451    ['description'] = [[Get queue ID]],
452  },
453  -- Get ID of the task being processed
454  ['uid'] = {
455    ['get_value'] = function(task)
456      local uid = task:get_uid()
457      if uid then return uid,'string' end
458      return nil
459    end,
460    ['description'] = [[Get ID of the task being processed]],
461  },
462  -- Get message ID of the task being processed
463  ['messageid'] = {
464    ['get_value'] = function(task)
465      local mid = task:get_message_id()
466      if mid then return mid,'string' end
467      return nil
468    end,
469    ['description'] = [[Get message ID]],
470  },
471  -- Get specific symbol
472  ['symbol'] = {
473    ['get_value'] = function(task, args)
474      local symbol = task:get_symbol(args[1], args[2])
475      if symbol then
476        return symbol[1],'table'
477      end
478    end,
479    ['description'] = 'Get specific symbol. The first argument must be the symbol name. ' ..
480      'The second argument is an optional shadow result name. ' ..
481      'Returns the symbol table. See task:get_symbol()',
482    ['args_schema'] = {ts.string, ts.string:is_optional()}
483  },
484  -- Get full scan result
485  ['scan_result'] = {
486    ['get_value'] = function(task, args)
487      local res = task:get_metric_result(args[1])
488      if res then
489        return res,'table'
490      end
491    end,
492    ['description'] = 'Get full scan result (either default or shadow if shadow result name is specified)' ..
493        'Returns the result table. See task:get_metric_result()',
494    ['args_schema'] = {ts.string:is_optional()}
495  },
496  -- Get list of metatokens as strings
497  ['metatokens'] = {
498    ['get_value'] = function(task)
499      local tokens = meta_functions.gen_metatokens(task)
500      if not tokens[1] then
501        return nil
502      end
503      local res = {}
504      for _, t in ipairs(tokens) do
505        table.insert(res, tostring(t))
506      end
507      return res, 'string_list'
508    end,
509    ['description'] = 'Get metatokens for a message as strings',
510  },
511}
512
513return extractors
514