1--[[
2Copyright (c) 2021, Vsevolod Stakhov <vsevolod@highsecure.ru>
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15]]--
16
17--[[
18external_relay plugin - sets IP/hostname from Received headers
19]]--
20
21if confighelp then
22  return
23end
24
25local lua_maps = require "lua_maps"
26local lua_util = require "lua_util"
27local rspamd_logger = require "rspamd_logger"
28local ts = require("tableshape").types
29
30local E = {}
31local N = "external_relay"
32
33local settings = {
34  rules = {},
35}
36
37local config_schema = ts.shape{
38  enabled = ts.boolean:is_optional(),
39  rules = ts.map_of(
40    ts.string, ts.one_of{
41      ts.shape{
42        priority = ts.number:is_optional(),
43        strategy = 'authenticated',
44        symbol = ts.string:is_optional(),
45        user_map = lua_maps.map_schema:is_optional(),
46      },
47      ts.shape{
48        count = ts.number,
49        priority = ts.number:is_optional(),
50        strategy = 'count',
51        symbol = ts.string:is_optional(),
52      },
53      ts.shape{
54        priority = ts.number:is_optional(),
55        strategy = 'local',
56        symbol = ts.string:is_optional(),
57      },
58      ts.shape{
59        hostname_map = lua_maps.map_schema,
60        priority = ts.number:is_optional(),
61        strategy = 'hostname_map',
62        symbol = ts.string:is_optional(),
63      },
64    }
65  ),
66}
67
68local function set_from_rcvd(task, rcvd)
69  local rcvd_ip = rcvd.real_ip
70  if not (rcvd_ip and rcvd_ip:is_valid()) then
71    rspamd_logger.errx(task, 'no IP in header: %s', rcvd)
72    return
73  end
74  task:set_from_ip(rcvd_ip)
75  if rcvd.from_hostname then
76    task:set_hostname(rcvd.from_hostname)
77    task:set_helo(rcvd.from_hostname) -- use fake value for HELO
78  else
79    rspamd_logger.warnx(task, "couldn't get hostname from headers")
80    local ipstr = string.format('[%s]', rcvd_ip)
81    task:set_hostname(ipstr) -- returns nil from task:get_hostname()
82    task:set_helo(ipstr)
83  end
84  return true
85end
86
87local strategies = {}
88
89strategies.authenticated = function(rule)
90  local user_map
91  if rule.user_map then
92    user_map = lua_maps.map_add_from_ucl(rule.user_map, 'set', 'external relay usernames')
93    if not user_map then
94      rspamd_logger.errx(rspamd_config, "couldn't add map %s; won't register symbol %s",
95          rule.user_map, rule.symbol)
96      return
97    end
98  end
99
100  return function(task)
101    local user = task:get_user()
102    if not user then
103      lua_util.debugm(N, task, 'sender is unauthenticated')
104      return
105    end
106    if user_map then
107      if not user_map:get_key(user) then
108        lua_util.debugm(N, task, 'sender (%s) is not in user_map', user)
109        return
110      end
111    end
112
113    local rcvd_hdrs = task:get_received_headers()
114    -- Try find end of authentication chain
115    for _, rcvd in ipairs(rcvd_hdrs) do
116      if not rcvd.flags.authenticated then
117        -- Found unauthenticated hop, use this header
118        return set_from_rcvd(task, rcvd)
119      end
120    end
121
122    rspamd_logger.errx(task, 'found nothing useful in Received headers')
123  end
124end
125
126strategies.count = function(rule)
127  return function(task)
128    local rcvd_hdrs = task:get_received_headers()
129    -- Reduce count by 1 if artificial header is present
130    local hdr_count
131    if ((rcvd_hdrs[1] or E).flags or E).artificial then
132      hdr_count = rule.count - 1
133    else
134      hdr_count = rule.count
135    end
136
137    local rcvd = rcvd_hdrs[hdr_count]
138    if not rcvd then
139      rspamd_logger.errx(task, 'found no received header #%s', hdr_count)
140      return
141    end
142
143    return set_from_rcvd(task, rcvd)
144  end
145end
146
147strategies.hostname_map = function(rule)
148  local hostname_map = lua_maps.map_add_from_ucl(rule.hostname_map, 'map', 'external relay hostnames')
149  if not hostname_map then
150    rspamd_logger.errx(rspamd_config, "couldn't add map %s; won't register symbol %s",
151        rule.hostname_map, rule.symbol)
152    return
153  end
154
155  return function(task)
156    local from_hn = task:get_hostname()
157    if not from_hn then
158      lua_util.debugm(N, task, 'sending hostname is missing')
159      return
160    end
161
162    if hostname_map:get_key(from_hn) ~= 'direct' then
163      lua_util.debugm(N, task, 'sending hostname (%s) is not a direct relay', from_hn)
164      return
165    end
166
167    local rcvd_hdrs = task:get_received_headers()
168    -- Try find sending hostname in Received headers
169    for _, rcvd in ipairs(rcvd_hdrs) do
170      if rcvd.by_hostname == from_hn and rcvd.real_ip then
171        if not hostname_map:get_key(rcvd.from_hostname) then
172          -- Remote hostname is not another relay, use this header
173          return set_from_rcvd(task, rcvd)
174        else
175          -- Keep checking with new hostname
176          from_hn = rcvd.from_hostname
177        end
178      end
179    end
180
181    rspamd_logger.errx(task, 'found nothing useful in Received headers')
182  end
183end
184
185strategies['local'] = function(rule)
186  return function(task)
187    local from_ip = task:get_from_ip()
188    if not from_ip then
189      lua_util.debugm(N, task, 'sending IP is missing')
190      return
191    end
192
193    if not from_ip:is_local() then
194      lua_util.debugm(N, task, 'sending IP (%s) is non-local', from_ip)
195      return
196    end
197
198    local rcvd_hdrs = task:get_received_headers()
199    local num_rcvd = #rcvd_hdrs
200    -- Try find first non-local IP in Received headers
201    for i, rcvd in ipairs(rcvd_hdrs) do
202      if rcvd.real_ip then
203        local rcvd_ip = rcvd.real_ip
204        if rcvd_ip and rcvd_ip:is_valid() and (not rcvd_ip:is_local() or i == num_rcvd) then
205          return set_from_rcvd(task, rcvd)
206        end
207      end
208    end
209
210    rspamd_logger.errx(task, 'found nothing useful in Received headers')
211  end
212end
213
214local opts = rspamd_config:get_all_opt(N)
215if opts then
216  settings = lua_util.override_defaults(settings, opts)
217
218  local ok, schema_err = config_schema:transform(settings)
219  if not ok then
220    rspamd_logger.errx(rspamd_config, 'config schema error: %s', schema_err)
221    lua_util.disable_module(N, "config")
222    return
223  end
224
225  for k, rule in pairs(settings.rules) do
226
227    if not rule.symbol then
228      rule.symbol = k
229    end
230
231    local cb = strategies[rule.strategy](rule)
232
233    if cb then
234      rspamd_config:register_symbol({
235        name = rule.symbol,
236        type = 'prefilter',
237        priority = rule.priority or 20,
238        group = N,
239        callback = cb,
240      })
241    end
242  end
243end
244