1--[[
2Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15]]--
16
17
18local rspamd_logger = require "rspamd_logger"
19local ansicolors = require "ansicolors"
20local ucl = require "ucl"
21local argparse = require "argparse"
22local fun = require "fun"
23local rspamd_http = require "rspamd_http"
24local cr = require "rspamd_cryptobox"
25
26local parser = argparse()
27    :name "rspamadm vault"
28    :description "Perform Hashicorp Vault management"
29    :help_description_margin(32)
30    :command_target("command")
31    :require_command(true)
32
33parser:flag "-s --silent"
34      :description "Do not output extra information"
35parser:option "-a --addr"
36      :description "Vault address (if not defined in VAULT_ADDR env)"
37parser:option "-t --token"
38      :description "Vault token (not recommended, better define VAULT_TOKEN env)"
39parser:option "-p --path"
40      :description "Path to work with in the vault"
41      :default "dkim"
42parser:option "-o --output"
43      :description "Output format ('ucl', 'json', 'json-compact', 'yaml')"
44      :argname("<type>")
45      :convert {
46        ucl = "ucl",
47        json = "json",
48        ['json-compact'] = "json-compact",
49        yaml = "yaml",
50      }
51    :default "ucl"
52
53parser:command "list ls l"
54    :description "List elements in the vault"
55
56local show = parser:command "show get"
57      :description "Extract element from the vault"
58show:argument "domain"
59      :description "Domain to create key for"
60      :args "+"
61
62local delete = parser:command "delete del rm remove"
63      :description "Delete element from the vault"
64delete:argument "domain"
65    :description "Domain to create delete key(s) for"
66    :args "+"
67
68
69local newkey = parser:command "newkey new create"
70                     :description "Add new key to the vault"
71newkey:argument "domain"
72      :description "Domain to create key for"
73      :args "+"
74newkey:option "-s --selector"
75      :description "Selector to use"
76      :count "?"
77newkey:option "-A --algorithm"
78      :argname("<type>")
79      :convert {
80        rsa = "rsa",
81        ed25519 = "ed25519",
82        eddsa = "ed25519",
83      }
84      :default "rsa"
85newkey:option "-b --bits"
86      :argname("<nbits>")
87      :convert(tonumber)
88      :default "1024"
89newkey:option "-x --expire"
90      :argname("<days>")
91      :convert(tonumber)
92newkey:flag "-r --rewrite"
93
94local roll = parser:command "roll rollover"
95                   :description "Perform keys rollover"
96roll:argument "domain"
97    :description "Domain to roll key(s) for"
98    :args "+"
99roll:option "-T --ttl"
100    :description "Validity period for old keys (days)"
101    :convert(tonumber)
102    :default "1"
103roll:flag "-r --remove-expired"
104    :description "Remove expired keys"
105roll:option "-x --expire"
106    :argname("<days>")
107    :convert(tonumber)
108
109local function printf(fmt, ...)
110  if fmt then
111    io.write(rspamd_logger.slog(fmt, ...))
112  end
113  io.write('\n')
114end
115
116local function maybe_printf(opts, fmt, ...)
117  if not opts.silent then
118    printf(fmt, ...)
119  end
120end
121
122local function highlight(str, color)
123  return ansicolors[color or 'white'] .. str .. ansicolors.reset
124end
125
126local function vault_url(opts, path)
127  if path then
128    return string.format('%s/v1/%s/%s', opts.addr, opts.path, path)
129  end
130
131  return string.format('%s/v1/%s', opts.addr, opts.path)
132end
133
134local function is_http_error(err, data)
135  return err or (math.floor(data.code / 100) ~= 2)
136end
137
138local function parse_vault_reply(data)
139  local p = ucl.parser()
140  local res,parser_err = p:parse_string(data)
141
142  if not res then
143    return nil,parser_err
144  else
145    return p:get_object(),nil
146  end
147end
148
149local function maybe_print_vault_data(opts, data, func)
150  if data then
151    local res,parser_err = parse_vault_reply(data)
152
153    if not res then
154      printf('vault reply for cannot be parsed: %s', parser_err)
155    else
156      if func then
157        printf(ucl.to_format(func(res), opts.output))
158      else
159        printf(ucl.to_format(res, opts.output))
160      end
161    end
162  else
163    printf('no data received')
164  end
165end
166
167local function print_dkim_txt_record(b64, selector, alg)
168  local labels = {}
169  local prefix = string.format("v=DKIM1; k=%s; p=", alg)
170  b64 = prefix .. b64
171  if #b64 < 255 then
172    labels = {'"' .. b64 .. '"'}
173  else
174    for sl=1,#b64,256 do
175      table.insert(labels, '"' .. b64:sub(sl, sl + 255) .. '"')
176    end
177  end
178
179  printf("%s._domainkey IN TXT ( %s )", selector,
180      table.concat(labels, "\n\t"))
181end
182
183local function show_handler(opts, domain)
184  local uri = vault_url(opts, domain)
185  local err,data = rspamd_http.request{
186    config = rspamd_config,
187    ev_base = rspamadm_ev_base,
188    session = rspamadm_session,
189    resolver = rspamadm_dns_resolver,
190    url = uri,
191    headers = {
192      ['X-Vault-Token'] = opts.token
193    }
194  }
195
196  if is_http_error(err, data) then
197    printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
198    maybe_print_vault_data(opts, err)
199    os.exit(1)
200  else
201    maybe_print_vault_data(opts, data.content, function(obj)
202      return obj.data.selectors
203    end)
204  end
205end
206
207local function delete_handler(opts, domain)
208  local uri = vault_url(opts, domain)
209  local err,data = rspamd_http.request{
210    config = rspamd_config,
211    ev_base = rspamadm_ev_base,
212    session = rspamadm_session,
213    resolver = rspamadm_dns_resolver,
214    url = uri,
215    method = 'delete',
216    headers = {
217      ['X-Vault-Token'] = opts.token
218    }
219  }
220
221  if is_http_error(err, data) then
222    printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
223    maybe_print_vault_data(opts, err)
224    os.exit(1)
225  else
226    printf('deleted key(s) for %s', domain)
227  end
228end
229
230local function list_handler(opts)
231  local uri = vault_url(opts)
232  local err,data = rspamd_http.request{
233    config = rspamd_config,
234    ev_base = rspamadm_ev_base,
235    session = rspamadm_session,
236    resolver = rspamadm_dns_resolver,
237    url = uri .. '?list=true',
238    headers = {
239      ['X-Vault-Token'] = opts.token
240    }
241  }
242
243  if is_http_error(err, data) then
244    printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
245    maybe_print_vault_data(opts, err)
246    os.exit(1)
247  else
248    maybe_print_vault_data(opts, data.content, function(obj)
249      return obj.data.keys
250    end)
251  end
252end
253
254-- Returns pair privkey+pubkey
255local function genkey(opts)
256  return cr.gen_dkim_keypair(opts.algorithm, opts.bits)
257end
258
259local function create_and_push_key(opts, domain, existing)
260  local uri = vault_url(opts, domain)
261  local sk,pk = genkey(opts)
262
263  local res = {
264    selectors = {
265      [1] = {
266        selector = opts.selector,
267        domain = domain,
268        key = tostring(sk),
269        pubkey = tostring(pk),
270        alg = opts.algorithm,
271        bits = opts.bits or 0,
272        valid_start = os.time(),
273      }
274    }
275  }
276
277  for _,sel in ipairs(existing) do
278    res.selectors[#res.selectors + 1] = sel
279  end
280
281  if opts.expire then
282    res.selectors[1].valid_end = os.time() + opts.expire * 3600 * 24
283  end
284
285  local err,data = rspamd_http.request{
286    config = rspamd_config,
287    ev_base = rspamadm_ev_base,
288    session = rspamadm_session,
289    resolver = rspamadm_dns_resolver,
290    url = uri,
291    method = 'put',
292    headers = {
293      ['Content-Type'] = 'application/json',
294      ['X-Vault-Token'] = opts.token
295    },
296    body = {
297      ucl.to_format(res, 'json-compact')
298    },
299  }
300
301  if is_http_error(err, data) then
302    printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
303    maybe_print_vault_data(opts, data.content)
304    os.exit(1)
305  else
306    maybe_printf(opts,'stored key for: %s, selector: %s', domain, opts.selector)
307    maybe_printf(opts, 'please place the corresponding public key as following:')
308
309    if opts.silent then
310      printf('%s', pk)
311    else
312      print_dkim_txt_record(tostring(pk), opts.selector, opts.algorithm)
313    end
314  end
315end
316
317local function newkey_handler(opts, domain)
318  local uri = vault_url(opts, domain)
319
320  if not opts.selector then
321    opts.selector = string.format('%s-%s', opts.algorithm,
322        os.date("!%Y%m%d"))
323  end
324
325  local err,data = rspamd_http.request{
326    config = rspamd_config,
327    ev_base = rspamadm_ev_base,
328    session = rspamadm_session,
329    resolver = rspamadm_dns_resolver,
330    url = uri,
331    method = 'get',
332    headers = {
333      ['X-Vault-Token'] = opts.token
334    }
335  }
336
337  if is_http_error(err, data) or not data.content then
338    create_and_push_key(opts, domain,{})
339  else
340    -- Key exists
341    local rep = parse_vault_reply(data.content)
342
343    if not rep or not rep.data then
344      printf('cannot parse reply for %s: %s', uri, data.content)
345      os.exit(1)
346    end
347
348    local elts = rep.data.selectors
349
350    if not elts then
351      create_and_push_key(opts, domain,{})
352      os.exit(0)
353    end
354
355    for _,sel in ipairs(elts) do
356      if sel.alg == opts.algorithm then
357        printf('key with the specific algorithm %s is already presented at %s selector for %s domain',
358            opts.algorithm, sel.selector, domain)
359        os.exit(1)
360      else
361        create_and_push_key(opts, domain, elts)
362      end
363    end
364  end
365end
366
367local function roll_handler(opts, domain)
368  local uri = vault_url(opts, domain)
369  local res = {
370    selectors = {}
371  }
372
373  local err,data = rspamd_http.request{
374    config = rspamd_config,
375    ev_base = rspamadm_ev_base,
376    session = rspamadm_session,
377    resolver = rspamadm_dns_resolver,
378    url = uri,
379    method = 'get',
380    headers = {
381      ['X-Vault-Token'] = opts.token
382    }
383  }
384
385  if is_http_error(err, data) or not data.content then
386    printf("No keys to roll for domain %s", domain)
387    os.exit(1)
388  else
389    local rep = parse_vault_reply(data.content)
390
391    if not rep or not rep.data then
392      printf('cannot parse reply for %s: %s', uri, data.content)
393      os.exit(1)
394    end
395
396    local elts = rep.data.selectors
397
398    if not elts then
399      printf("No keys to roll for domain %s", domain)
400      os.exit(1)
401    end
402
403    local nkeys = {} -- indexed by algorithm
404
405    local function insert_key(sel, add_expire)
406      if not nkeys[sel.alg] then
407        nkeys[sel.alg] = {}
408      end
409
410      if add_expire then
411        sel.valid_end = os.time() + opts.ttl * 3600 * 24
412      end
413
414      table.insert(nkeys[sel.alg], sel)
415    end
416
417    for _,sel in ipairs(elts) do
418      if sel.valid_end and sel.valid_end < os.time() then
419        if not opts.remove_expired then
420          insert_key(sel, false)
421        else
422          maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"',
423              domain, sel.selector, os.date('%c', sel.valid_end))
424        end
425      else
426        insert_key(sel, true)
427      end
428    end
429
430    -- Now we need to ensure that all but one selectors have either expired or just a single key
431    for alg,keys in pairs(nkeys) do
432      table.sort(keys, function(k1, k2)
433        if k1.valid_end and k2.valid_end then
434          return k1.valid_end > k2.valid_end
435        elseif k1.valid_end then
436          return true
437        elseif k2.valid_end then
438          return false
439        end
440        return false
441      end)
442      -- Exclude the key with the highest expiration date and examine the rest
443      if not (#keys == 1 or fun.all(function(k)
444            return k.valid_end and k.valid_end < os.time()
445          end, fun.tail(keys))) then
446        printf('bad keys list for %s and %s algorithm', domain, alg)
447        fun.each(function(k)
448          if not k.valid_end then
449            printf('selector %s, algorithm %s has a key with no expire',
450                k.selector, k.alg)
451          elseif k.valid_end >= os.time() then
452            printf('selector %s, algorithm %s has a key that not yet expired: %s',
453                k.selector, k.alg, os.date('%c', k.valid_end))
454          end
455        end, fun.tail(keys))
456        os.exit(1)
457      end
458      -- Do not create new keys, if we only want to remove expired keys
459      if not opts.remove_expired then
460        -- OK to process
461        -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key>
462        local sk,pk = genkey({algorithm = alg, bits = keys[1].bits})
463        local selector = string.format('%s-%s', alg,
464            os.date("!%Y%m%d"))
465
466        if selector == keys[1].selector then
467          selector = selector .. '-1'
468        end
469        local nelt = {
470          selector = selector,
471          domain = domain,
472          key = tostring(sk),
473          pubkey = tostring(pk),
474          alg = alg,
475          bits = keys[1].bits,
476          valid_start = os.time(),
477        }
478
479        if opts.expire then
480          nelt.valid_end = os.time() + opts.expire * 3600 * 24
481        end
482
483        table.insert(res.selectors, nelt)
484      end
485      for _,k in ipairs(keys) do
486        table.insert(res.selectors, k)
487      end
488    end
489  end
490
491  -- We can now store res in the vault
492  err,data = rspamd_http.request{
493    config = rspamd_config,
494    ev_base = rspamadm_ev_base,
495    session = rspamadm_session,
496    resolver = rspamadm_dns_resolver,
497    url = uri,
498    method = 'put',
499    headers = {
500      ['Content-Type'] = 'application/json',
501      ['X-Vault-Token'] = opts.token
502    },
503    body = {
504      ucl.to_format(res, 'json-compact')
505    },
506  }
507
508  if is_http_error(err, data) then
509    printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code)
510    maybe_print_vault_data(opts, data.content)
511    os.exit(1)
512  else
513    for _,key in ipairs(res.selectors) do
514      if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24  then
515        maybe_printf(opts,'rolled key for: %s, new selector: %s', domain, key.selector)
516        maybe_printf(opts, 'please place the corresponding public key as following:')
517
518        if opts.silent then
519          printf('%s', key.pubkey)
520        else
521          print_dkim_txt_record(key.pubkey, key.selector, key.alg)
522        end
523
524      end
525    end
526
527    maybe_printf(opts, 'your old keys will be valid until %s',
528        os.date('%c', os.time() + opts.ttl * 3600 * 24))
529  end
530end
531
532local function handler(args)
533  local opts = parser:parse(args)
534
535  if not opts.addr then
536    opts.addr = os.getenv('VAULT_ADDR')
537  end
538
539  if not opts.token then
540    opts.token = os.getenv('VAULT_TOKEN')
541  else
542    maybe_printf(opts, 'defining token via command line is insecure, define it via environment variable %s',
543        highlight('VAULT_TOKEN', 'red'))
544  end
545
546  if not opts.token or not opts.addr then
547    printf('no token or/and vault addr has been specified, exiting')
548    os.exit(1)
549  end
550
551  local command = opts.command
552
553  if command == 'list' then
554    list_handler(opts)
555  elseif command == 'show' then
556    fun.each(function(d) show_handler(opts, d) end, opts.domain)
557  elseif command == 'newkey' then
558    fun.each(function(d) newkey_handler(opts, d) end, opts.domain)
559  elseif command == 'roll' then
560    fun.each(function(d) roll_handler(opts, d) end, opts.domain)
561  elseif command == 'delete' then
562    fun.each(function(d) delete_handler(opts, d) end, opts.domain)
563  else
564    parser:error(string.format('command %s is not implemented', command))
565  end
566end
567
568return {
569  handler = handler,
570  description = parser._description,
571  name = 'vault'
572}
573