1 2local lfs = require "lfs"; 3local resolve_relative_path = require "core.configmanager".resolve_relative_path; 4local envload = require "util.envload".envload; 5local logger = require "util.logger".init; 6local it = require "util.iterators"; 7local set = require "util.set"; 8 9-- [definition_type] = definition_factory(param) 10local definitions = module:shared("definitions"); 11 12-- When a definition instance has been instantiated, it lives here 13-- [definition_type][definition_name] = definition_object 14local active_definitions = { 15 ZONE = { 16 -- Default zone that includes all local hosts 17 ["$local"] = setmetatable({}, { __index = prosody.hosts }); 18 }; 19}; 20 21local default_chains = { 22 preroute = { 23 type = "event"; 24 priority = 0.1; 25 "pre-message/bare", "pre-message/full", "pre-message/host"; 26 "pre-presence/bare", "pre-presence/full", "pre-presence/host"; 27 "pre-iq/bare", "pre-iq/full", "pre-iq/host"; 28 }; 29 deliver = { 30 type = "event"; 31 priority = 0.1; 32 "message/bare", "message/full", "message/host"; 33 "presence/bare", "presence/full", "presence/host"; 34 "iq/bare", "iq/full", "iq/host"; 35 }; 36 deliver_remote = { 37 type = "event"; "route/remote"; 38 priority = 0.1; 39 }; 40}; 41 42local extra_chains = module:get_option("firewall_extra_chains", {}); 43 44local chains = {}; 45for k,v in pairs(default_chains) do 46 chains[k] = v; 47end 48for k,v in pairs(extra_chains) do 49 chains[k] = v; 50end 51 52-- Returns the input if it is safe to be used as a variable name, otherwise nil 53function idsafe(name) 54 return name:match("^%a[%w_]*$"); 55end 56 57local meta_funcs = { 58 bare = function (code) 59 return "jid_bare("..code..")", {"jid_bare"}; 60 end; 61 node = function (code) 62 return "(jid_split("..code.."))", {"jid_split"}; 63 end; 64 host = function (code) 65 return "(select(2, jid_split("..code..")))", {"jid_split"}; 66 end; 67 resource = function (code) 68 return "(select(3, jid_split("..code..")))", {"jid_split"}; 69 end; 70}; 71 72-- Run quoted (%q) strings through this to allow them to contain code. e.g.: LOG=Received: $(stanza:top_tag()) 73function meta(s, deps, extra) 74 return (s:gsub("$(%b())", function (expr) 75 expr = expr:gsub("\\(.)", "%1"); 76 return [["..tostring(]]..expr..[[).."]]; 77 end) 78 :gsub("$(%b<>)", function (expr) 79 expr = expr:sub(2,-2); 80 local default = "<undefined>"; 81 expr = expr:gsub("||(%b\"\")$", function (default_string) 82 default = stripslashes(default_string:sub(2,-2)); 83 return ""; 84 end); 85 local func_chain = expr:match("|[%w|]+$"); 86 if func_chain then 87 expr = expr:sub(1, -1-#func_chain); 88 end 89 local code; 90 if expr:match("^@") then 91 -- Skip stanza:find() for simple attribute lookup 92 local attr_name = expr:sub(2); 93 if deps and (attr_name == "to" or attr_name == "from" or attr_name == "type") then 94 -- These attributes may be cached in locals 95 code = attr_name; 96 table.insert(deps, attr_name); 97 else 98 code = "stanza.attr["..("%q"):format(attr_name).."]"; 99 end 100 elseif expr:match("^%w+#$") then 101 code = ("stanza:get_child_text(%q)"):format(expr:sub(1, -2)); 102 else 103 code = ("stanza:find(%q)"):format(expr); 104 end 105 if func_chain then 106 for func_name in func_chain:gmatch("|(%w+)") do 107 -- to/from are already available in local variables, use those if possible 108 if (code == "to" or code == "from") and func_name == "bare" then 109 code = "bare_"..code; 110 table.insert(deps, code); 111 elseif (code == "to" or code == "from") and (func_name == "node" or func_name == "host" or func_name == "resource") then 112 table.insert(deps, "split_"..code); 113 code = code.."_"..func_name; 114 else 115 assert(meta_funcs[func_name], "unknown function: "..func_name); 116 local new_code, new_deps = meta_funcs[func_name](code); 117 code = new_code; 118 if new_deps and #new_deps > 0 then 119 assert(deps, "function not supported here: "..func_name); 120 for _, dep in ipairs(new_deps) do 121 table.insert(deps, dep); 122 end 123 end 124 end 125 end 126 end 127 return "\"..tostring("..code.." or "..("%q"):format(default)..")..\""; 128 end) 129 :gsub("$$(%a+)", extra or {}) 130 :gsub([[^""%.%.]], "") 131 :gsub([[%.%.""$]], "")); 132end 133 134function metaq(s, ...) 135 return meta(("%q"):format(s), ...); 136end 137 138local escape_chars = { 139 a = "\a", b = "\b", f = "\f", n = "\n", r = "\r", t = "\t", 140 v = "\v", ["\\"] = "\\", ["\""] = "\"", ["\'"] = "\'" 141}; 142function stripslashes(s) 143 return (s:gsub("\\(.)", escape_chars)); 144end 145 146-- Dependency locations: 147-- <type lib> 148-- <type global> 149-- function handler() 150-- <local deps> 151-- if <conditions> then 152-- <actions> 153-- end 154-- end 155 156local available_deps = { 157 st = { global_code = [[local st = require "util.stanza";]]}; 158 it = { global_code = [[local it = require "util.iterators";]]}; 159 it_count = { global_code = [[local it_count = it.count;]], depends = { "it" } }; 160 current_host = { global_code = [[local current_host = module.host;]] }; 161 jid_split = { 162 global_code = [[local jid_split = require "util.jid".split;]]; 163 }; 164 jid_bare = { 165 global_code = [[local jid_bare = require "util.jid".bare;]]; 166 }; 167 to = { local_code = [[local to = stanza.attr.to or jid_bare(session.full_jid);]]; depends = { "jid_bare" } }; 168 from = { local_code = [[local from = stanza.attr.from;]] }; 169 type = { local_code = [[local type = stanza.attr.type;]] }; 170 name = { local_code = [[local name = stanza.name;]] }; 171 split_to = { -- The stanza's split to address 172 depends = { "jid_split", "to" }; 173 local_code = [[local to_node, to_host, to_resource = jid_split(to);]]; 174 }; 175 split_from = { -- The stanza's split from address 176 depends = { "jid_split", "from" }; 177 local_code = [[local from_node, from_host, from_resource = jid_split(from);]]; 178 }; 179 bare_to = { depends = { "jid_bare", "to" }, local_code = "local bare_to = jid_bare(to)"}; 180 bare_from = { depends = { "jid_bare", "from" }, local_code = "local bare_from = jid_bare(from)"}; 181 group_contains = { 182 global_code = [[local group_contains = module:depends("groups").group_contains]]; 183 }; 184 is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin;]]}; 185 core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza;]] }; 186 zone = { global_code = function (zone) 187 local var = zone; 188 if var == "$local" then 189 var = "_local"; -- See #1090 190 else 191 assert(idsafe(var), "Invalid zone name: "..zone); 192 end 193 return ("local zone_%s = zones[%q] or {};"):format(var, zone); 194 end }; 195 date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] }; 196 time = { local_code = function (what) 197 local defs = {}; 198 for field in what:gmatch("%a+") do 199 table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field)); 200 end 201 return table.concat(defs, " "); 202 end, depends = { "date_time" }; }; 203 timestamp = { global_code = [[local get_time = require "socket".gettime;]]; local_code = [[local current_timestamp = get_time();]]; }; 204 globalthrottle = { 205 global_code = function (throttle) 206 assert(idsafe(throttle), "Invalid rate limit name: "..throttle); 207 assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle); 208 return ("local global_throttle_%s = rates.%s:single();"):format(throttle, throttle); 209 end; 210 }; 211 multithrottle = { 212 global_code = function (throttle) 213 assert(pcall(require, "util.cache"), "Using LIMIT with 'on' requires Prosody 0.10 or higher"); 214 assert(idsafe(throttle), "Invalid rate limit name: "..throttle); 215 assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle); 216 return ("local multi_throttle_%s = rates.%s:multi();"):format(throttle, throttle); 217 end; 218 }; 219 rostermanager = { 220 global_code = [[local rostermanager = require "core.rostermanager";]]; 221 }; 222 roster_entry = { 223 local_code = [[local roster_entry = (to_node and rostermanager.load_roster(to_node, to_host) or {})[bare_from];]]; 224 depends = { "rostermanager", "split_to", "bare_from" }; 225 }; 226 list = { global_code = function (list) 227 assert(idsafe(list), "Invalid list name: "..list); 228 assert(active_definitions.LIST[list], "Unknown list: "..list); 229 return ("local list_%s = lists[%q];"):format(list, list); 230 end 231 }; 232 search = { 233 local_code = function (search_name) 234 local search_path = assert(active_definitions.SEARCH[search_name], "Undefined search path: "..search_name); 235 return ("local search_%s = tostring(stanza:find(%q) or \"\")"):format(search_name, search_path); 236 end; 237 }; 238 pattern = { 239 local_code = function (pattern_name) 240 local pattern = assert(active_definitions.PATTERN[pattern_name], "Undefined pattern: "..pattern_name); 241 return ("local pattern_%s = %q"):format(pattern_name, pattern); 242 end; 243 }; 244 tokens = { 245 local_code = function (search_and_pattern) 246 local search_name, pattern_name = search_and_pattern:match("^([^%-]+)-(.+)$"); 247 local code = ([[local tokens_%s_%s = {}; 248 if search_%s then 249 for s in search_%s:gmatch(pattern_%s) do 250 tokens_%s_%s[s] = true; 251 end 252 end 253 ]]):format(search_name, pattern_name, search_name, search_name, pattern_name, search_name, pattern_name); 254 return code, { "search:"..search_name, "pattern:"..pattern_name }; 255 end; 256 }; 257 scan_list = { 258 global_code = [[local function scan_list(list, items) for item in pairs(items) do if list:contains(item) then return true; end end end]]; 259 } 260}; 261 262local function include_dep(dependency, code) 263 local dep, dep_param = dependency:match("^([^:]+):?(.*)$"); 264 local dep_info = available_deps[dep]; 265 if not dep_info then 266 module:log("error", "Dependency not found: %s", dep); 267 return; 268 end 269 if code.included_deps[dependency] ~= nil then 270 if code.included_deps[dependency] ~= true then 271 module:log("error", "Circular dependency on %s", dep); 272 end 273 return; 274 end 275 code.included_deps[dependency] = false; -- Pending flag (used to detect circular references) 276 for _, dep_dep in ipairs(dep_info.depends or {}) do 277 include_dep(dep_dep, code); 278 end 279 if dep_info.global_code then 280 if dep_param ~= "" then 281 local global_code, deps = dep_info.global_code(dep_param); 282 if deps then 283 for _, dep_dep in ipairs(deps) do 284 include_dep(dep_dep, code); 285 end 286 end 287 table.insert(code.global_header, global_code); 288 else 289 table.insert(code.global_header, dep_info.global_code); 290 end 291 end 292 if dep_info.local_code then 293 if dep_param ~= "" then 294 local local_code, deps = dep_info.local_code(dep_param); 295 if deps then 296 for _, dep_dep in ipairs(deps) do 297 include_dep(dep_dep, code); 298 end 299 end 300 table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..local_code.."\n"); 301 else 302 table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..dep_info.local_code.."\n"); 303 end 304 end 305 code.included_deps[dependency] = true; 306end 307 308local definition_handlers = module:require("definitions"); 309local condition_handlers = module:require("conditions"); 310local action_handlers = module:require("actions"); 311 312if module:get_option_boolean("firewall_experimental_user_marks", false) then 313 module:require"marks"; 314end 315 316local function new_rule(ruleset, chain) 317 assert(chain, "no chain specified"); 318 local rule = { conditions = {}, actions = {}, deps = {} }; 319 table.insert(ruleset[chain], rule); 320 return rule; 321end 322 323local function parse_firewall_rules(filename) 324 local line_no = 0; 325 326 local function errmsg(err) 327 return "Error compiling "..filename.." on line "..line_no..": "..err; 328 end 329 330 local ruleset = { 331 deliver = {}; 332 }; 333 334 local chain = "deliver"; -- Default chain 335 local rule; 336 337 local file, err = io.open(filename); 338 if not file then return nil, err; end 339 340 local state; -- nil -> "rules" -> "actions" -> nil -> ... 341 342 local line_hold; 343 for line in file:lines() do 344 line = line:match("^%s*(.-)%s*$"); 345 if line_hold and line:sub(-1,-1) ~= "\\" then 346 line = line_hold..line; 347 line_hold = nil; 348 elseif line:sub(-1,-1) == "\\" then 349 line_hold = (line_hold or "")..line:sub(1,-2); 350 end 351 line_no = line_no + 1; 352 353 if line_hold or line:find("^[#;]") then -- luacheck: ignore 542 354 -- No action; comment or partial line 355 elseif line == "" then 356 if state == "rules" then 357 return nil, ("Expected an action on line %d for preceding criteria") 358 :format(line_no); 359 end 360 state = nil; 361 elseif not(state) and line:sub(1, 2) == "::" then 362 chain = line:gsub("^::%s*", ""); 363 local chain_info = chains[chain]; 364 if not chain_info then 365 if chain:match("^user/") then 366 chains[chain] = { type = "event", priority = 1, pass_return = false }; 367 else 368 return nil, errmsg("Unknown chain: "..chain); 369 end 370 elseif chain_info.type ~= "event" then 371 return nil, errmsg("Only event chains supported at the moment"); 372 end 373 ruleset[chain] = ruleset[chain] or {}; 374 elseif not(state) and line:sub(1,1) == "%" then -- Definition (zone, limit, etc.) 375 local what, name = line:match("^%%%s*([%w_]+) +([^ :]+)"); 376 if not definition_handlers[what] then 377 return nil, errmsg("Definition of unknown object: "..what); 378 elseif not name or not idsafe(name) then 379 return nil, errmsg("Invalid "..what.." name"); 380 end 381 382 local val = line:match(": ?(.*)$"); 383 if not val and line:find(":<") then -- Read from file 384 local fn = line:match(":< ?(.-)%s*$"); 385 if not fn then 386 return nil, errmsg("Unable to parse filename"); 387 end 388 local f, err = io.open(fn); 389 if not f then return nil, errmsg(err); end 390 val = f:read("*a"):gsub("\r?\n", " "):gsub("%s+$", ""); 391 end 392 if not val then 393 return nil, errmsg("No value given for definition"); 394 end 395 val = stripslashes(val); 396 local ok, ret = pcall(definition_handlers[what], name, val); 397 if not ok then 398 return nil, errmsg(ret); 399 end 400 401 if not active_definitions[what] then 402 active_definitions[what] = {}; 403 end 404 active_definitions[what][name] = ret; 405 elseif line:find("^[%w_ ]+[%.=]") then 406 -- Action 407 if state == nil then 408 -- This is a standalone action with no conditions 409 rule = new_rule(ruleset, chain); 410 end 411 state = "actions"; 412 -- Action handlers? 413 local action = line:match("^[%w_ ]+"):upper():gsub(" ", "_"); 414 if not action_handlers[action] then 415 return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>"); 416 end 417 table.insert(rule.actions, "-- "..line) 418 local ok, action_string, action_deps = pcall(action_handlers[action], line:match("=(.+)$")); 419 if not ok then 420 return nil, errmsg(action_string); 421 end 422 table.insert(rule.actions, action_string); 423 for _, dep in ipairs(action_deps or {}) do 424 table.insert(rule.deps, dep); 425 end 426 elseif state == "actions" then -- state is actions but action pattern did not match 427 state = nil; -- Awaiting next rule, etc. 428 table.insert(ruleset[chain], rule); 429 rule = nil; 430 else 431 if not state then 432 state = "rules"; 433 rule = new_rule(ruleset, chain); 434 end 435 -- Check standard modifiers for the condition (e.g. NOT) 436 local negated; 437 local condition = line:match("^[^:=%.?]*"); 438 if condition:find("%f[%w]NOT%f[^%w]") then 439 local s, e = condition:match("%f[%w]()NOT()%f[^%w]"); 440 condition = (condition:sub(1,s-1)..condition:sub(e+1, -1)):match("^%s*(.-)%s*$"); 441 negated = true; 442 end 443 condition = condition:gsub(" ", "_"); 444 if not condition_handlers[condition] then 445 return nil, ("Unknown condition on line %d: %s"):format(line_no, (condition:gsub("_", " "))); 446 end 447 -- Get the code for this condition 448 local ok, condition_code, condition_deps = pcall(condition_handlers[condition], line:match(":%s?(.+)$")); 449 if not ok then 450 return nil, errmsg(condition_code); 451 end 452 if negated then condition_code = "not("..condition_code..")"; end 453 table.insert(rule.conditions, condition_code); 454 for _, dep in ipairs(condition_deps or {}) do 455 table.insert(rule.deps, dep); 456 end 457 end 458 end 459 return ruleset; 460end 461 462local function process_firewall_rules(ruleset) 463 -- Compile ruleset and return complete code 464 465 local chain_handlers = {}; 466 467 -- Loop through the chains in the parsed ruleset (e.g. incoming, outgoing) 468 for chain_name, rules in pairs(ruleset) do 469 local code = { included_deps = {}, global_header = {} }; 470 local condition_uses = {}; 471 -- This inner loop assumes chain is an event-based, not a filter-based 472 -- chain (filter-based will be added later) 473 for _, rule in ipairs(rules) do 474 for _, condition in ipairs(rule.conditions) do 475 if condition:find("^not%(.+%)$") then 476 condition = condition:match("^not%((.+)%)$"); 477 end 478 condition_uses[condition] = (condition_uses[condition] or 0) + 1; 479 end 480 end 481 482 local condition_cache, n_conditions = {}, 0; 483 for _, rule in ipairs(rules) do 484 for _, dep in ipairs(rule.deps) do 485 include_dep(dep, code); 486 end 487 table.insert(code, "\n\t\t"); 488 local rule_code; 489 if #rule.conditions > 0 then 490 for i, condition in ipairs(rule.conditions) do 491 local negated = condition:match("^not%(.+%)$"); 492 if negated then 493 condition = condition:match("^not%((.+)%)$"); 494 end 495 if condition_uses[condition] > 1 then 496 local name = condition_cache[condition]; 497 if not name then 498 n_conditions = n_conditions + 1; 499 name = "condition"..n_conditions; 500 condition_cache[condition] = name; 501 table.insert(code, "local "..name.." = "..condition..";\n\t\t"); 502 end 503 rule.conditions[i] = (negated and "not(" or "")..name..(negated and ")" or ""); 504 else 505 rule.conditions[i] = (negated and "not(" or "(")..condition..")"; 506 end 507 end 508 509 rule_code = "if "..table.concat(rule.conditions, " and ").." then\n\t\t\t" 510 ..table.concat(rule.actions, "\n\t\t\t") 511 .."\n\t\tend\n"; 512 else 513 rule_code = table.concat(rule.actions, "\n\t\t"); 514 end 515 table.insert(code, rule_code); 516 end 517 518 for name in pairs(definition_handlers) do 519 table.insert(code.global_header, 1, "local "..name:lower().."s = definitions."..name..";"); 520 end 521 522 local code_string = "return function (definitions, fire_event, log, module, pass_return)\n\t" 523 ..table.concat(code.global_header, "\n\t") 524 .."\n\tlocal db = require 'util.debug';\n\n\t" 525 .."return function (event)\n\t\t" 526 .."local stanza, session = event.stanza, event.origin;\n" 527 ..table.concat(code, "") 528 .."\n\tend;\nend"; 529 530 chain_handlers[chain_name] = code_string; 531 end 532 533 return chain_handlers; 534end 535 536local function compile_firewall_rules(filename) 537 local ruleset, err = parse_firewall_rules(filename); 538 if not ruleset then return nil, err; end 539 local chain_handlers = process_firewall_rules(ruleset); 540 return chain_handlers; 541end 542 543-- Compile handler code into a factory that produces a valid event handler. Factory accepts 544-- a value to be returned on PASS 545local function compile_handler(code_string, filename) 546 -- Prepare event handler function 547 local chunk, err = envload(code_string, "="..filename, _G); 548 if not chunk then 549 return nil, "Error compiling (probably a compiler bug, please report): "..err; 550 end 551 local function fire_event(name, data) 552 return module:fire_event(name, data); 553 end 554 return function (pass_return) 555 return chunk()(active_definitions, fire_event, logger(filename), module, pass_return); -- Returns event handler with upvalues 556 end 557end 558 559local function resolve_script_path(script_path) 560 local relative_to = prosody.paths.config; 561 if script_path:match("^module:") then 562 relative_to = module.path:sub(1, -#("/mod_"..module.name..".lua")); 563 script_path = script_path:match("^module:(.+)$"); 564 end 565 return resolve_relative_path(relative_to, script_path); 566end 567 568-- [filename] = { last_modified = ..., events_hooked = { [name] = handler } } 569local loaded_scripts = {}; 570 571function load_script(script) 572 script = resolve_script_path(script); 573 local last_modified = (lfs.attributes(script) or {}).modification or os.time(); 574 if loaded_scripts[script] then 575 if loaded_scripts[script].last_modified == last_modified then 576 return; -- Already loaded, and source file hasn't changed 577 end 578 module:log("debug", "Reloading %s", script); 579 -- Already loaded, but the source file has changed 580 -- unload it now, and we'll load the new version below 581 unload_script(script, true); 582 end 583 local chain_functions, err = compile_firewall_rules(script); 584 585 if not chain_functions then 586 module:log("error", "Error compiling %s: %s", script, err or "unknown error"); 587 return; 588 end 589 590 -- Loop through the chains in the script, and for each chain attach the compiled code to the 591 -- relevant events, keeping track in events_hooked so we can cleanly unload later 592 local events_hooked = {}; 593 for chain, handler_code in pairs(chain_functions) do 594 local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain); 595 if not new_handler then 596 module:log("error", "Compilation error for %s: %s", script, err); 597 else 598 local chain_definition = chains[chain]; 599 if chain_definition and chain_definition.type == "event" then 600 local handler = new_handler(chain_definition.pass_return); 601 for _, event_name in ipairs(chain_definition) do 602 events_hooked[event_name] = handler; 603 module:hook(event_name, handler, chain_definition.priority); 604 end 605 elseif not chain:sub(1, 5) == "user/" then 606 module:log("warn", "Unknown chain %q", chain); 607 end 608 local event_name, handler = "firewall/chains/"..chain, new_handler(false); 609 events_hooked[event_name] = handler; 610 module:hook(event_name, handler); 611 end 612 end 613 loaded_scripts[script] = { last_modified = last_modified, events_hooked = events_hooked }; 614 module:log("debug", "Loaded %s", script); 615end 616 617--COMPAT w/0.9 (no module:unhook()!) 618local function module_unhook(event, handler) 619 return module:unhook_object_event((hosts[module.host] or prosody).events, event, handler); 620end 621 622function unload_script(script, is_reload) 623 script = resolve_script_path(script); 624 local script_info = loaded_scripts[script]; 625 if not script_info then 626 return; -- Script not loaded 627 end 628 local events_hooked = script_info.events_hooked; 629 for event_name, event_handler in pairs(events_hooked) do 630 module_unhook(event_name, event_handler); 631 events_hooked[event_name] = nil; 632 end 633 loaded_scripts[script] = nil; 634 if not is_reload then 635 module:log("debug", "Unloaded %s", script); 636 end 637end 638 639-- Given a set of scripts (e.g. from config) figure out which ones need to 640-- be loaded, which are already loaded but need unloading, and which to reload 641function load_unload_scripts(script_list) 642 local wanted_scripts = script_list / resolve_script_path; 643 local currently_loaded = set.new(it.to_array(it.keys(loaded_scripts))); 644 local scripts_to_unload = currently_loaded - wanted_scripts; 645 for script in wanted_scripts do 646 -- If the script is already loaded, this is fine - it will 647 -- reload the script for us if the file has changed 648 load_script(script); 649 end 650 for script in scripts_to_unload do 651 unload_script(script); 652 end 653end 654 655function module.load() 656 if not prosody.arg then return end -- Don't run in prosodyctl 657 local firewall_scripts = module:get_option_set("firewall_scripts", {}); 658 load_unload_scripts(firewall_scripts); 659 -- Replace contents of definitions table (shared) with active definitions 660 for k in it.keys(definitions) do definitions[k] = nil; end 661 for k,v in pairs(active_definitions) do definitions[k] = v; end 662end 663 664function module.save() 665 return { active_definitions = active_definitions, loaded_scripts = loaded_scripts }; 666end 667 668function module.restore(state) 669 active_definitions = state.active_definitions; 670 loaded_scripts = state.loaded_scripts; 671end 672 673module:hook_global("config-reloaded", function () 674 load_unload_scripts(module:get_option_set("firewall_scripts", {})); 675end); 676 677function module.command(arg) 678 if not arg[1] or arg[1] == "--help" then 679 require"util.prosodyctl".show_usage([[mod_firewall <firewall.pfw>]], [[Compile files with firewall rules to Lua code]]); 680 return 1; 681 end 682 local verbose = arg[1] == "-v"; 683 if verbose then table.remove(arg, 1); end 684 685 if arg[1] == "test" then 686 table.remove(arg, 1); 687 return module:require("test")(arg); 688 end 689 690 local serialize = require "util.serialization".serialize; 691 if verbose then 692 print("local logger = require \"util.logger\".init;"); 693 print(); 694 print("local function fire_event(name, data)\n\tmodule:fire_event(name, data)\nend"); 695 print(); 696 end 697 698 for _, filename in ipairs(arg) do 699 filename = resolve_script_path(filename); 700 print("do -- File "..filename); 701 local chain_functions = assert(compile_firewall_rules(filename)); 702 if verbose then 703 print(); 704 print("local active_definitions = "..serialize(active_definitions)..";"); 705 print(); 706 end 707 local c = 0; 708 for chain, handler_code in pairs(chain_functions) do 709 c = c + 1; 710 print("---- Chain "..chain:gsub("_", " ")); 711 local chain_func_name = "chain_"..tostring(c).."_"..chain:gsub("%p", "_"); 712 if not verbose then 713 print(("%s = %s;"):format(chain_func_name, handler_code:sub(8))); 714 else 715 716 print(("local %s = (%s)(active_definitions, fire_event, logger(%q));"):format(chain_func_name, handler_code:sub(8), filename)); 717 print(); 718 719 local chain_definition = chains[chain]; 720 if chain_definition and chain_definition.type == "event" then 721 for _, event_name in ipairs(chain_definition) do 722 print(("module:hook(%q, %s, %d);"):format(event_name, chain_func_name, chain_definition.priority or 0)); 723 end 724 end 725 print(("module:hook(%q, %s, %d);"):format("firewall/chains/"..chain, chain_func_name, chain_definition.priority or 0)); 726 end 727 728 print("---- End of chain "..chain); 729 print(); 730 end 731 print("end -- End of file "..filename); 732 end 733end 734