1 2-- Name arguments are unused here 3-- luacheck: ignore 212 4 5local definition_handlers = {}; 6 7local http = require "net.http"; 8local timer = require "util.timer"; 9local set = require"util.set"; 10local new_throttle = require "util.throttle".create; 11local hashes = require "util.hashes"; 12local jid = require "util.jid"; 13local lfs = require "lfs"; 14 15local multirate_cache_size = module:get_option_number("firewall_multirate_cache_limit", 1000); 16 17function definition_handlers.ZONE(zone_name, zone_members) 18 local zone_member_list = {}; 19 for member in zone_members:gmatch("[^, ]+") do 20 zone_member_list[#zone_member_list+1] = member; 21 end 22 return set.new(zone_member_list)._items; 23end 24 25-- Helper function used by RATE handler 26local function evict_only_unthrottled(name, throttle) 27 throttle:update(); 28 -- Check whether the throttle is at max balance (i.e. totally safe to forget about it) 29 if throttle.balance < throttle.max then 30 -- Not safe to forget 31 return false; 32 end 33end 34 35function definition_handlers.RATE(name, line) 36 local rate = assert(tonumber(line:match("([%d.]+)")), "Unable to parse rate"); 37 local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1; 38 local max_throttles = tonumber(line:match("%(%s*entries%s+([%d]+)%s*%)")) or multirate_cache_size; 39 local deny_when_full = not line:match("%(allow overflow%)"); 40 return { 41 single = function () 42 return new_throttle(rate*burst, burst); 43 end; 44 45 multi = function () 46 local cache = require "util.cache".new(max_throttles, deny_when_full and evict_only_unthrottled or nil); 47 return { 48 poll_on = function (_, key, amount) 49 assert(key, "no key"); 50 local throttle = cache:get(key); 51 if not throttle then 52 throttle = new_throttle(rate*burst, burst); 53 if not cache:set(key, throttle) then 54 module:log("warn", "Multirate '%s' has hit its maximum number of active throttles (%d), denying new events", name, max_throttles); 55 return false; 56 end 57 end 58 return throttle:poll(amount); 59 end; 60 } 61 end; 62 }; 63end 64 65local list_backends = { 66 -- %LIST name: memory (limit: number) 67 memory = { 68 init = function (self, type, opts) 69 if opts.limit then 70 local have_cache_lib, cache_lib = pcall(require, "util.cache"); 71 if not have_cache_lib then 72 error("In-memory lists with a size limit require Prosody 0.10"); 73 end 74 self.cache = cache_lib.new((assert(tonumber(opts.limit), "Invalid list limit"))); 75 if not self.cache.table then 76 error("In-memory lists with a size limit require a newer version of Prosody 0.10"); 77 end 78 self.items = self.cache:table(); 79 else 80 self.items = {}; 81 end 82 end; 83 add = function (self, item) 84 self.items[item] = true; 85 end; 86 remove = function (self, item) 87 self.items[item] = nil; 88 end; 89 contains = function (self, item) 90 return self.items[item] == true; 91 end; 92 }; 93 94 -- %LIST name: http://example.com/ (ttl: number, pattern: pat, hash: sha1) 95 http = { 96 init = function (self, url, opts) 97 local poll_interval = assert(tonumber(opts.ttl or "3600"), "invalid ttl for <"..url.."> (expected number of seconds)"); 98 local pattern = opts.pattern or "([^\r\n]+)\r?\n"; 99 assert(pcall(string.match, "", pattern), "invalid pattern for <"..url..">"); 100 if opts.hash then 101 assert(opts.hash:match("^%w+$") and type(hashes[opts.hash]) == "function", "invalid hash function: "..opts.hash); 102 self.hash_function = hashes[opts.hash]; 103 end 104 local etag; 105 local failure_count = 0; 106 local retry_intervals = { 60, 120, 300 }; 107 -- By default only check the certificate if net.http supports SNI 108 local sni_supported = http.feature and http.features.sni; 109 local insecure = false; 110 if opts.checkcert == "never" then 111 insecure = true; 112 elseif (opts.checkcert == nil or opts.checkcert == "when-sni") and not sni_supported then 113 insecure = false; 114 end 115 local function update_list() 116 http.request(url, { 117 insecure = insecure; 118 headers = { 119 ["If-None-Match"] = etag; 120 }; 121 }, function (body, code, response) 122 local next_poll = poll_interval; 123 if code == 200 and body then 124 etag = response.headers.etag; 125 local items = {}; 126 for entry in body:gmatch(pattern) do 127 items[entry] = true; 128 end 129 self.items = items; 130 module:log("debug", "Fetched updated list from <%s>", url); 131 elseif code == 304 then 132 module:log("debug", "List at <%s> is unchanged", url); 133 elseif code == 0 or (code >= 400 and code <=599) then 134 module:log("warn", "Failed to fetch list from <%s>: %d %s", url, code, tostring(body)); 135 failure_count = failure_count + 1; 136 next_poll = retry_intervals[failure_count] or retry_intervals[#retry_intervals]; 137 end 138 if next_poll > 0 then 139 timer.add_task(next_poll+math.random(0, 60), update_list); 140 end 141 end); 142 end 143 update_list(); 144 end; 145 add = function () 146 end; 147 remove = function () 148 end; 149 contains = function (self, item) 150 if self.hash_function then 151 item = self.hash_function(item); 152 end 153 return self.items and self.items[item] == true; 154 end; 155 }; 156 157 -- %LIST: file:/path/to/file 158 file = { 159 init = function (self, file_spec, opts) 160 local n, items = 0, {}; 161 self.items = items; 162 local filename = file_spec:gsub("^file:", ""); 163 if opts.missing == "ignore" and not lfs.attributes(filename, "mode") then 164 module:log("debug", "Ignoring missing list file: %s", filename); 165 return; 166 end 167 local file, err = io.open(filename); 168 if not file then 169 module:log("warn", "Failed to open list from %s: %s", filename, err); 170 return; 171 else 172 for line in file:lines() do 173 if not items[line] then 174 n = n + 1; 175 items[line] = true; 176 end 177 end 178 end 179 module:log("debug", "Loaded %d items from %s", n, filename); 180 end; 181 add = function (self, item) 182 self.items[item] = true; 183 end; 184 remove = function (self, item) 185 self.items[item] = nil; 186 end; 187 contains = function (self, item) 188 return self.items and self.items[item] == true; 189 end; 190 }; 191 192 -- %LIST: pubsub:pubsub.example.com/node 193 -- TODO or the actual URI scheme? Bit overkill maybe? 194 -- TODO Publish items back to the service? 195 -- Step 1: Receiving pubsub events and storing them in the list 196 -- We'll start by using only the item id. 197 -- TODO Invent some custom schema for this? Needed for just a set of strings? 198 pubsubitemid = { 199 init = function(self, pubsub_spec, opts) 200 local service_addr, node = pubsub_spec:match("^([^/]*)/(.*)"); 201 module:depends("pubsub_subscription"); 202 module:add_item("pubsub-subscription", { 203 service = service_addr; 204 node = node; 205 on_subscribed = function () 206 self.items = {}; 207 end; 208 on_item = function (event) 209 self:add(event.item.attr.id); 210 end; 211 on_retract = function (event) 212 self:remove(event.item.attr.id); 213 end; 214 on_purge = function () 215 self.items = {}; 216 end; 217 on_unsubscribed = function () 218 self.items = nil; 219 end; 220 on_delete= function () 221 self.items = nil; 222 end; 223 }); 224 -- TODO Initial fetch? Or should mod_pubsub_subscription do this? 225 end; 226 add = function (self, item) 227 if self.items then 228 self.items[item] = true; 229 end 230 end; 231 remove = function (self, item) 232 if self.items then 233 self.items[item] = nil; 234 end 235 end; 236 contains = function (self, item) 237 return self.items and self.items[item] == true; 238 end; 239 }; 240}; 241list_backends.https = list_backends.http; 242 243local normalize_functions = { 244 upper = string.upper, lower = string.lower; 245 md5 = hashes.md5, sha1 = hashes.sha1, sha256 = hashes.sha256; 246 prep = jid.prep, bare = jid.bare; 247}; 248 249local function wrap_list_method(list_method, filter) 250 return function (self, item) 251 return list_method(self, filter(item)); 252 end 253end 254 255local function create_list(list_backend, list_def, opts) 256 if not list_backends[list_backend] then 257 error("Unknown list type '"..list_backend.."'", 0); 258 end 259 local list = setmetatable({}, { __index = list_backends[list_backend] }); 260 if list.init then 261 list:init(list_def, opts); 262 end 263 if opts.filter then 264 local filters = {}; 265 for func_name in opts.filter:gmatch("[%w_]+") do 266 if func_name == "log" then 267 table.insert(filters, function (s) 268 --print("&&&&&", s); 269 module:log("debug", "Checking list <%s> for: %s", list_def, s); 270 return s; 271 end); 272 else 273 assert(normalize_functions[func_name], "Unknown list filter: "..func_name); 274 table.insert(filters, normalize_functions[func_name]); 275 end 276 end 277 278 local filter; 279 local n = #filters; 280 if n == 1 then 281 filter = filters[1]; 282 else 283 function filter(s) 284 for i = 1, n do 285 s = filters[i](s or ""); 286 end 287 return s; 288 end 289 end 290 291 list.add = wrap_list_method(list.add, filter); 292 list.remove = wrap_list_method(list.remove, filter); 293 list.contains = wrap_list_method(list.contains, filter); 294 end 295 return list; 296end 297 298--[[ 299%LIST spammers: memory (source: /etc/spammers.txt) 300 301%LIST spammers: memory (source: /etc/spammers.txt) 302 303 304%LIST spammers: http://example.com/blacklist.txt 305]] 306 307function definition_handlers.LIST(list_name, list_definition) 308 local list_backend = list_definition:match("^%w+"); 309 local opts = {}; 310 local opt_string = list_definition:match("^%S+%s+%((.+)%)"); 311 if opt_string then 312 for opt_k, opt_v in opt_string:gmatch("(%w+): ?([^,]+)") do 313 opts[opt_k] = opt_v; 314 end 315 end 316 return create_list(list_backend, list_definition:match("^%S+"), opts); 317end 318 319function definition_handlers.PATTERN(name, pattern) 320 local ok, err = pcall(string.match, "", pattern); 321 if not ok then 322 error("Invalid pattern '"..name.."': "..err); 323 end 324 return pattern; 325end 326 327function definition_handlers.SEARCH(name, pattern) 328 return pattern; 329end 330 331return definition_handlers; 332