1--[[
2Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru>
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15]]--
16
17if confighelp then
18  return
19end
20
21-- This plugin implements dynamic updates for rspamd
22
23local ucl = require "ucl"
24local fun = require "fun"
25local rspamd_logger = require "rspamd_logger"
26local rspamd_config = rspamd_config
27local hash = require "rspamd_cryptobox_hash"
28local lua_util = require "lua_util"
29local N = "rspamd_update"
30local rspamd_version = rspamd_version
31local maps = {}
32local allow_rules = false -- Deny for now
33local global_priority = 1 -- Default for local rules
34
35local function process_symbols(obj, priority)
36  fun.each(function(sym, score)
37    rspamd_config:set_metric_symbol({
38      name = sym,
39      score = score,
40      priority = priority
41    })
42  end, obj)
43end
44
45local function process_actions(obj, priority)
46  fun.each(function(act, score)
47    rspamd_config:set_metric_action({
48      action = act,
49      score = score,
50      priority = priority
51    })
52  end, obj)
53end
54
55local function process_rules(obj)
56  fun.each(function(key, code)
57    local f = load(code)
58    if f then
59      f()
60    else
61      rspamd_logger(rspamd_config, 'cannot load rules for %s', key)
62    end
63  end, obj)
64end
65
66local function check_version(obj)
67  local ret = true
68
69  if not obj then
70    return false
71  end
72
73  if obj['min_version'] then
74    if rspamd_version('cmp', obj['min_version']) > 0 then
75      ret = false
76      rspamd_logger.errx(rspamd_config, 'updates require at least %s version of rspamd',
77        obj['min_version'])
78    end
79  end
80  if obj['max_version'] then
81    if rspamd_version('cmp', obj['max_version']) < 0 then
82      ret = false
83      rspamd_logger.errx(rspamd_config, 'updates require maximum %s version of rspamd',
84        obj['max_version'])
85    end
86  end
87
88  return ret
89end
90
91local function gen_callback()
92
93  return function(data)
94    local parser = ucl.parser()
95    local res,err = parser:parse_string(data)
96
97    if not res then
98      rspamd_logger.warnx(rspamd_config, 'cannot parse updates map: ' .. err)
99    else
100      local h = hash.create()
101      h:update(data)
102      local obj = parser:get_object()
103
104      if check_version(obj) then
105
106        if obj['symbols'] then
107          process_symbols(obj['symbols'], global_priority)
108        end
109        if obj['actions'] then
110          process_actions(obj['actions'], global_priority)
111        end
112        if allow_rules and obj['rules'] then
113          process_rules(obj['rules'])
114        end
115
116        rspamd_logger.infox(rspamd_config, 'loaded new rules with hash "%s"',
117          h:hex())
118      end
119    end
120
121    return res
122  end
123end
124
125-- Configuration part
126local section = rspamd_config:get_all_opt("rspamd_update")
127if section and section.rules then
128  local trusted_key
129  if section.key then
130    trusted_key = section.key
131  end
132
133  if type(section.rules) ~= 'table' then
134    section.rules = {section.rules}
135  end
136
137  fun.each(function(elt)
138    local map = rspamd_config:add_map(elt, "rspamd updates map", nil, "callback")
139    if not map then
140      rspamd_logger.errx(rspamd_config, 'cannot load updates from %1', elt)
141    else
142      map:set_callback(gen_callback(map))
143      maps['elt'] = map
144    end
145  end, section.rules)
146
147  fun.each(function(k, map)
148    -- Check sanity for maps
149    local proto = map:get_proto()
150    if (proto == 'http' or proto == 'https') and not map:get_sign_key() then
151      if trusted_key then
152        map:set_sign_key(trusted_key)
153      else
154        rspamd_logger.warnx(rspamd_config, 'Map %s is loaded by HTTP and it is not signed', k)
155      end
156    end
157  end, maps)
158else
159  rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
160  lua_util.disable_module(N, "config")
161end
162