1--[[
2Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org>
3Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
4
5Licensed under the Apache License, Version 2.0 (the "License");
6you may not use this file except in compliance with the License.
7You may obtain a copy of the License at
8
9    http://www.apache.org/licenses/LICENSE-2.0
10
11Unless required by applicable law or agreed to in writing, software
12distributed under the License is distributed on an "AS IS" BASIS,
13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14See the License for the specific language governing permissions and
15limitations under the License.
16]]--
17
18-- A plugin that forces actions
19
20if confighelp then
21  return
22end
23
24local E = {}
25local N = 'force_actions'
26local selector_cache = {}
27
28local fun = require "fun"
29local lua_util = require "lua_util"
30local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash"
31local rspamd_expression = require "rspamd_expression"
32local rspamd_logger = require "rspamd_logger"
33local lua_selectors = require "lua_selectors"
34
35-- Params table fields:
36-- expr, act, pool, message, subject, raction, honor, limit, flags
37local function gen_cb(params)
38
39  local function parse_atom(str)
40    local atom = table.concat(fun.totable(fun.take_while(function(c)
41      if string.find(', \t()><+!|&\n', c) then
42        return false
43      end
44      return true
45    end, fun.iter(str))), '')
46    return atom
47  end
48
49  local function process_atom(atom, task)
50    local f_ret = task:has_symbol(atom)
51    if f_ret then
52      f_ret = math.abs(task:get_symbol(atom)[1].score)
53      if f_ret < 0.001 then
54        -- Adjust some low score to distinguish from pure zero
55        f_ret = 0.001
56      end
57      return f_ret
58    end
59    return 0
60  end
61
62  local e, err = rspamd_expression.create(params.expr, {parse_atom, process_atom}, params.pool)
63  if err then
64    rspamd_logger.errx(rspamd_config, 'Couldnt create expression [%1]: %2', params.expr, err)
65    return
66  end
67
68  return function(task)
69
70    local function process_message_selectors(repl, selector_expr)
71      -- create/reuse selector to extract value for this placeholder
72      local selector = selector_cache[selector_expr]
73      if not selector then
74        selector_cache[selector_expr] = lua_selectors.create_selector_closure(rspamd_config, selector_expr, '', true)
75        selector = selector_cache[selector_expr]
76        if not selector then
77          rspamd_logger.errx(task, 'could not create selector [%1]', selector_expr)
78          return "((could not create selector))"
79        end
80      end
81      local extracted = selector(task)
82      if extracted then
83        if type(extracted) == 'table' then
84          extracted = table.concat(extracted, ',')
85        end
86      else
87        rspamd_logger.errx(task, 'could not extract value with selector [%1]', selector_expr)
88        extracted = '((error extracting value))'
89      end
90      return extracted
91    end
92
93    local cact = task:get_metric_action('default')
94    if not params.message and not params.subject and params.act and cact == params.act then
95      return false
96    end
97    if params.honor and params.honor[cact] then
98      return false
99    elseif params.raction and not params.raction[cact] then
100      return false
101    end
102
103    local ret = e:process(task)
104    lua_util.debugm(N, task, "expression %s returned %s", params.expr, ret)
105    if (not params.limit and ret > 0) or (ret > (params.limit or 0)) then
106      if params.subject then
107        task:set_metric_subject(params.subject)
108      end
109
110      local flags = params.flags or ""
111
112      if type(params.message) == 'string' then
113        -- process selector expressions in the message
114        local message = string.gsub(params.message, '(${(.-)})', process_message_selectors)
115        task:set_pre_result{action = params.act, message = message, module = N, flags = flags}
116      else
117        task:set_pre_result{action = params.act, module = N, flags = flags}
118      end
119      return true, params.act
120    end
121
122  end, e:atoms()
123
124end
125
126local function configure_module()
127  local opts = rspamd_config:get_all_opt(N)
128  if not opts then
129    return false
130  end
131  if type(opts.actions) == 'table' then
132    rspamd_logger.warnx(rspamd_config, 'Processing legacy config')
133    for action, expressions in pairs(opts.actions) do
134      if type(expressions) == 'table' then
135        for _, expr in ipairs(expressions) do
136          local message, subject
137          if type(expr) == 'table' then
138            subject = expr[3]
139            message = expr[2]
140            expr = expr[1]
141          else
142            message = (opts.messages or E)[expr]
143          end
144          if type(expr) == 'string' then
145            -- expr, act, pool, message, subject, raction, honor, limit, flags
146            local cb, atoms = gen_cb{expr = expr,
147                                     act = action,
148                                     pool = rspamd_config:get_mempool(),
149                                     message = message,
150                                     subject = subject}
151            if cb and atoms then
152              local h = rspamd_cryptobox_hash.create()
153              h:update(expr)
154              local name = 'FORCE_ACTION_' .. string.upper(string.sub(h:hex(), 1, 12))
155              rspamd_config:register_symbol({
156                type = 'normal',
157                name = name,
158                callback = cb,
159                flags = 'empty',
160              })
161              for _, a in ipairs(atoms) do
162                rspamd_config:register_dependency(name, a)
163              end
164              rspamd_logger.infox(rspamd_config, 'Registered symbol %1 <%2> with dependencies [%3]',
165                  name, expr, table.concat(atoms, ','))
166            end
167          end
168        end
169      end
170    end
171  elseif type(opts.rules) == 'table' then
172    for name, sett in pairs(opts.rules) do
173      local action = sett.action
174      local expr = sett.expression
175
176      if action and expr then
177        local flags = {}
178        if sett.least then table.insert(flags, "least") end
179        if sett.process_all then table.insert(flags, "process_all") end
180        local raction = lua_util.list_to_hash(sett.require_action)
181        local honor = lua_util.list_to_hash(sett.honor_action)
182        local cb, atoms = gen_cb{expr = expr,
183                                 act = action,
184                                 pool = rspamd_config:get_mempool(),
185                                 message = sett.message,
186                                 subject = sett.subject,
187                                 raction = raction,
188                                 honor = honor,
189                                 limit = sett.limit,
190                                 flags = table.concat(flags, ',')}
191        if cb and atoms then
192          local t = {}
193          if (raction or honor) then
194            t.type = 'postfilter'
195            t.priority = 10
196          else
197            t.type = 'normal'
198          end
199          t.name = 'FORCE_ACTION_' .. name
200          t.callback = cb
201          t.flags = 'empty'
202          rspamd_config:register_symbol(t)
203          if t.type == 'normal' then
204            for _, a in ipairs(atoms) do
205              rspamd_config:register_dependency(t.name, a)
206            end
207            rspamd_logger.infox(rspamd_config, 'Registered symbol %1 <%2> with dependencies [%3]',
208                t.name, expr, table.concat(atoms, ','))
209          else
210            rspamd_logger.infox(rspamd_config, 'Registered symbol %1 <%2> as postfilter', t.name, expr)
211          end
212        end
213      end
214    end
215  end
216end
217
218configure_module()
219