1
2local lfs = require "lfs";
3local resolve_relative_path = require "core.configmanager".resolve_relative_path;
4local envload = require "util.envload".envload;
5local logger = require "util.logger".init;
6local it = require "util.iterators";
7local set = require "util.set";
8
9-- [definition_type] = definition_factory(param)
10local definitions = module:shared("definitions");
11
12-- When a definition instance has been instantiated, it lives here
13-- [definition_type][definition_name] = definition_object
14local active_definitions = {
15	ZONE = {
16		-- Default zone that includes all local hosts
17		["$local"] = setmetatable({}, { __index = prosody.hosts });
18	};
19};
20
21local default_chains = {
22	preroute = {
23		type = "event";
24		priority = 0.1;
25		"pre-message/bare", "pre-message/full", "pre-message/host";
26		"pre-presence/bare", "pre-presence/full", "pre-presence/host";
27		"pre-iq/bare", "pre-iq/full", "pre-iq/host";
28	};
29	deliver = {
30		type = "event";
31		priority = 0.1;
32		"message/bare", "message/full", "message/host";
33		"presence/bare", "presence/full", "presence/host";
34		"iq/bare", "iq/full", "iq/host";
35	};
36	deliver_remote = {
37		type = "event"; "route/remote";
38		priority = 0.1;
39	};
40};
41
42local extra_chains = module:get_option("firewall_extra_chains", {});
43
44local chains = {};
45for k,v in pairs(default_chains) do
46	chains[k] = v;
47end
48for k,v in pairs(extra_chains) do
49	chains[k] = v;
50end
51
52-- Returns the input if it is safe to be used as a variable name, otherwise nil
53function idsafe(name)
54	return name:match("^%a[%w_]*$");
55end
56
57local meta_funcs = {
58	bare = function (code)
59		return "jid_bare("..code..")", {"jid_bare"};
60	end;
61	node = function (code)
62		return "(jid_split("..code.."))", {"jid_split"};
63	end;
64	host = function (code)
65		return "(select(2, jid_split("..code..")))", {"jid_split"};
66	end;
67	resource = function (code)
68		return "(select(3, jid_split("..code..")))", {"jid_split"};
69	end;
70};
71
72-- Run quoted (%q) strings through this to allow them to contain code. e.g.: LOG=Received: $(stanza:top_tag())
73function meta(s, deps, extra)
74	return (s:gsub("$(%b())", function (expr)
75			expr = expr:gsub("\\(.)", "%1");
76			return [["..tostring(]]..expr..[[).."]];
77		end)
78		:gsub("$(%b<>)", function (expr)
79			expr = expr:sub(2,-2);
80			local default = "<undefined>";
81			expr = expr:gsub("||(%b\"\")$", function (default_string)
82				default = stripslashes(default_string:sub(2,-2));
83				return "";
84			end);
85			local func_chain = expr:match("|[%w|]+$");
86			if func_chain then
87				expr = expr:sub(1, -1-#func_chain);
88			end
89			local code;
90			if expr:match("^@") then
91				-- Skip stanza:find() for simple attribute lookup
92				local attr_name = expr:sub(2);
93				if deps and (attr_name == "to" or attr_name == "from" or attr_name == "type") then
94					-- These attributes may be cached in locals
95					code = attr_name;
96					table.insert(deps, attr_name);
97				else
98					code = "stanza.attr["..("%q"):format(attr_name).."]";
99				end
100			elseif expr:match("^%w+#$") then
101				code = ("stanza:get_child_text(%q)"):format(expr:sub(1, -2));
102			else
103				code = ("stanza:find(%q)"):format(expr);
104			end
105			if func_chain then
106				for func_name in func_chain:gmatch("|(%w+)") do
107					-- to/from are already available in local variables, use those if possible
108					if (code == "to" or code == "from") and func_name == "bare" then
109						code = "bare_"..code;
110						table.insert(deps, code);
111					elseif (code == "to" or code == "from") and (func_name == "node" or func_name == "host" or func_name == "resource") then
112						table.insert(deps, "split_"..code);
113						code = code.."_"..func_name;
114					else
115						assert(meta_funcs[func_name], "unknown function: "..func_name);
116						local new_code, new_deps = meta_funcs[func_name](code);
117						code = new_code;
118						if new_deps and #new_deps > 0 then
119							assert(deps, "function not supported here: "..func_name);
120							for _, dep in ipairs(new_deps) do
121								table.insert(deps, dep);
122							end
123						end
124					end
125				end
126			end
127			return "\"..tostring("..code.." or "..("%q"):format(default)..")..\"";
128		end)
129		:gsub("$$(%a+)", extra or {})
130		:gsub([[^""%.%.]], "")
131		:gsub([[%.%.""$]], ""));
132end
133
134function metaq(s, ...)
135	return meta(("%q"):format(s), ...);
136end
137
138local escape_chars = {
139	a = "\a", b = "\b", f = "\f", n = "\n", r = "\r", t = "\t",
140	v = "\v", ["\\"] = "\\", ["\""] = "\"", ["\'"] = "\'"
141};
142function stripslashes(s)
143	return (s:gsub("\\(.)", escape_chars));
144end
145
146-- Dependency locations:
147-- <type lib>
148-- <type global>
149-- function handler()
150--   <local deps>
151--   if <conditions> then
152--     <actions>
153--   end
154-- end
155
156local available_deps = {
157	st = { global_code = [[local st = require "util.stanza";]]};
158	it = { global_code = [[local it = require "util.iterators";]]};
159	it_count = { global_code = [[local it_count = it.count;]], depends = { "it" } };
160	current_host = { global_code = [[local current_host = module.host;]] };
161	jid_split = {
162		global_code = [[local jid_split = require "util.jid".split;]];
163	};
164	jid_bare = {
165		global_code = [[local jid_bare = require "util.jid".bare;]];
166	};
167	to = { local_code = [[local to = stanza.attr.to or jid_bare(session.full_jid);]]; depends = { "jid_bare" } };
168	from = { local_code = [[local from = stanza.attr.from;]] };
169	type = { local_code = [[local type = stanza.attr.type;]] };
170	name = { local_code = [[local name = stanza.name;]] };
171	split_to = { -- The stanza's split to address
172		depends = { "jid_split", "to" };
173		local_code = [[local to_node, to_host, to_resource = jid_split(to);]];
174	};
175	split_from = { -- The stanza's split from address
176		depends = { "jid_split", "from" };
177		local_code = [[local from_node, from_host, from_resource = jid_split(from);]];
178	};
179	bare_to = { depends = { "jid_bare", "to" }, local_code = "local bare_to = jid_bare(to)"};
180	bare_from = { depends = { "jid_bare", "from" }, local_code = "local bare_from = jid_bare(from)"};
181	group_contains = {
182		global_code = [[local group_contains = module:depends("groups").group_contains]];
183	};
184	is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin;]]};
185	core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza;]] };
186	zone = { global_code = function (zone)
187		local var = zone;
188		if var == "$local" then
189			var = "_local"; -- See #1090
190		else
191			assert(idsafe(var), "Invalid zone name: "..zone);
192		end
193		return ("local zone_%s = zones[%q] or {};"):format(var, zone);
194	end };
195	date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] };
196	time = { local_code = function (what)
197		local defs = {};
198		for field in what:gmatch("%a+") do
199			table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field));
200		end
201		return table.concat(defs, " ");
202	end, depends = { "date_time" }; };
203	timestamp = { global_code = [[local get_time = require "socket".gettime;]]; local_code = [[local current_timestamp = get_time();]]; };
204	globalthrottle = {
205		global_code = function (throttle)
206			assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
207			assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle);
208			return ("local global_throttle_%s = rates.%s:single();"):format(throttle, throttle);
209		end;
210	};
211	multithrottle = {
212		global_code = function (throttle)
213			assert(pcall(require, "util.cache"), "Using LIMIT with 'on' requires Prosody 0.10 or higher");
214			assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
215			assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle);
216			return ("local multi_throttle_%s = rates.%s:multi();"):format(throttle, throttle);
217		end;
218	};
219	rostermanager = {
220		global_code = [[local rostermanager = require "core.rostermanager";]];
221	};
222	roster_entry = {
223		local_code = [[local roster_entry = (to_node and rostermanager.load_roster(to_node, to_host) or {})[bare_from];]];
224		depends = { "rostermanager", "split_to", "bare_from" };
225	};
226	list = { global_code = function (list)
227			assert(idsafe(list), "Invalid list name: "..list);
228			assert(active_definitions.LIST[list], "Unknown list: "..list);
229			return ("local list_%s = lists[%q];"):format(list, list);
230		end
231	};
232	search = {
233		local_code = function (search_name)
234			local search_path = assert(active_definitions.SEARCH[search_name], "Undefined search path: "..search_name);
235			return ("local search_%s = tostring(stanza:find(%q) or \"\")"):format(search_name, search_path);
236		end;
237	};
238	pattern = {
239		local_code = function (pattern_name)
240			local pattern = assert(active_definitions.PATTERN[pattern_name], "Undefined pattern: "..pattern_name);
241			return ("local pattern_%s = %q"):format(pattern_name, pattern);
242		end;
243	};
244	tokens = {
245		local_code = function (search_and_pattern)
246			local search_name, pattern_name = search_and_pattern:match("^([^%-]+)-(.+)$");
247			local code = ([[local tokens_%s_%s = {};
248			if search_%s then
249				for s in search_%s:gmatch(pattern_%s) do
250					tokens_%s_%s[s] = true;
251				end
252			end
253			]]):format(search_name, pattern_name, search_name, search_name, pattern_name, search_name, pattern_name);
254			return code, { "search:"..search_name, "pattern:"..pattern_name };
255		end;
256	};
257	scan_list = {
258		global_code = [[local function scan_list(list, items) for item in pairs(items) do if list:contains(item) then return true; end end end]];
259	}
260};
261
262local function include_dep(dependency, code)
263	local dep, dep_param = dependency:match("^([^:]+):?(.*)$");
264	local dep_info = available_deps[dep];
265	if not dep_info then
266		module:log("error", "Dependency not found: %s", dep);
267		return;
268	end
269	if code.included_deps[dependency] ~= nil then
270		if code.included_deps[dependency] ~= true then
271			module:log("error", "Circular dependency on %s", dep);
272		end
273		return;
274	end
275	code.included_deps[dependency] = false; -- Pending flag (used to detect circular references)
276	for _, dep_dep in ipairs(dep_info.depends or {}) do
277		include_dep(dep_dep, code);
278	end
279	if dep_info.global_code then
280		if dep_param ~= "" then
281			local global_code, deps = dep_info.global_code(dep_param);
282			if deps then
283				for _, dep_dep in ipairs(deps) do
284					include_dep(dep_dep, code);
285				end
286			end
287			table.insert(code.global_header, global_code);
288		else
289			table.insert(code.global_header, dep_info.global_code);
290		end
291	end
292	if dep_info.local_code then
293		if dep_param ~= "" then
294			local local_code, deps = dep_info.local_code(dep_param);
295			if deps then
296				for _, dep_dep in ipairs(deps) do
297					include_dep(dep_dep, code);
298				end
299			end
300			table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..local_code.."\n");
301		else
302			table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..dep_info.local_code.."\n");
303		end
304	end
305	code.included_deps[dependency] = true;
306end
307
308local definition_handlers = module:require("definitions");
309local condition_handlers = module:require("conditions");
310local action_handlers = module:require("actions");
311
312if module:get_option_boolean("firewall_experimental_user_marks", false) then
313	module:require"marks";
314end
315
316local function new_rule(ruleset, chain)
317	assert(chain, "no chain specified");
318	local rule = { conditions = {}, actions = {}, deps = {} };
319	table.insert(ruleset[chain], rule);
320	return rule;
321end
322
323local function parse_firewall_rules(filename)
324	local line_no = 0;
325
326	local function errmsg(err)
327		return "Error compiling "..filename.." on line "..line_no..": "..err;
328	end
329
330	local ruleset = {
331		deliver = {};
332	};
333
334	local chain = "deliver"; -- Default chain
335	local rule;
336
337	local file, err = io.open(filename);
338	if not file then return nil, err; end
339
340	local state; -- nil -> "rules" -> "actions" -> nil -> ...
341
342	local line_hold;
343	for line in file:lines() do
344		line = line:match("^%s*(.-)%s*$");
345		if line_hold and line:sub(-1,-1) ~= "\\" then
346			line = line_hold..line;
347			line_hold = nil;
348		elseif line:sub(-1,-1) == "\\" then
349			line_hold = (line_hold or "")..line:sub(1,-2);
350		end
351		line_no = line_no + 1;
352
353		if line_hold or line:find("^[#;]") then -- luacheck: ignore 542
354			-- No action; comment or partial line
355		elseif line == "" then
356			if state == "rules" then
357				return nil, ("Expected an action on line %d for preceding criteria")
358					:format(line_no);
359			end
360			state = nil;
361		elseif not(state) and line:sub(1, 2) == "::" then
362			chain = line:gsub("^::%s*", "");
363			local chain_info = chains[chain];
364			if not chain_info then
365				if chain:match("^user/") then
366					chains[chain] = { type = "event", priority = 1, pass_return = false };
367				else
368					return nil, errmsg("Unknown chain: "..chain);
369				end
370			elseif chain_info.type ~= "event" then
371				return nil, errmsg("Only event chains supported at the moment");
372			end
373			ruleset[chain] = ruleset[chain] or {};
374		elseif not(state) and line:sub(1,1) == "%" then -- Definition (zone, limit, etc.)
375			local what, name = line:match("^%%%s*([%w_]+) +([^ :]+)");
376			if not definition_handlers[what] then
377				return nil, errmsg("Definition of unknown object: "..what);
378			elseif not name or not idsafe(name) then
379				return nil, errmsg("Invalid "..what.." name");
380			end
381
382			local val = line:match(": ?(.*)$");
383			if not val and line:find(":<") then -- Read from file
384				local fn = line:match(":< ?(.-)%s*$");
385				if not fn then
386					return nil, errmsg("Unable to parse filename");
387				end
388				local f, err = io.open(fn);
389				if not f then return nil, errmsg(err); end
390				val = f:read("*a"):gsub("\r?\n", " "):gsub("%s+$", "");
391			end
392			if not val then
393				return nil, errmsg("No value given for definition");
394			end
395			val = stripslashes(val);
396			local ok, ret = pcall(definition_handlers[what], name, val);
397			if not ok then
398				return nil, errmsg(ret);
399			end
400
401			if not active_definitions[what] then
402				active_definitions[what] = {};
403			end
404			active_definitions[what][name] = ret;
405		elseif line:find("^[%w_ ]+[%.=]") then
406			-- Action
407			if state == nil then
408				-- This is a standalone action with no conditions
409				rule = new_rule(ruleset, chain);
410			end
411			state = "actions";
412			-- Action handlers?
413			local action = line:match("^[%w_ ]+"):upper():gsub(" ", "_");
414			if not action_handlers[action] then
415				return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>");
416			end
417			table.insert(rule.actions, "-- "..line)
418			local ok, action_string, action_deps = pcall(action_handlers[action], line:match("=(.+)$"));
419			if not ok then
420				return nil, errmsg(action_string);
421			end
422			table.insert(rule.actions, action_string);
423			for _, dep in ipairs(action_deps or {}) do
424				table.insert(rule.deps, dep);
425			end
426		elseif state == "actions" then -- state is actions but action pattern did not match
427			state = nil; -- Awaiting next rule, etc.
428			table.insert(ruleset[chain], rule);
429			rule = nil;
430		else
431			if not state then
432				state = "rules";
433				rule = new_rule(ruleset, chain);
434			end
435			-- Check standard modifiers for the condition (e.g. NOT)
436			local negated;
437			local condition = line:match("^[^:=%.?]*");
438			if condition:find("%f[%w]NOT%f[^%w]") then
439				local s, e = condition:match("%f[%w]()NOT()%f[^%w]");
440				condition = (condition:sub(1,s-1)..condition:sub(e+1, -1)):match("^%s*(.-)%s*$");
441				negated = true;
442			end
443			condition = condition:gsub(" ", "_");
444			if not condition_handlers[condition] then
445				return nil, ("Unknown condition on line %d: %s"):format(line_no, (condition:gsub("_", " ")));
446			end
447			-- Get the code for this condition
448			local ok, condition_code, condition_deps = pcall(condition_handlers[condition], line:match(":%s?(.+)$"));
449			if not ok then
450				return nil, errmsg(condition_code);
451			end
452			if negated then condition_code = "not("..condition_code..")"; end
453			table.insert(rule.conditions, condition_code);
454			for _, dep in ipairs(condition_deps or {}) do
455				table.insert(rule.deps, dep);
456			end
457		end
458	end
459	return ruleset;
460end
461
462local function process_firewall_rules(ruleset)
463	-- Compile ruleset and return complete code
464
465	local chain_handlers = {};
466
467	-- Loop through the chains in the parsed ruleset (e.g. incoming, outgoing)
468	for chain_name, rules in pairs(ruleset) do
469		local code = { included_deps = {}, global_header = {} };
470		local condition_uses = {};
471		-- This inner loop assumes chain is an event-based, not a filter-based
472		-- chain (filter-based will be added later)
473		for _, rule in ipairs(rules) do
474			for _, condition in ipairs(rule.conditions) do
475				if condition:find("^not%(.+%)$") then
476					condition = condition:match("^not%((.+)%)$");
477				end
478				condition_uses[condition] = (condition_uses[condition] or 0) + 1;
479			end
480		end
481
482		local condition_cache, n_conditions = {}, 0;
483		for _, rule in ipairs(rules) do
484			for _, dep in ipairs(rule.deps) do
485				include_dep(dep, code);
486			end
487			table.insert(code, "\n\t\t");
488			local rule_code;
489			if #rule.conditions > 0 then
490				for i, condition in ipairs(rule.conditions) do
491					local negated = condition:match("^not%(.+%)$");
492					if negated then
493						condition = condition:match("^not%((.+)%)$");
494					end
495					if condition_uses[condition] > 1 then
496						local name = condition_cache[condition];
497						if not name then
498							n_conditions = n_conditions + 1;
499							name = "condition"..n_conditions;
500							condition_cache[condition] = name;
501							table.insert(code, "local "..name.." = "..condition..";\n\t\t");
502						end
503						rule.conditions[i] = (negated and "not(" or "")..name..(negated and ")" or "");
504					else
505						rule.conditions[i] = (negated and "not(" or "(")..condition..")";
506					end
507				end
508
509				rule_code = "if "..table.concat(rule.conditions, " and ").." then\n\t\t\t"
510					..table.concat(rule.actions, "\n\t\t\t")
511					.."\n\t\tend\n";
512			else
513				rule_code = table.concat(rule.actions, "\n\t\t");
514			end
515			table.insert(code, rule_code);
516		end
517
518		for name in pairs(definition_handlers) do
519			table.insert(code.global_header, 1, "local "..name:lower().."s = definitions."..name..";");
520		end
521
522		local code_string = "return function (definitions, fire_event, log, module, pass_return)\n\t"
523			..table.concat(code.global_header, "\n\t")
524			.."\n\tlocal db = require 'util.debug';\n\n\t"
525			.."return function (event)\n\t\t"
526			.."local stanza, session = event.stanza, event.origin;\n"
527			..table.concat(code, "")
528			.."\n\tend;\nend";
529
530		chain_handlers[chain_name] = code_string;
531	end
532
533	return chain_handlers;
534end
535
536local function compile_firewall_rules(filename)
537	local ruleset, err = parse_firewall_rules(filename);
538	if not ruleset then return nil, err; end
539	local chain_handlers = process_firewall_rules(ruleset);
540	return chain_handlers;
541end
542
543-- Compile handler code into a factory that produces a valid event handler. Factory accepts
544-- a value to be returned on PASS
545local function compile_handler(code_string, filename)
546	-- Prepare event handler function
547	local chunk, err = envload(code_string, "="..filename, _G);
548	if not chunk then
549		return nil, "Error compiling (probably a compiler bug, please report): "..err;
550	end
551	local function fire_event(name, data)
552		return module:fire_event(name, data);
553	end
554	return function (pass_return)
555		return chunk()(active_definitions, fire_event, logger(filename), module, pass_return); -- Returns event handler with upvalues
556	end
557end
558
559local function resolve_script_path(script_path)
560	local relative_to = prosody.paths.config;
561	if script_path:match("^module:") then
562		relative_to = module.path:sub(1, -#("/mod_"..module.name..".lua"));
563		script_path = script_path:match("^module:(.+)$");
564	end
565	return resolve_relative_path(relative_to, script_path);
566end
567
568-- [filename] = { last_modified = ..., events_hooked = { [name] = handler } }
569local loaded_scripts = {};
570
571function load_script(script)
572	script = resolve_script_path(script);
573	local last_modified = (lfs.attributes(script) or {}).modification or os.time();
574	if loaded_scripts[script] then
575		if loaded_scripts[script].last_modified == last_modified then
576			return; -- Already loaded, and source file hasn't changed
577		end
578		module:log("debug", "Reloading %s", script);
579		-- Already loaded, but the source file has changed
580		-- unload it now, and we'll load the new version below
581		unload_script(script, true);
582	end
583	local chain_functions, err = compile_firewall_rules(script);
584
585	if not chain_functions then
586		module:log("error", "Error compiling %s: %s", script, err or "unknown error");
587		return;
588	end
589
590	-- Loop through the chains in the script, and for each chain attach the compiled code to the
591	-- relevant events, keeping track in events_hooked so we can cleanly unload later
592	local events_hooked = {};
593	for chain, handler_code in pairs(chain_functions) do
594		local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
595		if not new_handler then
596			module:log("error", "Compilation error for %s: %s", script, err);
597		else
598			local chain_definition = chains[chain];
599			if chain_definition and chain_definition.type == "event" then
600				local handler = new_handler(chain_definition.pass_return);
601				for _, event_name in ipairs(chain_definition) do
602					events_hooked[event_name] = handler;
603					module:hook(event_name, handler, chain_definition.priority);
604				end
605			elseif not chain:sub(1, 5) == "user/" then
606				module:log("warn", "Unknown chain %q", chain);
607			end
608			local event_name, handler = "firewall/chains/"..chain, new_handler(false);
609			events_hooked[event_name] = handler;
610			module:hook(event_name, handler);
611		end
612	end
613	loaded_scripts[script] = { last_modified = last_modified, events_hooked = events_hooked };
614	module:log("debug", "Loaded %s", script);
615end
616
617--COMPAT w/0.9 (no module:unhook()!)
618local function module_unhook(event, handler)
619	return module:unhook_object_event((hosts[module.host] or prosody).events, event, handler);
620end
621
622function unload_script(script, is_reload)
623	script = resolve_script_path(script);
624	local script_info = loaded_scripts[script];
625	if not script_info then
626		return; -- Script not loaded
627	end
628	local events_hooked = script_info.events_hooked;
629	for event_name, event_handler in pairs(events_hooked) do
630		module_unhook(event_name, event_handler);
631		events_hooked[event_name] = nil;
632	end
633	loaded_scripts[script] = nil;
634	if not is_reload then
635		module:log("debug", "Unloaded %s", script);
636	end
637end
638
639-- Given a set of scripts (e.g. from config) figure out which ones need to
640-- be loaded, which are already loaded but need unloading, and which to reload
641function load_unload_scripts(script_list)
642	local wanted_scripts = script_list / resolve_script_path;
643	local currently_loaded = set.new(it.to_array(it.keys(loaded_scripts)));
644	local scripts_to_unload = currently_loaded - wanted_scripts;
645	for script in wanted_scripts do
646		-- If the script is already loaded, this is fine - it will
647		-- reload the script for us if the file has changed
648		load_script(script);
649	end
650	for script in scripts_to_unload do
651		unload_script(script);
652	end
653end
654
655function module.load()
656	if not prosody.arg then return end -- Don't run in prosodyctl
657	local firewall_scripts = module:get_option_set("firewall_scripts", {});
658	load_unload_scripts(firewall_scripts);
659	-- Replace contents of definitions table (shared) with active definitions
660	for k in it.keys(definitions) do definitions[k] = nil; end
661	for k,v in pairs(active_definitions) do definitions[k] = v; end
662end
663
664function module.save()
665	return { active_definitions = active_definitions, loaded_scripts = loaded_scripts };
666end
667
668function module.restore(state)
669	active_definitions = state.active_definitions;
670	loaded_scripts = state.loaded_scripts;
671end
672
673module:hook_global("config-reloaded", function ()
674	load_unload_scripts(module:get_option_set("firewall_scripts", {}));
675end);
676
677function module.command(arg)
678	if not arg[1] or arg[1] == "--help" then
679		require"util.prosodyctl".show_usage([[mod_firewall <firewall.pfw>]], [[Compile files with firewall rules to Lua code]]);
680		return 1;
681	end
682	local verbose = arg[1] == "-v";
683	if verbose then table.remove(arg, 1); end
684
685	if arg[1] == "test" then
686		table.remove(arg, 1);
687		return module:require("test")(arg);
688	end
689
690	local serialize = require "util.serialization".serialize;
691	if verbose then
692		print("local logger = require \"util.logger\".init;");
693		print();
694		print("local function fire_event(name, data)\n\tmodule:fire_event(name, data)\nend");
695		print();
696	end
697
698	for _, filename in ipairs(arg) do
699		filename = resolve_script_path(filename);
700		print("do -- File "..filename);
701		local chain_functions = assert(compile_firewall_rules(filename));
702		if verbose then
703			print();
704			print("local active_definitions = "..serialize(active_definitions)..";");
705			print();
706		end
707		local c = 0;
708		for chain, handler_code in pairs(chain_functions) do
709			c = c + 1;
710			print("---- Chain "..chain:gsub("_", " "));
711			local chain_func_name = "chain_"..tostring(c).."_"..chain:gsub("%p", "_");
712			if not verbose then
713				print(("%s = %s;"):format(chain_func_name, handler_code:sub(8)));
714			else
715
716				print(("local %s = (%s)(active_definitions, fire_event, logger(%q));"):format(chain_func_name, handler_code:sub(8), filename));
717				print();
718
719				local chain_definition = chains[chain];
720				if chain_definition and chain_definition.type == "event" then
721					for _, event_name in ipairs(chain_definition) do
722						print(("module:hook(%q, %s, %d);"):format(event_name, chain_func_name, chain_definition.priority or 0));
723					end
724				end
725				print(("module:hook(%q, %s, %d);"):format("firewall/chains/"..chain, chain_func_name, chain_definition.priority or 0));
726			end
727
728			print("---- End of chain "..chain);
729			print();
730		end
731		print("end -- End of file "..filename);
732	end
733end
734