1--[[
2Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15]]--
16
17local logger = require "rspamd_logger"
18local lutil = require "lua_util"
19local rspamd_util = require "rspamd_util"
20local ts = require("tableshape").types
21
22local exports = {}
23
24local E = {}
25local N = "lua_redis"
26
27local common_schema = ts.shape {
28  timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
29  db = ts.string:is_optional(),
30  database = ts.string:is_optional(),
31  dbname = ts.string:is_optional(),
32  prefix = ts.string:is_optional(),
33  password = ts.string:is_optional(),
34  expand_keys = ts.boolean:is_optional(),
35  sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
36  sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
37  sentinel_masters_pattern = ts.string:is_optional(),
38  sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
39}
40
41local config_schema =
42  ts.shape({
43    read_servers = ts.string + ts.array_of(ts.string),
44    write_servers = ts.string + ts.array_of(ts.string),
45  }, {extra_opts = common_schema}) +
46  ts.shape({
47    servers = ts.string + ts.array_of(ts.string),
48  }, {extra_opts = common_schema}) +
49  ts.shape({
50    server = ts.string + ts.array_of(ts.string),
51  }, {extra_opts = common_schema})
52
53exports.config_schema = config_schema
54
55
56local function redis_query_sentinel(ev_base, params, initialised)
57  local function flatten_redis_table(tbl)
58    local res = {}
59    for i=1,#tbl,2 do
60      res[tbl[i]] = tbl[i + 1]
61    end
62
63    return res
64  end
65  -- Coroutines syntax
66  local rspamd_redis = require "rspamd_redis"
67  local sentinels = params.sentinels
68  local addr = sentinels:get_upstream_round_robin()
69
70  local host = addr:get_addr()
71  local masters = {}
72  local process_masters -- Function that is called to process masters data
73
74  local function masters_cb(err, result)
75    if not err and result and type(result) == 'table' then
76
77      local pending_subrequests = 0
78
79      for _,m in ipairs(result) do
80        local master = flatten_redis_table(m)
81
82        -- Wrap IPv6-adresses in brackets
83        if (master.ip:match(":")) then
84          master.ip = "["..master.ip.."]"
85        end
86
87        if params.sentinel_masters_pattern then
88          if master.name:match(params.sentinel_masters_pattern) then
89            lutil.debugm(N, 'found master %s with ip %s and port %s',
90                master.name, master.ip, master.port)
91            masters[master.name] = master
92          else
93            lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
94                master.name, master.ip, master.port, params.sentinel_masters_pattern)
95          end
96        else
97          lutil.debugm(N, 'found master %s with ip %s and port %s',
98              master.name, master.ip, master.port)
99          masters[master.name] = master
100        end
101      end
102
103      -- For each master we need to get a list of slaves
104      for k,v in pairs(masters) do
105        v.slaves = {}
106        local function slaves_cb(slave_err, slave_result)
107          if not slave_err and type(slave_result) == 'table' then
108            for _,s in ipairs(slave_result) do
109              local slave = flatten_redis_table(s)
110              lutil.debugm(N, rspamd_config,
111                  'found slave for master %s with ip %s and port %s',
112                  v.name, slave.ip, slave.port)
113              -- Wrap IPv6-adresses in brackets
114              if (slave.ip:match(":")) then
115                slave.ip = "["..slave.ip.."]"
116              end
117              v.slaves[#v.slaves + 1] = slave
118            end
119          else
120            logger.errx('cannot get slaves data from Redis Sentinel %s: %s',
121                host:to_string(true), slave_err)
122            addr:fail()
123          end
124
125          pending_subrequests = pending_subrequests - 1
126
127          if pending_subrequests == 0 then
128            -- Finalize masters and slaves
129            process_masters()
130          end
131        end
132
133        local ret = rspamd_redis.make_request({
134          host = addr:get_addr(),
135          timeout = params.timeout,
136          config = rspamd_config,
137          ev_base = ev_base,
138          cmd = 'SENTINEL',
139          args = {'slaves', k},
140          no_pool = true,
141          callback = slaves_cb
142        })
143
144        if not ret then
145          logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s',
146              host:to_string(true))
147          addr:fail()
148        else
149          pending_subrequests = pending_subrequests + 1
150        end
151      end
152
153      addr:ok()
154    else
155      logger.errx('cannot get masters data from Redis Sentinel %s: %s',
156          host:to_string(true), err)
157      addr:fail()
158    end
159  end
160
161  local ret = rspamd_redis.make_request({
162    host = addr:get_addr(),
163    timeout = params.timeout,
164    config = rspamd_config,
165    ev_base = ev_base,
166    cmd = 'SENTINEL',
167    args = {'masters'},
168    no_pool = true,
169    callback = masters_cb,
170  })
171
172  if not ret then
173    logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
174        host:to_string(true))
175    addr:fail()
176  end
177
178  process_masters = function()
179    -- We now form new strings for masters and slaves
180    local read_servers_tbl, write_servers_tbl = {}, {}
181
182    for _,master in pairs(masters) do
183      write_servers_tbl[#write_servers_tbl + 1] = string.format(
184          '%s:%s', master.ip, master.port
185      )
186      read_servers_tbl[#read_servers_tbl + 1] = string.format(
187          '%s:%s', master.ip, master.port
188      )
189
190      for _,slave in ipairs(master.slaves) do
191        if slave['master-link-status'] == 'ok' then
192          read_servers_tbl[#read_servers_tbl + 1] = string.format(
193              '%s:%s', slave.ip, slave.port
194          )
195        end
196      end
197    end
198
199    table.sort(read_servers_tbl)
200    table.sort(write_servers_tbl)
201
202    local read_servers_str = table.concat(read_servers_tbl, ',')
203    local write_servers_str = table.concat(write_servers_tbl, ',')
204
205    lutil.debugm(N, rspamd_config,
206        'new servers list: %s read; %s write',
207        read_servers_str,
208        write_servers_str)
209
210    if read_servers_str ~= params.read_servers_str then
211      local upstream_list = require "rspamd_upstream_list"
212
213      local read_upstreams = upstream_list.create(rspamd_config,
214          read_servers_str, 6379)
215
216      if read_upstreams then
217        logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
218            host:to_string(true), read_servers_str)
219        params.read_servers = read_upstreams
220        params.read_servers_str = read_servers_str
221      end
222    end
223
224    if write_servers_str ~= params.write_servers_str then
225      local upstream_list = require "rspamd_upstream_list"
226
227      local write_upstreams = upstream_list.create(rspamd_config,
228          write_servers_str, 6379)
229
230      if write_upstreams then
231        logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
232            host:to_string(true), write_servers_str)
233        params.write_servers = write_upstreams
234        params.write_servers_str = write_servers_str
235
236        local queried = false
237
238        local function monitor_failures(up, _, count)
239          if count > params.sentinel_master_maxerrors and not queried then
240            logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
241                host:to_string(true), count)
242            queried = true -- Avoid multiple checks caused by this monitor
243            redis_query_sentinel(ev_base, params, true)
244          end
245        end
246
247        write_upstreams:add_watcher('failure', monitor_failures)
248      end
249    end
250  end
251
252end
253
254local function add_redis_sentinels(params)
255  local upstream_list = require "rspamd_upstream_list"
256
257  local upstreams_sentinels = upstream_list.create(rspamd_config,
258      params.sentinels, 5000)
259
260  if not upstreams_sentinels then
261    logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
262        params.sentinels)
263
264    return
265  end
266
267  params.sentinels = upstreams_sentinels
268
269  if not params.sentinel_watch_time then
270    params.sentinel_watch_time = 60 -- Each minute
271  end
272
273  if not params.sentinel_master_maxerrors then
274    params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
275  end
276
277  rspamd_config:add_on_load(function(_, ev_base, worker)
278    local initialised = false
279    if worker:is_scanner() then
280      rspamd_config:add_periodic(ev_base, 0.0, function()
281        redis_query_sentinel(ev_base, params, initialised)
282        initialised = true
283
284        return params.sentinel_watch_time
285      end, false)
286    end
287  end)
288end
289
290local cached_results = {}
291
292local function calculate_redis_hash(params)
293  local cr = require "rspamd_cryptobox_hash"
294
295  local h = cr.create()
296
297  local function rec_hash(k, v)
298    if type(v) == 'string' then
299      h:update(k)
300      h:update(v)
301    elseif type(v) == 'number' then
302      h:update(k)
303      h:update(tostring(v))
304    elseif type(v) == 'table' then
305      for kk,vv in pairs(v) do
306        rec_hash(kk, vv)
307      end
308    end
309  end
310
311  rec_hash('top', params)
312
313  return h:base32()
314end
315
316local function process_redis_opts(options, redis_params)
317  local default_timeout = 1.0
318  local default_expand_keys = false
319
320  if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
321    if options['timeout'] then
322      redis_params['timeout'] = tonumber(options['timeout'])
323    else
324      redis_params['timeout'] = default_timeout
325    end
326  end
327
328  if options['prefix'] and not redis_params['prefix'] then
329    redis_params['prefix'] = options['prefix']
330  end
331
332  if type(options['expand_keys']) == 'boolean' then
333    redis_params['expand_keys'] = options['expand_keys']
334  else
335    redis_params['expand_keys'] = default_expand_keys
336  end
337
338  if not redis_params['db'] then
339    if options['db'] then
340      redis_params['db'] = tostring(options['db'])
341    elseif options['dbname'] then
342      redis_params['db'] = tostring(options['dbname'])
343    elseif options['database'] then
344      redis_params['db'] = tostring(options['database'])
345    end
346  end
347  if options['password'] and not redis_params['password'] then
348    redis_params['password'] = options['password']
349  end
350
351  if not redis_params.sentinels and options.sentinels then
352    redis_params.sentinels = options.sentinels
353  end
354
355  if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then
356    redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern']
357  end
358
359end
360
361local function enrich_defaults(rspamd_config, module, redis_params)
362  if rspamd_config then
363    local opts = rspamd_config:get_all_opt('redis')
364
365    if opts then
366      if module then
367        if opts[module] then
368          process_redis_opts(opts[module], redis_params)
369        end
370      end
371
372      process_redis_opts(opts, redis_params)
373    end
374  end
375end
376
377local function maybe_return_cached(redis_params)
378  local h = calculate_redis_hash(redis_params)
379
380  if cached_results[h] then
381    lutil.debugm(N, 'reused redis server: %s', redis_params)
382    return cached_results[h]
383  end
384
385  redis_params.hash = h
386  cached_results[h] = redis_params
387
388  if not redis_params.read_only and redis_params.sentinels then
389    add_redis_sentinels(redis_params)
390  end
391
392  lutil.debugm(N, 'loaded new redis server: %s', redis_params)
393  return redis_params
394end
395
396--[[[
397-- @module lua_redis
398-- This module contains helper functions for working with Redis
399--]]
400local function process_redis_options(options, rspamd_config, result)
401  local default_port = 6379
402  local upstream_list = require "rspamd_upstream_list"
403  local read_only = true
404
405  -- Try to get read servers:
406  local upstreams_read, upstreams_write
407
408  if options['read_servers'] then
409    if rspamd_config then
410      upstreams_read = upstream_list.create(rspamd_config,
411        options['read_servers'], default_port)
412    else
413      upstreams_read = upstream_list.create(options['read_servers'],
414        default_port)
415    end
416
417    result.read_servers_str = options['read_servers']
418  elseif options['servers'] then
419    if rspamd_config then
420      upstreams_read = upstream_list.create(rspamd_config,
421        options['servers'], default_port)
422    else
423      upstreams_read = upstream_list.create(options['servers'], default_port)
424    end
425
426    result.read_servers_str = options['servers']
427    read_only = false
428  elseif options['server'] then
429    if rspamd_config then
430      upstreams_read = upstream_list.create(rspamd_config,
431        options['server'], default_port)
432    else
433      upstreams_read = upstream_list.create(options['server'], default_port)
434    end
435
436    result.read_servers_str = options['server']
437    read_only = false
438  end
439
440  if upstreams_read then
441    if options['write_servers'] then
442      if rspamd_config then
443        upstreams_write = upstream_list.create(rspamd_config,
444                options['write_servers'], default_port)
445      else
446        upstreams_write = upstream_list.create(options['write_servers'],
447                default_port)
448      end
449      result.write_servers_str = options['write_servers']
450      read_only = false
451    elseif not read_only then
452      upstreams_write = upstreams_read
453      result.write_servers_str = result.read_servers_str
454    end
455  end
456
457  -- Store options
458  process_redis_opts(options, result)
459
460  if read_only and not upstreams_write then
461    result.read_only = true
462  elseif upstreams_write then
463    result.read_only = false
464  end
465
466  if upstreams_read then
467    result.read_servers = upstreams_read
468
469    if upstreams_write then
470      result.write_servers = upstreams_write
471    end
472
473    return true
474  end
475
476  lutil.debugm(N, rspamd_config,
477      'cannot load redis server from obj: %s, processed to %s',
478      options, result)
479
480  return false
481end
482
483--[[[
484@function try_load_redis_servers(options, rspamd_config, no_fallback)
485Tries to load redis servers from the specified `options` object.
486Returns `redis_params` table or nil in case of failure
487
488--]]
489exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
490  local result = {}
491
492  if process_redis_options(options, rspamd_config, result) then
493    if not no_fallback then
494      enrich_defaults(rspamd_config, module_name, result)
495    end
496    return maybe_return_cached(result)
497  end
498end
499
500-- This function parses redis server definition using either
501-- specific server string for this module or global
502-- redis section
503local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
504  local result = {}
505
506  -- Try local options
507  local opts
508  lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
509  if not module_opts then
510    opts = rspamd_config:get_all_opt(module_name)
511  else
512    opts = module_opts
513  end
514
515  if opts then
516    local ret
517
518    if opts.redis then
519      ret = process_redis_options(opts.redis, rspamd_config, result)
520
521      if ret then
522        if not no_fallback then
523          enrich_defaults(rspamd_config, module_name, result)
524        end
525        return maybe_return_cached(result)
526      end
527    end
528
529    ret = process_redis_options(opts, rspamd_config, result)
530
531    if ret then
532      if not no_fallback then
533        enrich_defaults(rspamd_config, module_name, result)
534      end
535      return maybe_return_cached(result)
536    end
537  end
538
539  if no_fallback then
540    logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
541        module_name)
542
543    return nil
544  end
545
546  -- Try global options
547  opts = rspamd_config:get_all_opt('redis')
548
549  if opts then
550    local ret
551
552    if opts[module_name] then
553      ret = process_redis_options(opts[module_name], rspamd_config, result)
554
555      if ret then
556        return maybe_return_cached(result)
557      end
558    else
559      ret = process_redis_options(opts, rspamd_config, result)
560
561      -- Exclude disabled
562      if opts['disabled_modules'] then
563        for _,v in ipairs(opts['disabled_modules']) do
564          if v == module_name then
565            logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
566              module_name)
567
568              return nil
569          end
570        end
571      end
572
573      if ret then
574        logger.infox(rspamd_config, "use default Redis settings for %s",
575            module_name)
576        return maybe_return_cached(result)
577      end
578    end
579  end
580
581  if result.read_servers then
582      return maybe_return_cached(result)
583  end
584
585  return nil
586end
587
588--[[[
589-- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
590-- Extracts Redis server settings from configuration
591-- @param {string} module_name name of module to get settings for
592-- @param {table} module_opts settings for module or `nil` to fetch them from configuration
593-- @param {boolean} no_fallback should be `true` if global settings must not be used
594-- @return {table} redis server settings
595-- @example
596-- local rconfig = lua_redis.parse_redis_server('my_module')
597-- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
598-- -- ['timeout'] contains timeout in seconds
599-- -- ['expand_keys'] if true tells that redis key expansion is enabled
600--]]
601
602exports.rspamd_parse_redis_server = rspamd_parse_redis_server
603exports.parse_redis_server = rspamd_parse_redis_server
604
605local process_cmd = {
606  bitop = function(args)
607    local idx_l = {}
608    for i = 2, #args do
609      table.insert(idx_l, i)
610    end
611    return idx_l
612  end,
613  blpop = function(args)
614    local idx_l = {}
615    for i = 1, #args -1 do
616      table.insert(idx_l, i)
617    end
618    return idx_l
619  end,
620  eval = function(args)
621    local idx_l = {}
622    local numkeys = args[2]
623    if numkeys and tonumber(numkeys) >= 1 then
624      for i = 3, numkeys + 2 do
625        table.insert(idx_l, i)
626      end
627    end
628    return idx_l
629  end,
630  set = function(args)
631    return {1}
632  end,
633  mget = function(args)
634    local idx_l = {}
635    for i = 1, #args do
636      table.insert(idx_l, i)
637    end
638    return idx_l
639  end,
640  mset = function(args)
641    local idx_l = {}
642    for i = 1, #args, 2 do
643      table.insert(idx_l, i)
644    end
645    return idx_l
646  end,
647  sdiffstore = function(args)
648    local idx_l = {}
649    for i = 2, #args do
650      table.insert(idx_l, i)
651    end
652    return idx_l
653  end,
654  smove = function(args)
655    return {1, 2}
656  end,
657  script = function() end
658}
659process_cmd.append = process_cmd.set
660process_cmd.auth = process_cmd.script
661process_cmd.bgrewriteaof = process_cmd.script
662process_cmd.bgsave = process_cmd.script
663process_cmd.bitcount = process_cmd.set
664process_cmd.bitfield = process_cmd.set
665process_cmd.bitpos = process_cmd.set
666process_cmd.brpop = process_cmd.blpop
667process_cmd.brpoplpush = process_cmd.blpop
668process_cmd.client = process_cmd.script
669process_cmd.cluster = process_cmd.script
670process_cmd.command = process_cmd.script
671process_cmd.config = process_cmd.script
672process_cmd.dbsize = process_cmd.script
673process_cmd.debug = process_cmd.script
674process_cmd.decr = process_cmd.set
675process_cmd.decrby = process_cmd.set
676process_cmd.del = process_cmd.mget
677process_cmd.discard = process_cmd.script
678process_cmd.dump = process_cmd.set
679process_cmd.echo = process_cmd.script
680process_cmd.evalsha = process_cmd.eval
681process_cmd.exec = process_cmd.script
682process_cmd.exists = process_cmd.mget
683process_cmd.expire = process_cmd.set
684process_cmd.expireat = process_cmd.set
685process_cmd.flushall = process_cmd.script
686process_cmd.flushdb = process_cmd.script
687process_cmd.geoadd = process_cmd.set
688process_cmd.geohash = process_cmd.set
689process_cmd.geopos = process_cmd.set
690process_cmd.geodist = process_cmd.set
691process_cmd.georadius = process_cmd.set
692process_cmd.georadiusbymember = process_cmd.set
693process_cmd.get = process_cmd.set
694process_cmd.getbit = process_cmd.set
695process_cmd.getrange = process_cmd.set
696process_cmd.getset = process_cmd.set
697process_cmd.hdel = process_cmd.set
698process_cmd.hexists = process_cmd.set
699process_cmd.hget = process_cmd.set
700process_cmd.hgetall = process_cmd.set
701process_cmd.hincrby = process_cmd.set
702process_cmd.hincrbyfloat = process_cmd.set
703process_cmd.hkeys = process_cmd.set
704process_cmd.hlen = process_cmd.set
705process_cmd.hmget = process_cmd.set
706process_cmd.hmset = process_cmd.set
707process_cmd.hscan = process_cmd.set
708process_cmd.hset = process_cmd.set
709process_cmd.hsetnx = process_cmd.set
710process_cmd.hstrlen = process_cmd.set
711process_cmd.hvals = process_cmd.set
712process_cmd.incr = process_cmd.set
713process_cmd.incrby = process_cmd.set
714process_cmd.incrbyfloat = process_cmd.set
715process_cmd.info = process_cmd.script
716process_cmd.keys = process_cmd.script
717process_cmd.lastsave = process_cmd.script
718process_cmd.lindex = process_cmd.set
719process_cmd.linsert = process_cmd.set
720process_cmd.llen = process_cmd.set
721process_cmd.lpop = process_cmd.set
722process_cmd.lpush = process_cmd.set
723process_cmd.lpushx = process_cmd.set
724process_cmd.lrange = process_cmd.set
725process_cmd.lrem = process_cmd.set
726process_cmd.lset = process_cmd.set
727process_cmd.ltrim = process_cmd.set
728process_cmd.migrate = process_cmd.script
729process_cmd.monitor = process_cmd.script
730process_cmd.move = process_cmd.set
731process_cmd.msetnx = process_cmd.mset
732process_cmd.multi = process_cmd.script
733process_cmd.object = process_cmd.script
734process_cmd.persist = process_cmd.set
735process_cmd.pexpire = process_cmd.set
736process_cmd.pexpireat = process_cmd.set
737process_cmd.pfadd = process_cmd.set
738process_cmd.pfcount = process_cmd.set
739process_cmd.pfmerge = process_cmd.mget
740process_cmd.ping = process_cmd.script
741process_cmd.psetex = process_cmd.set
742process_cmd.psubscribe = process_cmd.script
743process_cmd.pubsub = process_cmd.script
744process_cmd.pttl = process_cmd.set
745process_cmd.publish = process_cmd.script
746process_cmd.punsubscribe = process_cmd.script
747process_cmd.quit = process_cmd.script
748process_cmd.randomkey = process_cmd.script
749process_cmd.readonly = process_cmd.script
750process_cmd.readwrite = process_cmd.script
751process_cmd.rename = process_cmd.mget
752process_cmd.renamenx = process_cmd.mget
753process_cmd.restore = process_cmd.set
754process_cmd.role = process_cmd.script
755process_cmd.rpop = process_cmd.set
756process_cmd.rpoplpush = process_cmd.mget
757process_cmd.rpush = process_cmd.set
758process_cmd.rpushx = process_cmd.set
759process_cmd.sadd = process_cmd.set
760process_cmd.save = process_cmd.script
761process_cmd.scard = process_cmd.set
762process_cmd.sdiff = process_cmd.mget
763process_cmd.select = process_cmd.script
764process_cmd.setbit = process_cmd.set
765process_cmd.setex = process_cmd.set
766process_cmd.setnx = process_cmd.set
767process_cmd.sinterstore = process_cmd.sdiff
768process_cmd.sismember = process_cmd.set
769process_cmd.slaveof = process_cmd.script
770process_cmd.slowlog = process_cmd.script
771process_cmd.smembers = process_cmd.script
772process_cmd.sort = process_cmd.set
773process_cmd.spop = process_cmd.set
774process_cmd.srandmember = process_cmd.set
775process_cmd.srem = process_cmd.set
776process_cmd.strlen = process_cmd.set
777process_cmd.subscribe = process_cmd.script
778process_cmd.sunion = process_cmd.mget
779process_cmd.sunionstore = process_cmd.mget
780process_cmd.swapdb = process_cmd.script
781process_cmd.sync = process_cmd.script
782process_cmd.time = process_cmd.script
783process_cmd.touch = process_cmd.mget
784process_cmd.ttl = process_cmd.set
785process_cmd.type = process_cmd.set
786process_cmd.unsubscribe = process_cmd.script
787process_cmd.unlink = process_cmd.mget
788process_cmd.unwatch = process_cmd.script
789process_cmd.wait = process_cmd.script
790process_cmd.watch = process_cmd.mget
791process_cmd.zadd = process_cmd.set
792process_cmd.zcard = process_cmd.set
793process_cmd.zcount = process_cmd.set
794process_cmd.zincrby = process_cmd.set
795process_cmd.zinterstore = process_cmd.eval
796process_cmd.zlexcount = process_cmd.set
797process_cmd.zrange = process_cmd.set
798process_cmd.zrangebylex = process_cmd.set
799process_cmd.zrank = process_cmd.set
800process_cmd.zrem = process_cmd.set
801process_cmd.zrembylex = process_cmd.set
802process_cmd.zrembyrank = process_cmd.set
803process_cmd.zrembyscore = process_cmd.set
804process_cmd.zrevrange = process_cmd.set
805process_cmd.zrevrangebyscore = process_cmd.set
806process_cmd.zrevrank = process_cmd.set
807process_cmd.zscore = process_cmd.set
808process_cmd.zunionstore = process_cmd.eval
809process_cmd.scan = process_cmd.script
810process_cmd.sscan = process_cmd.set
811process_cmd.hscan = process_cmd.set
812process_cmd.zscan = process_cmd.set
813
814local function get_key_indexes(cmd, args)
815  local idx_l = {}
816  cmd = string.lower(cmd)
817  if process_cmd[cmd] then
818    idx_l = process_cmd[cmd](args)
819  else
820    logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
821  end
822  return idx_l
823end
824
825local gen_meta = {
826  principal_recipient = function(task)
827    return task:get_principal_recipient()
828  end,
829  principal_recipient_domain = function(task)
830    local p = task:get_principal_recipient()
831    if not p then return end
832    return string.match(p, '.*@(.*)')
833  end,
834  ip = function(task)
835    local i = task:get_ip()
836    if i and i:is_valid() then return i:to_string() end
837  end,
838  from = function(task)
839    return ((task:get_from('smtp') or E)[1] or E)['addr']
840  end,
841  from_domain = function(task)
842    return ((task:get_from('smtp') or E)[1] or E)['domain']
843  end,
844  from_domain_or_helo_domain = function(task)
845    local d = ((task:get_from('smtp') or E)[1] or E)['domain']
846    if d and #d > 0 then return d end
847    return task:get_helo()
848  end,
849  mime_from = function(task)
850    return ((task:get_from('mime') or E)[1] or E)['addr']
851  end,
852  mime_from_domain = function(task)
853    return ((task:get_from('mime') or E)[1] or E)['domain']
854  end,
855}
856
857local function gen_get_esld(f)
858  return function(task)
859    local d = f(task)
860    if not d then return end
861    return rspamd_util.get_tld(d)
862  end
863end
864
865gen_meta.smtp_from = gen_meta.from
866gen_meta.smtp_from_domain = gen_meta.from_domain
867gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
868gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
869gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
870gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
871gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
872gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
873gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
874
875local function get_key_expansion_metadata(task)
876
877  local md_mt = {
878    __index = function(self, k)
879      k = string.lower(k)
880      local v = rawget(self, k)
881      if v then
882        return v
883      end
884      if gen_meta[k] then
885        v = gen_meta[k](task)
886        rawset(self, k, v)
887      end
888      return v
889    end,
890  }
891
892  local lazy_meta = {}
893  setmetatable(lazy_meta, md_mt)
894  return lazy_meta
895
896end
897
898-- Performs async call to redis hiding all complexity inside function
899-- task - rspamd_task
900-- redis_params - valid params returned by rspamd_parse_redis_server
901-- key - key to select upstream or nil to select round-robin/master-slave
902-- is_write - true if need to write to redis server
903-- callback - function to be called upon request is completed
904-- command - redis command
905-- args - table of arguments
906-- extra_opts - table of optional request arguments
907local function rspamd_redis_make_request(task, redis_params, key, is_write,
908    callback, command, args, extra_opts)
909  local addr
910  local function rspamd_redis_make_request_cb(err, data)
911    if err then
912      addr:fail()
913    else
914      addr:ok()
915    end
916    if callback then
917      callback(err, data, addr)
918    end
919  end
920  if not task or not redis_params or not callback or not command then
921    return false,nil,nil
922  end
923
924  local rspamd_redis = require "rspamd_redis"
925
926  if key then
927    if is_write then
928      addr = redis_params['write_servers']:get_upstream_by_hash(key)
929    else
930      addr = redis_params['read_servers']:get_upstream_by_hash(key)
931    end
932  else
933    if is_write then
934      addr = redis_params['write_servers']:get_upstream_master_slave(key)
935    else
936      addr = redis_params['read_servers']:get_upstream_round_robin(key)
937    end
938  end
939
940  if not addr then
941    logger.errx(task, 'cannot select server to make redis request')
942  end
943
944  if redis_params['expand_keys'] then
945    local m = get_key_expansion_metadata(task)
946    local indexes = get_key_indexes(command, args)
947    for _, i in ipairs(indexes) do
948      args[i] = lutil.template(args[i], m)
949    end
950  end
951
952  local ip_addr = addr:get_addr()
953  local options = {
954    task = task,
955    callback = rspamd_redis_make_request_cb,
956    host = ip_addr,
957    timeout = redis_params['timeout'],
958    cmd = command,
959    args = args
960  }
961
962  if extra_opts then
963    for k,v in pairs(extra_opts) do
964      options[k] = v
965    end
966  end
967
968  if redis_params['password'] then
969    options['password'] = redis_params['password']
970  end
971
972  if redis_params['db'] then
973    options['dbname'] = redis_params['db']
974  end
975
976  lutil.debugm(N, task, 'perform request to redis server' ..
977      ' (host=%s, timeout=%s): cmd: %s', ip_addr,
978      options.timeout, options.cmd)
979
980  local ret,conn = rspamd_redis.make_request(options)
981
982  if not ret then
983    addr:fail()
984    logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
985  end
986
987  return ret,conn,addr
988end
989
990--[[[
991-- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
992-- Sends a request to Redis
993-- @param {rspamd_task} task task object
994-- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
995-- @param {string} key key to use for sharding
996-- @param {boolean} is_write should be `true` if we are performing a write operating
997-- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
998-- @param {string} command Redis command to run
999-- @param {table} args Numerically indexed table containing arguments for command
1000--]]
1001
1002exports.rspamd_redis_make_request = rspamd_redis_make_request
1003exports.redis_make_request = rspamd_redis_make_request
1004
1005local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
1006    is_write, callback, command, args, extra_opts)
1007  if not ev_base or not redis_params or not callback or not command then
1008    return false,nil,nil
1009  end
1010
1011  local addr
1012  local function rspamd_redis_make_request_cb(err, data)
1013    if err then
1014      addr:fail()
1015    else
1016      addr:ok()
1017    end
1018    if callback then
1019      callback(err, data, addr)
1020    end
1021  end
1022
1023  local rspamd_redis = require "rspamd_redis"
1024
1025  if key then
1026    if is_write then
1027      addr = redis_params['write_servers']:get_upstream_by_hash(key)
1028    else
1029      addr = redis_params['read_servers']:get_upstream_by_hash(key)
1030    end
1031  else
1032    if is_write then
1033      addr = redis_params['write_servers']:get_upstream_master_slave(key)
1034    else
1035      addr = redis_params['read_servers']:get_upstream_round_robin(key)
1036    end
1037  end
1038
1039  if not addr then
1040    logger.errx(cfg, 'cannot select server to make redis request')
1041  end
1042
1043  local options = {
1044    ev_base = ev_base,
1045    config = cfg,
1046    callback = rspamd_redis_make_request_cb,
1047    host = addr:get_addr(),
1048    timeout = redis_params['timeout'],
1049    cmd = command,
1050    args = args
1051  }
1052  if extra_opts then
1053    for k,v in pairs(extra_opts) do
1054      options[k] = v
1055    end
1056  end
1057
1058
1059  if redis_params['password'] then
1060    options['password'] = redis_params['password']
1061  end
1062
1063  if redis_params['db'] then
1064    options['dbname'] = redis_params['db']
1065  end
1066
1067  lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
1068      ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true),
1069      options.timeout, options.cmd)
1070  local ret,conn = rspamd_redis.make_request(options)
1071  if not ret then
1072    logger.errx('cannot execute redis request')
1073    addr:fail()
1074  end
1075
1076  return ret,conn,addr
1077end
1078
1079--[[[
1080-- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
1081-- Sends a request to Redis in context where `task` is not available for some specific use-cases
1082-- Identical to redis_make_request() except in that first parameter is an `event base` object
1083--]]
1084
1085exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
1086exports.redis_make_request_taskless = redis_make_request_taskless
1087
1088local redis_scripts = {
1089}
1090
1091local function script_set_loaded(script)
1092  if script.sha then
1093    script.loaded = true
1094  end
1095
1096  local wait_table = {}
1097  for _,s in ipairs(script.waitq) do
1098    table.insert(wait_table, s)
1099  end
1100
1101  script.waitq = {}
1102
1103  for _,s in ipairs(wait_table) do
1104    s(script.loaded)
1105  end
1106end
1107
1108local function prepare_redis_call(script)
1109  local function merge_tables(t1, t2)
1110    for k,v in pairs(t2) do t1[k] = v end
1111  end
1112
1113  local servers = {}
1114  local options = {}
1115
1116  if script.redis_params.read_servers then
1117    merge_tables(servers, script.redis_params.read_servers:all_upstreams())
1118  end
1119  if script.redis_params.write_servers then
1120    merge_tables(servers, script.redis_params.write_servers:all_upstreams())
1121  end
1122
1123  -- Call load script on each server, set loaded flag
1124  script.in_flight = #servers
1125  for _,s in ipairs(servers) do
1126    local cur_opts = {
1127      host = s:get_addr(),
1128      timeout = script.redis_params['timeout'],
1129      cmd = 'SCRIPT',
1130      args = {'LOAD', script.script },
1131      upstream = s
1132    }
1133
1134    if script.redis_params['password'] then
1135      cur_opts['password'] = script.redis_params['password']
1136    end
1137
1138    if script.redis_params['db'] then
1139      cur_opts['dbname'] = script.redis_params['db']
1140    end
1141
1142    table.insert(options, cur_opts)
1143  end
1144
1145  return options
1146end
1147
1148local function load_script_task(script, task)
1149  local rspamd_redis = require "rspamd_redis"
1150  local opts = prepare_redis_call(script)
1151
1152  for _,opt in ipairs(opts) do
1153    opt.task = task
1154    opt.callback = function(err, data)
1155      if err then
1156        logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
1157            opt.upstream:get_addr():to_string(true),
1158            err, script.caller.short_src, script.caller.currentline)
1159        opt.upstream:fail()
1160        script.fatal_error = err
1161      else
1162        opt.upstream:ok()
1163        logger.infox(task,
1164          "uploaded redis script to %s with id %s, sha: %s",
1165            opt.upstream:get_addr():to_string(true),
1166            script.id, data)
1167        script.sha = data -- We assume that sha is the same on all servers
1168      end
1169      script.in_flight = script.in_flight - 1
1170
1171      if script.in_flight == 0 then
1172        script_set_loaded(script)
1173      end
1174    end
1175
1176    local ret = rspamd_redis.make_request(opt)
1177
1178    if not ret then
1179      logger.errx('cannot execute redis request to load script on %s',
1180        opt.upstream:get_addr())
1181      script.in_flight = script.in_flight - 1
1182      opt.upstream:fail()
1183    end
1184
1185    if script.in_flight == 0 then
1186      script_set_loaded(script)
1187    end
1188  end
1189end
1190
1191local function load_script_taskless(script, cfg, ev_base)
1192  local rspamd_redis = require "rspamd_redis"
1193  local opts = prepare_redis_call(script)
1194
1195  for _,opt in ipairs(opts) do
1196    opt.config = cfg
1197    opt.ev_base = ev_base
1198    opt.callback = function(err, data)
1199      if err then
1200        logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s',
1201            opt.upstream:get_addr():to_string(true),
1202            err, script.caller.short_src, script.caller.currentline)
1203        opt.upstream:fail()
1204        script.fatal_error = err
1205      else
1206        opt.upstream:ok()
1207        logger.infox(cfg,
1208          "uploaded redis script to %s with id %s, sha: %s",
1209            opt.upstream:get_addr():to_string(true), script.id, data)
1210        script.sha = data -- We assume that sha is the same on all servers
1211        script.fatal_error = nil
1212      end
1213      script.in_flight = script.in_flight - 1
1214
1215      if script.in_flight == 0 then
1216        script_set_loaded(script)
1217      end
1218    end
1219    local ret = rspamd_redis.make_request(opt)
1220
1221    if not ret then
1222      logger.errx('cannot execute redis request to load script on %s',
1223        opt.upstream:get_addr())
1224      script.in_flight = script.in_flight - 1
1225      opt.upstream:fail()
1226    end
1227
1228    if script.in_flight == 0 then
1229      script_set_loaded(script)
1230    end
1231  end
1232end
1233
1234local function load_redis_script(script, cfg, ev_base, _)
1235  if script.redis_params then
1236    load_script_taskless(script, cfg, ev_base)
1237  end
1238end
1239
1240local function add_redis_script(script, redis_params)
1241  local caller = debug.getinfo(2)
1242
1243  local new_script = {
1244    caller = caller,
1245    loaded = false,
1246    redis_params = redis_params,
1247    script = script,
1248    waitq = {}, -- callbacks pending for script being loaded
1249    id = #redis_scripts + 1
1250  }
1251
1252  -- Register on load function
1253  rspamd_config:add_on_load(function(cfg, ev_base, worker)
1254    local mult = 0.0
1255    rspamd_config:add_periodic(ev_base, 0.0, function()
1256      if not new_script.sha then
1257        load_redis_script(new_script, cfg, ev_base, worker)
1258        mult = mult + 1
1259        return 1.0 * mult -- Check one more time in one second
1260      end
1261
1262      return false
1263    end, false)
1264  end)
1265
1266  table.insert(redis_scripts, new_script)
1267
1268  return #redis_scripts
1269end
1270exports.add_redis_script = add_redis_script
1271
1272local function exec_redis_script(id, params, callback, keys, args)
1273  local redis_args = {}
1274
1275  if not redis_scripts[id] then
1276      logger.errx("cannot find registered script with id %s", id)
1277    return false
1278  end
1279
1280
1281  local script = redis_scripts[id]
1282
1283  if script.fatal_error then
1284    callback(script.fatal_error, nil)
1285    return true
1286  end
1287
1288  if not script.redis_params then
1289    callback('no redis servers defined', nil)
1290    return true
1291  end
1292
1293  local function do_call(can_reload)
1294    local function redis_cb(err, data)
1295      if not err then
1296        callback(err, data)
1297      elseif string.match(err, 'NOSCRIPT') then
1298        -- Schedule restart
1299        script.sha = nil
1300        if can_reload then
1301          table.insert(script.waitq, do_call)
1302          if script.in_flight == 0 then
1303            -- Reload scripts if this has not been initiated yet
1304            if params.task then
1305              load_script_task(script, params.task)
1306            else
1307              load_script_taskless(script, rspamd_config, params.ev_base)
1308            end
1309          end
1310        else
1311          callback(err, data)
1312        end
1313      else
1314        callback(err, data)
1315      end
1316    end
1317
1318    if #redis_args == 0 then
1319      table.insert(redis_args, script.sha)
1320      table.insert(redis_args, tostring(#keys))
1321      for _,k in ipairs(keys) do
1322        table.insert(redis_args, k)
1323      end
1324
1325      if type(args) == 'table' then
1326        for _, a in ipairs(args) do
1327          table.insert(redis_args, a)
1328        end
1329      end
1330    end
1331
1332    if params.task then
1333      if not rspamd_redis_make_request(params.task, script.redis_params,
1334        params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
1335        callback('Cannot make redis request', nil)
1336      end
1337    else
1338      if not redis_make_request_taskless(params.ev_base, rspamd_config,
1339        script.redis_params,
1340        params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
1341        callback('Cannot make redis request', nil)
1342      end
1343    end
1344  end
1345
1346  if script.loaded then
1347    do_call(true)
1348  else
1349    -- Delayed until scripts are loaded
1350    if not params.task then
1351      table.insert(script.waitq, do_call)
1352    else
1353      -- TODO: fix taskfull requests
1354      table.insert(script.waitq, function()
1355        if script.loaded then
1356          do_call(false)
1357        else
1358          callback('NOSCRIPT', nil)
1359        end
1360      end)
1361      load_script_task(script, params.task)
1362    end
1363  end
1364
1365  return true
1366end
1367
1368exports.exec_redis_script = exec_redis_script
1369
1370local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
1371  if not redis_params then
1372    return false,nil
1373  end
1374
1375  local rspamd_redis = require "rspamd_redis"
1376  local addr
1377
1378  if key then
1379    if is_write then
1380      addr = redis_params['write_servers']:get_upstream_by_hash(key)
1381    else
1382      addr = redis_params['read_servers']:get_upstream_by_hash(key)
1383    end
1384  else
1385    if is_write then
1386      addr = redis_params['write_servers']:get_upstream_master_slave(key)
1387    else
1388      addr = redis_params['read_servers']:get_upstream_round_robin(key)
1389    end
1390  end
1391
1392  if not addr then
1393    logger.errx(cfg, 'cannot select server to make redis request')
1394  end
1395
1396  local options = {
1397    host = addr:get_addr(),
1398    timeout = redis_params['timeout'],
1399    config = cfg or rspamd_config,
1400    ev_base = ev_base or rspamadm_ev_base,
1401    session = redis_params.session or rspamadm_session
1402  }
1403
1404  for k,v in pairs(redis_params) do
1405    options[k] = v
1406  end
1407
1408  if not options.config then
1409    logger.errx('config is not set')
1410    return false,nil,addr
1411  end
1412
1413  if not options.ev_base then
1414    logger.errx('ev_base is not set')
1415    return false,nil,addr
1416  end
1417
1418  if not options.session then
1419    logger.errx('session is not set')
1420    return false,nil,addr
1421  end
1422
1423  local ret,conn = rspamd_redis.connect_sync(options)
1424  if not ret then
1425    logger.errx('cannot execute redis request: %s', conn)
1426    addr:fail()
1427
1428    return false,nil,addr
1429  end
1430
1431  if conn then
1432    if redis_params['password'] then
1433      conn:add_cmd('AUTH', {redis_params['password']})
1434    end
1435
1436    if redis_params['db'] then
1437      conn:add_cmd('SELECT', {tostring(redis_params['db'])})
1438    elseif redis_params['dbname'] then
1439      conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
1440    end
1441  end
1442
1443  return ret,conn,addr
1444end
1445
1446exports.redis_connect_sync = redis_connect_sync
1447
1448--[[[
1449-- @function lua_redis.request(redis_params, attrs, req)
1450-- Sends a request to Redis synchronously with coroutines or asynchronously using
1451-- a callback (modern API)
1452-- @param redis_params a table of redis server parameters
1453-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
1454-- @param req a table of request: a command + command options
1455-- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
1456--]]
1457
1458exports.request = function(redis_params, attrs, req)
1459  local lua_util = require "lua_util"
1460
1461  if not attrs or not redis_params or not req then
1462    logger.errx('invalid arguments for redis request')
1463    return false,nil,nil
1464  end
1465
1466  if not (attrs.task or (attrs.config and attrs.ev_base)) then
1467    logger.errx('invalid attributes for redis request')
1468    return false,nil,nil
1469  end
1470
1471  local opts = lua_util.shallowcopy(attrs)
1472
1473  local log_obj = opts.task or opts.config
1474
1475  local addr
1476
1477  if opts.callback then
1478    -- Wrap callback
1479    local callback = opts.callback
1480    local function rspamd_redis_make_request_cb(err, data)
1481      if err then
1482        addr:fail()
1483      else
1484        addr:ok()
1485      end
1486      callback(err, data, addr)
1487    end
1488    opts.callback = rspamd_redis_make_request_cb
1489  end
1490
1491  local rspamd_redis = require "rspamd_redis"
1492  local is_write = opts.is_write
1493
1494  if opts.key then
1495    if is_write then
1496      addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
1497    else
1498      addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
1499    end
1500  else
1501    if is_write then
1502      addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
1503    else
1504      addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
1505    end
1506  end
1507
1508  if not addr then
1509    logger.errx(log_obj, 'cannot select server to make redis request')
1510  end
1511
1512  opts.host = addr:get_addr()
1513  opts.timeout = redis_params.timeout
1514
1515  if type(req) == 'string' then
1516    opts.cmd = req
1517  else
1518    -- XXX: modifies the input table
1519    opts.cmd = table.remove(req, 1);
1520    opts.args = req
1521  end
1522
1523  if redis_params.password then
1524    opts.password = redis_params.password
1525  end
1526
1527  if redis_params.db then
1528    opts.dbname = redis_params.db
1529  end
1530
1531  lutil.debugm(N, 'perform generic request to redis server' ..
1532      ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
1533      opts.timeout, opts.cmd, opts.args)
1534
1535  if opts.callback then
1536    local ret,conn = rspamd_redis.make_request(opts)
1537    if not ret then
1538      logger.errx(log_obj, 'cannot execute redis request')
1539      addr:fail()
1540    end
1541
1542    return ret,conn,addr
1543  else
1544    -- Coroutines version
1545    local ret,conn = rspamd_redis.connect_sync(opts)
1546    if not ret then
1547      logger.errx(log_obj, 'cannot execute redis request')
1548      addr:fail()
1549    else
1550      conn:add_cmd(opts.cmd, opts.args)
1551      return conn:exec()
1552    end
1553    return false,nil,addr
1554  end
1555end
1556
1557--[[[
1558-- @function lua_redis.connect(redis_params, attrs)
1559-- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
1560-- @param redis_params a table of redis server parameters
1561-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
1562-- @return {result,connection,address} boolean result, connection object, redis server address
1563--]]
1564
1565exports.connect = function(redis_params, attrs)
1566  local lua_util = require "lua_util"
1567
1568  if not attrs or not redis_params then
1569    logger.errx('invalid arguments for redis connect')
1570    return false,nil,nil
1571  end
1572
1573  if not (attrs.task or (attrs.config and attrs.ev_base)) then
1574    logger.errx('invalid attributes for redis connect')
1575    return false,nil,nil
1576  end
1577
1578  local opts = lua_util.shallowcopy(attrs)
1579
1580  local log_obj = opts.task or opts.config
1581
1582  local addr
1583
1584  if opts.callback then
1585    -- Wrap callback
1586    local callback = opts.callback
1587    local function rspamd_redis_make_request_cb(err, data)
1588      if err then
1589        addr:fail()
1590      else
1591        addr:ok()
1592      end
1593      callback(err, data, addr)
1594    end
1595    opts.callback = rspamd_redis_make_request_cb
1596  end
1597
1598  local rspamd_redis = require "rspamd_redis"
1599  local is_write = opts.is_write
1600
1601  if opts.key then
1602    if is_write then
1603      addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
1604    else
1605      addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
1606    end
1607  else
1608    if is_write then
1609      addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
1610    else
1611      addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
1612    end
1613  end
1614
1615  if not addr then
1616    logger.errx(log_obj, 'cannot select server to make redis connect')
1617  end
1618
1619  opts.host = addr:get_addr()
1620  opts.timeout = redis_params.timeout
1621
1622  if redis_params.password then
1623    opts.password = redis_params.password
1624  end
1625
1626  if redis_params.db then
1627    opts.dbname = redis_params.db
1628  end
1629
1630  if opts.callback then
1631    local ret,conn = rspamd_redis.connect(opts)
1632    if not ret then
1633      logger.errx(log_obj, 'cannot execute redis connect')
1634      addr:fail()
1635    end
1636
1637    return ret,conn,addr
1638  else
1639    -- Coroutines version
1640    local ret,conn = rspamd_redis.connect_sync(opts)
1641    if not ret then
1642      logger.errx(log_obj, 'cannot execute redis connect')
1643      addr:fail()
1644    else
1645      return true,conn,addr
1646    end
1647
1648    return false,nil,addr
1649  end
1650end
1651
1652local redis_prefixes = {}
1653
1654--[[[
1655-- @function lua_redis.register_prefix(prefix, module, description[, optional])
1656-- Register new redis prefix for documentation purposes
1657-- @param {string} prefix string prefix
1658-- @param {string} module module name
1659-- @param {string} description prefix description
1660-- @param {table} optional optional kv pairs (e.g. pattern)
1661--]]
1662local function register_prefix(prefix, module, description, optional)
1663  local pr = {
1664    module = module,
1665    description = description
1666  }
1667
1668  if optional and type(optional) == 'table' then
1669    for k,v in pairs(optional) do
1670      pr[k] = v
1671    end
1672  end
1673
1674  redis_prefixes[prefix] = pr
1675end
1676
1677exports.register_prefix = register_prefix
1678
1679--[[[
1680-- @function lua_redis.prefixes([mname])
1681-- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
1682--]]
1683exports.prefixes = function(mname)
1684  if not mname then
1685    return redis_prefixes
1686  else
1687    local fun = require "fun"
1688
1689    return fun.totable(fun.filter(function(_, data) return data.module == mname end,
1690        redis_prefixes))
1691  end
1692end
1693
1694return exports
1695