1-- SPDX-License-Identifier: GPL-3.0-or-later
2-- Module interface
3local ffi = require('ffi')
4local prefixes_global = {}
5
6-- Create subnet prefix rule
7local function matchprefix(subnet, addr)
8	local target = kres.str2ip(addr)
9	if target == nil then error('[renumber] invalid address: '..addr) end
10	local addrtype = string.find(addr, ':', 1, true) and kres.type.AAAA or kres.type.A
11	local subnet_cd = ffi.new('char[16]')
12	local bitlen = ffi.C.kr_straddr_subnet(subnet_cd, subnet)
13	if bitlen < 0 then error('[renumber] invalid subnet: '..subnet) end
14	return {subnet_cd, bitlen, target, addrtype}
15end
16
17-- Create name match rule
18local function matchname(name, addr)
19	local target = kres.str2ip(addr)
20	if target == nil then error('[renumber] invalid address: '..addr) end
21	local owner = todname(name)
22	if not name then error('[renumber] invalid name: '..name) end
23	local addrtype = string.find(addr, ':', 1, true) and kres.type.AAAA or kres.type.A
24	return {owner, nil, target, addrtype}
25end
26
27-- Add subnet prefix rewrite rule
28local function add_prefix(subnet, addr)
29	local prefix = matchprefix(subnet, addr)
30	local bitlen = prefix[2]
31	if bitlen ~= nil and bitlen % 8 ~= 0 then
32		log_warn(ffi.C.LOG_GRP_RENUMBER, 'network mask: only /8, /16, /24 etc. are supported (entire octets are rewritten)')
33	end
34	table.insert(prefixes_global, prefix)
35end
36
37-- Match IP against given subnet or record owner
38local function match_subnet(subnet, bitlen, addrtype, rr)
39	local addr = rr.rdata
40	return addrtype == rr.type and
41	       ((bitlen and (#addr >= bitlen / 8) and (ffi.C.kr_bitcmp(subnet, addr, bitlen) == 0)) or subnet == rr.owner)
42end
43
44-- Renumber address record
45local addr_buf = ffi.new('char[16]')
46local function renumber_record(tbl, rr)
47	for i = 1, #tbl do
48		local prefix = tbl[i]
49		-- Match record type to address family and record address to given subnet
50		-- If provided, compare record owner to prefix name
51		if match_subnet(prefix[1], prefix[2], prefix[4], rr) then
52			-- Replace part or whole address
53			local to_copy = prefix[2] or (#prefix[3] * 8)
54			local chunks = to_copy / 8
55			local rdlen = #rr.rdata
56			if rdlen < chunks then return rr end -- Address length mismatch
57			ffi.copy(addr_buf, rr.rdata, rdlen)
58			ffi.copy(addr_buf, prefix[3], chunks) -- Rewrite prefix
59			rr.rdata = ffi.string(addr_buf, rdlen)
60			return rr
61		end
62	end
63	return nil
64end
65
66-- Renumber addresses based on config
67local function rule(prefixes)
68	return function (state, req)
69		if state == kres.FAIL then return state end
70		local pkt = req.answer
71		-- Only successful answers
72		local records = pkt:section(kres.section.ANSWER)
73		local ancount = #records
74		if ancount == 0 then return state end
75		-- Find renumber candidates
76		local changed = false
77		for i = 1, ancount do
78			local rr = records[i]
79			if rr.type == kres.type.A or rr.type == kres.type.AAAA then
80				local new_rr = renumber_record(prefixes, rr)
81				if new_rr ~= nil then
82					records[i] = new_rr
83					changed = true
84				end
85			end
86		end
87		-- If not rewritten, chain action
88		if not changed then return state end
89		-- Replace section if renumbering
90		local qname = pkt:qname()
91		local qclass = pkt:qclass()
92		local qtype = pkt:qtype()
93		pkt:recycle()
94		pkt:question(qname, qclass, qtype)
95		for i = 1, ancount do
96			local rr = records[i]
97			-- Strip signatures as rewritten data cannot be validated
98			if rr.type ~= kres.type.RRSIG then
99				pkt:put(rr.owner, rr.ttl, rr.class, rr.type, rr.rdata)
100			end
101		end
102		return state
103	end
104end
105
106-- Export module interface
107local M = {
108	prefix = matchprefix,
109	name = matchname,
110	rule = rule,
111	match_subnet = match_subnet,
112}
113
114-- Config
115function M.config (conf)
116	if conf == nil then return end
117	if type(conf) ~= 'table' or type(conf[1]) ~= 'table' then
118		error('[renumber] expected { {prefix, target}, ... }')
119	end
120	for i = 1, #conf do add_prefix(conf[i][1], conf[i][2]) end
121end
122
123-- Layers
124M.layer = {
125	finish = rule(prefixes_global),
126}
127
128return M
129