1
2-- Name arguments are unused here
3-- luacheck: ignore 212
4
5local definition_handlers = {};
6
7local http = require "net.http";
8local timer = require "util.timer";
9local set = require"util.set";
10local new_throttle = require "util.throttle".create;
11local hashes = require "util.hashes";
12local jid = require "util.jid";
13local lfs = require "lfs";
14
15local multirate_cache_size = module:get_option_number("firewall_multirate_cache_limit", 1000);
16
17function definition_handlers.ZONE(zone_name, zone_members)
18			local zone_member_list = {};
19			for member in zone_members:gmatch("[^, ]+") do
20				zone_member_list[#zone_member_list+1] = member;
21			end
22			return set.new(zone_member_list)._items;
23end
24
25-- Helper function used by RATE handler
26local function evict_only_unthrottled(name, throttle)
27	throttle:update();
28	-- Check whether the throttle is at max balance (i.e. totally safe to forget about it)
29	if throttle.balance < throttle.max then
30		-- Not safe to forget
31		return false;
32	end
33end
34
35function definition_handlers.RATE(name, line)
36			local rate = assert(tonumber(line:match("([%d.]+)")), "Unable to parse rate");
37			local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1;
38			local max_throttles = tonumber(line:match("%(%s*entries%s+([%d]+)%s*%)")) or multirate_cache_size;
39			local deny_when_full = not line:match("%(allow overflow%)");
40			return {
41				single = function ()
42					return new_throttle(rate*burst, burst);
43				end;
44
45				multi = function ()
46					local cache = require "util.cache".new(max_throttles, deny_when_full and evict_only_unthrottled or nil);
47					return {
48						poll_on = function (_, key, amount)
49							assert(key, "no key");
50							local throttle = cache:get(key);
51							if not throttle then
52								throttle = new_throttle(rate*burst, burst);
53								if not cache:set(key, throttle) then
54									module:log("warn", "Multirate '%s' has hit its maximum number of active throttles (%d), denying new events", name, max_throttles);
55									return false;
56								end
57							end
58							return throttle:poll(amount);
59						end;
60					}
61				end;
62			};
63end
64
65local list_backends = {
66	-- %LIST name: memory (limit: number)
67	memory = {
68		init = function (self, type, opts)
69			if opts.limit then
70				local have_cache_lib, cache_lib = pcall(require, "util.cache");
71				if not have_cache_lib then
72					error("In-memory lists with a size limit require Prosody 0.10");
73				end
74				self.cache = cache_lib.new((assert(tonumber(opts.limit), "Invalid list limit")));
75				if not self.cache.table then
76					error("In-memory lists with a size limit require a newer version of Prosody 0.10");
77				end
78				self.items = self.cache:table();
79			else
80				self.items = {};
81			end
82		end;
83		add = function (self, item)
84			self.items[item] = true;
85		end;
86		remove = function (self, item)
87			self.items[item] = nil;
88		end;
89		contains = function (self, item)
90			return self.items[item] == true;
91		end;
92	};
93
94	-- %LIST name: http://example.com/ (ttl: number, pattern: pat, hash: sha1)
95	http = {
96		init = function (self, url, opts)
97			local poll_interval = assert(tonumber(opts.ttl or "3600"), "invalid ttl for <"..url.."> (expected number of seconds)");
98			local pattern = opts.pattern or "([^\r\n]+)\r?\n";
99			assert(pcall(string.match, "", pattern), "invalid pattern for <"..url..">");
100			if opts.hash then
101				assert(opts.hash:match("^%w+$") and type(hashes[opts.hash]) == "function", "invalid hash function: "..opts.hash);
102				self.hash_function = hashes[opts.hash];
103			end
104			local etag;
105			local failure_count = 0;
106			local retry_intervals = { 60, 120, 300 };
107			-- By default only check the certificate if net.http supports SNI
108			local sni_supported = http.feature and http.features.sni;
109			local insecure = false;
110			if opts.checkcert == "never" then
111				insecure = true;
112			elseif (opts.checkcert == nil or opts.checkcert == "when-sni") and not sni_supported then
113				insecure = false;
114			end
115			local function update_list()
116				http.request(url, {
117					insecure = insecure;
118					headers = {
119						["If-None-Match"] = etag;
120					};
121				}, function (body, code, response)
122					local next_poll = poll_interval;
123					if code == 200 and body then
124						etag = response.headers.etag;
125						local items = {};
126						for entry in body:gmatch(pattern) do
127							items[entry] = true;
128						end
129						self.items = items;
130						module:log("debug", "Fetched updated list from <%s>", url);
131					elseif code == 304 then
132						module:log("debug", "List at <%s> is unchanged", url);
133					elseif code == 0 or (code >= 400 and code <=599) then
134						module:log("warn", "Failed to fetch list from <%s>: %d %s", url, code, tostring(body));
135						failure_count = failure_count + 1;
136						next_poll = retry_intervals[failure_count] or retry_intervals[#retry_intervals];
137					end
138					if next_poll > 0 then
139						timer.add_task(next_poll+math.random(0, 60), update_list);
140					end
141				end);
142			end
143			update_list();
144		end;
145		add = function ()
146		end;
147		remove = function ()
148		end;
149		contains = function (self, item)
150			if self.hash_function then
151				item = self.hash_function(item);
152			end
153			return self.items and self.items[item] == true;
154		end;
155	};
156
157	-- %LIST: file:/path/to/file
158	file = {
159		init = function (self, file_spec, opts)
160			local n, items = 0, {};
161			self.items = items;
162			local filename = file_spec:gsub("^file:", "");
163			if opts.missing == "ignore" and not lfs.attributes(filename, "mode") then
164				module:log("debug", "Ignoring missing list file: %s", filename);
165				return;
166			end
167			local file, err = io.open(filename);
168			if not file then
169				module:log("warn", "Failed to open list from %s: %s", filename, err);
170				return;
171			else
172				for line in file:lines() do
173					if not items[line] then
174						n = n + 1;
175						items[line] = true;
176					end
177				end
178			end
179			module:log("debug", "Loaded %d items from %s", n, filename);
180		end;
181		add = function (self, item)
182			self.items[item] = true;
183		end;
184		remove = function (self, item)
185			self.items[item] = nil;
186		end;
187		contains = function (self, item)
188			return self.items and self.items[item] == true;
189		end;
190	};
191
192	-- %LIST: pubsub:pubsub.example.com/node
193	-- TODO or the actual URI scheme? Bit overkill maybe?
194	-- TODO Publish items back to the service?
195	-- Step 1: Receiving pubsub events and storing them in the list
196	-- We'll start by using only the item id.
197	-- TODO Invent some custom schema for this? Needed for just a set of strings?
198	pubsubitemid = {
199		init = function(self, pubsub_spec, opts)
200			local service_addr, node = pubsub_spec:match("^([^/]*)/(.*)");
201			module:depends("pubsub_subscription");
202			module:add_item("pubsub-subscription", {
203					service = service_addr;
204					node = node;
205					on_subscribed = function ()
206						self.items = {};
207					end;
208					on_item = function (event)
209						self:add(event.item.attr.id);
210					end;
211					on_retract = function (event)
212						self:remove(event.item.attr.id);
213					end;
214					on_purge = function ()
215						self.items = {};
216					end;
217					on_unsubscribed = function ()
218						self.items = nil;
219					end;
220					on_delete= function ()
221						self.items = nil;
222					end;
223				});
224			-- TODO Initial fetch? Or should mod_pubsub_subscription do this?
225		end;
226		add = function (self, item)
227			if self.items then
228				self.items[item] = true;
229			end
230		end;
231		remove = function (self, item)
232			if self.items then
233				self.items[item] = nil;
234			end
235		end;
236		contains = function (self, item)
237			return self.items and self.items[item] == true;
238		end;
239	};
240};
241list_backends.https = list_backends.http;
242
243local normalize_functions = {
244	upper = string.upper, lower = string.lower;
245	md5 = hashes.md5, sha1 = hashes.sha1, sha256 = hashes.sha256;
246	prep = jid.prep, bare = jid.bare;
247};
248
249local function wrap_list_method(list_method, filter)
250	return function (self, item)
251		return list_method(self, filter(item));
252	end
253end
254
255local function create_list(list_backend, list_def, opts)
256	if not list_backends[list_backend] then
257		error("Unknown list type '"..list_backend.."'", 0);
258	end
259	local list = setmetatable({}, { __index = list_backends[list_backend] });
260	if list.init then
261		list:init(list_def, opts);
262	end
263	if opts.filter then
264		local filters = {};
265		for func_name in opts.filter:gmatch("[%w_]+") do
266			if func_name == "log" then
267				table.insert(filters, function (s)
268					--print("&&&&&", s);
269					module:log("debug", "Checking list <%s> for: %s", list_def, s);
270					return s;
271				end);
272			else
273				assert(normalize_functions[func_name], "Unknown list filter: "..func_name);
274				table.insert(filters, normalize_functions[func_name]);
275			end
276		end
277
278		local filter;
279		local n = #filters;
280		if n == 1 then
281			filter = filters[1];
282		else
283			function filter(s)
284				for i = 1, n do
285					s = filters[i](s or "");
286				end
287				return s;
288			end
289		end
290
291		list.add = wrap_list_method(list.add, filter);
292		list.remove = wrap_list_method(list.remove, filter);
293		list.contains = wrap_list_method(list.contains, filter);
294	end
295	return list;
296end
297
298--[[
299%LIST spammers: memory (source: /etc/spammers.txt)
300
301%LIST spammers: memory (source: /etc/spammers.txt)
302
303
304%LIST spammers: http://example.com/blacklist.txt
305]]
306
307function definition_handlers.LIST(list_name, list_definition)
308	local list_backend = list_definition:match("^%w+");
309	local opts = {};
310	local opt_string = list_definition:match("^%S+%s+%((.+)%)");
311	if opt_string then
312		for opt_k, opt_v in opt_string:gmatch("(%w+): ?([^,]+)") do
313			opts[opt_k] = opt_v;
314		end
315	end
316	return create_list(list_backend, list_definition:match("^%S+"), opts);
317end
318
319function definition_handlers.PATTERN(name, pattern)
320	local ok, err = pcall(string.match, "", pattern);
321	if not ok then
322		error("Invalid pattern '"..name.."': "..err);
323	end
324	return pattern;
325end
326
327function definition_handlers.SEARCH(name, pattern)
328	return pattern;
329end
330
331return definition_handlers;
332