1local Parser, Command, Argument, Option
2
3do -- Create classes with setters
4   local class = require "30log"
5
6   local function add_setters(cl, fields)
7      for field, setter in pairs(fields) do
8         cl[field] = function(self, value)
9            setter(self, value)
10            self["_"..field] = value
11            return self
12         end
13      end
14
15      cl.__init = function(self, ...)
16         return self(...)
17      end
18
19      cl.__call = function(self, ...)
20         local name_or_options
21
22         for i=1, select("#", ...) do
23            name_or_options = select(i, ...)
24
25            if type(name_or_options) == "string" then
26               if self._aliases then
27                  table.insert(self._aliases, name_or_options)
28               end
29
30               if not self._aliases or not self._name then
31                  self._name = name_or_options
32               end
33            elseif type(name_or_options) == "table" then
34               for field, setter in pairs(fields) do
35                  if name_or_options[field] ~= nil then
36                     self[field](self, name_or_options[field])
37                  end
38               end
39            end
40         end
41
42         return self
43      end
44
45      return cl
46   end
47
48   local typecheck = setmetatable({}, {
49      __index = function(self, type_)
50         local typechecker_factory = function(field)
51            return function(_, value)
52               if type(value) ~= type_ then
53                  error(("bad field '%s' (%s expected, got %s)"):format(field, type_, type(value)))
54               end
55            end
56         end
57
58         self[type_] = typechecker_factory
59         return typechecker_factory
60      end
61   })
62
63   local function aliased_name(self, name)
64      typecheck.string "name" (self, name)
65
66      table.insert(self._aliases, name)
67   end
68
69   local function aliased_aliases(self, aliases)
70      typecheck.table "aliases" (self, aliases)
71
72      if not self._name then
73         self._name = aliases[1]
74      end
75   end
76
77   local function parse_boundaries(boundaries)
78      if tonumber(boundaries) then
79         return tonumber(boundaries), tonumber(boundaries)
80      end
81
82      if boundaries == "*" then
83         return 0, math.huge
84      end
85
86      if boundaries == "+" then
87         return 1, math.huge
88      end
89
90      if boundaries == "?" then
91         return 0, 1
92      end
93
94      if boundaries:match "^%d+%-%d+$" then
95         local min, max = boundaries:match "^(%d+)%-(%d+)$"
96         return tonumber(min), tonumber(max)
97      end
98
99      if boundaries:match "^%d+%+$" then
100         local min = boundaries:match "^(%d+)%+$"
101         return tonumber(min), math.huge
102      end
103   end
104
105   local function boundaries(field)
106      return function(self, value)
107         local min, max = parse_boundaries(value)
108
109         if not min then
110            error(("bad field '%s'"):format(field))
111         end
112
113         self["_min"..field], self["_max"..field] = min, max
114      end
115   end
116
117   local function convert(self, value)
118      if type(value) ~= "function" then
119         if type(value) ~= "table" then
120            error(("bad field 'convert' (function or table expected, got %s)"):format(type(value)))
121         end
122      end
123   end
124
125   local function argname(self, value)
126      if type(value) ~= "string" then
127         if type(value) ~= "table" then
128            error(("bad field 'argname' (string or table expected, got %s)"):format(type(value)))
129         end
130      end
131   end
132
133   local function add_help(self, param)
134      if self._has_help then
135         table.remove(self._options)
136         self._has_help = false
137      end
138
139      if param then
140         local help = self:flag()
141            :description "Show this help message and exit. "
142            :action(function()
143               io.stdout:write(self:get_help() .. "\r\n")
144               os.exit(0)
145            end)(param)
146
147         if not help._name then
148            help "-h" "--help"
149         end
150
151         self._has_help = true
152      end
153   end
154
155   Parser = add_setters(class {
156      __name = "Parser",
157      _arguments = {},
158      _options = {},
159      _commands = {},
160      _mutexes = {},
161      _require_command = true
162   }, {
163      name = typecheck.string "name",
164      description = typecheck.string "description",
165      epilog = typecheck.string "epilog",
166      require_command = typecheck.boolean "require_command",
167      usage = typecheck.string "usage",
168      help = typecheck.string "help",
169      add_help = add_help
170   })
171
172   Command = add_setters(Parser:extends {
173      __name = "Command",
174      _aliases = {}
175   }, {
176      name = aliased_name,
177      aliases = aliased_aliases,
178      description = typecheck.string "description",
179      epilog = typecheck.string "epilog",
180      target = typecheck.string "target",
181      require_command = typecheck.boolean "require_command",
182      action = typecheck["function"] "action",
183      usage = typecheck.string "usage",
184      help = typecheck.string "help",
185      add_help = add_help
186   })
187
188   Argument = add_setters(class {
189      __name = "Argument",
190      _minargs = 1,
191      _maxargs = 1,
192      _mincount = 1,
193      _maxcount = 1,
194      _defmode = "unused"
195   }, {
196      name = typecheck.string "name",
197      description = typecheck.string "description",
198      target = typecheck.string "target",
199      args = boundaries "args",
200      default = typecheck.string "default",
201      defmode = typecheck.string "defmode",
202      convert = convert,
203      argname = argname
204   })
205
206   Option = add_setters(Argument:extends {
207      __name = "Option",
208      _aliases = {},
209      _mincount = 0,
210      _overwrite = true
211   }, {
212      name = aliased_name,
213      aliases = aliased_aliases,
214      description = typecheck.string "description",
215      target = typecheck.string "target",
216      args = boundaries "args",
217      count = boundaries "count",
218      default = typecheck.string "default",
219      defmode = typecheck.string "defmode",
220      convert = convert,
221      overwrite = typecheck.boolean "overwrite",
222      action = typecheck["function"] "action",
223      argname = argname
224   })
225end
226
227function Argument:_get_argument_list()
228   local buf = {}
229   local i = 1
230
231   while i <= math.min(self._minargs, 3) do
232      local argname = self:_get_argname_i(i)
233
234      if self._default and self._defmode:find "a" then
235         argname = "[" .. argname .. "]"
236      end
237
238      table.insert(buf, argname)
239      i = i+1
240   end
241
242   while i <= math.min(self._maxargs, 3) do
243      table.insert(buf, "[" .. self:_get_argname_i(i) .. "]")
244      i = i+1
245
246      if self._maxargs == math.huge then
247         break
248      end
249   end
250
251   if i < self._maxargs then
252      table.insert(buf, "...")
253   end
254
255   return buf
256end
257
258function Argument:_get_usage()
259   local usage = table.concat(self:_get_argument_list(), " ")
260
261   if self._default and self._defmode:find "u" then
262      if self._maxargs > 1 or (self._minargs == 1 and not self._defmode:find "a") then
263         usage = "[" .. usage .. "]"
264      end
265   end
266
267   return usage
268end
269
270function Argument:_get_type()
271   if self._maxcount == 1 then
272      if self._maxargs == 0 then
273         return "flag"
274      elseif self._maxargs == 1 and (self._minargs == 1 or self._mincount == 1) then
275         return "arg"
276      else
277         return "multiarg"
278      end
279   else
280      if self._maxargs == 0 then
281         return "counter"
282      elseif self._maxargs == 1 and self._minargs == 1 then
283         return "multicount"
284      else
285         return "twodimensional"
286      end
287   end
288end
289
290function Argument:_get_argname_i(i)
291   local argname = self:_get_argname()
292
293   if type(argname) == "table" then
294      return argname[i]
295   else
296      return argname
297   end
298end
299
300function Argument:_get_argname()
301   return self._argname or ("<"..self._name..">")
302end
303
304function Option:_get_argname()
305   return self._argname or ("<"..self:_get_target()..">")
306end
307
308function Argument:_get_label()
309   return self._name
310end
311
312function Option:_get_label()
313   local variants = {}
314   local argument_list = self:_get_argument_list()
315   table.insert(argument_list, 1, nil)
316
317   for _, alias in ipairs(self._aliases) do
318      argument_list[1] = alias
319      table.insert(variants, table.concat(argument_list, " "))
320   end
321
322   return table.concat(variants, ", ")
323end
324
325function Command:_get_label()
326   return table.concat(self._aliases, ", ")
327end
328
329function Argument:_get_description()
330   if self._default then
331      if self._description then
332         return ("%s (default: %s)"):format(self._description, self._default)
333      else
334         return ("default: %s"):format(self._default)
335      end
336   else
337      return self._description or ""
338   end
339end
340
341function Command:_get_description()
342   return self._description or ""
343end
344
345function Option:_get_usage()
346   local usage = self:_get_argument_list()
347   table.insert(usage, 1, self._name)
348   usage = table.concat(usage, " ")
349
350   if self._mincount == 0 or self._default then
351      usage = "[" .. usage .. "]"
352   end
353
354   return usage
355end
356
357function Option:_get_target()
358   if self._target then
359      return self._target
360   end
361
362   for _, alias in ipairs(self._aliases) do
363      if alias:sub(1, 1) == alias:sub(2, 2) then
364         return alias:sub(3)
365      end
366   end
367
368   return self._name:sub(2)
369end
370
371function Parser:_get_fullname()
372   local parent = self._parent
373   local buf = {self._name}
374
375   while parent do
376      table.insert(buf, 1, parent._name)
377      parent = parent._parent
378   end
379
380   return table.concat(buf, " ")
381end
382
383function Parser:_update_charset(charset)
384   charset = charset or {}
385
386   for _, command in ipairs(self._commands) do
387      command:_update_charset(charset)
388   end
389
390   for _, option in ipairs(self._options) do
391      for _, alias in ipairs(option._aliases) do
392         charset[alias:sub(1, 1)] = true
393      end
394   end
395
396   return charset
397end
398
399function Parser:argument(...)
400   local argument = Argument:new(...)
401   table.insert(self._arguments, argument)
402   return argument
403end
404
405function Parser:option(...)
406   local option = Option:new(...)
407
408   if self._has_help then
409      table.insert(self._options, #self._options, option)
410   else
411      table.insert(self._options, option)
412   end
413
414   return option
415end
416
417function Parser:flag(...)
418   return self:option():args(0)(...)
419end
420
421function Parser:command(...)
422   local command = Command:new():add_help(true)(...)
423   command._parent = self
424   table.insert(self._commands, command)
425   return command
426end
427
428function Parser:mutex(...)
429   local options = {...}
430
431   for i, option in ipairs(options) do
432      assert(getmetatable(option) == Option, ("bad argument #%d to 'mutex' (Option expected)"):format(i))
433   end
434
435   table.insert(self._mutexes, options)
436   return self
437end
438
439local max_usage_width = 70
440local usage_welcome = "Usage: "
441
442function Parser:get_usage()
443   if self._usage then
444      return self._usage
445   end
446
447   local lines = {usage_welcome .. self:_get_fullname()}
448
449   local function add(s)
450      if #lines[#lines]+1+#s <= max_usage_width then
451         lines[#lines] = lines[#lines] .. " " .. s
452      else
453         lines[#lines+1] = (" "):rep(#usage_welcome) .. s
454      end
455   end
456
457   -- set of mentioned elements
458   local used = {}
459
460   for _, mutex in ipairs(self._mutexes) do
461      local buf = {}
462
463      for _, option in ipairs(mutex) do
464         table.insert(buf, option:_get_usage())
465         used[option] = true
466      end
467
468      add("(" .. table.concat(buf, " | ") .. ")")
469   end
470
471   for _, elements in ipairs{self._options, self._arguments} do
472      for _, element in ipairs(elements) do
473         if not used[element] then
474            add(element:_get_usage())
475         end
476      end
477   end
478
479   if #self._commands > 0 then
480      if self._require_command then
481         add("<command>")
482      else
483         add("[<command>]")
484      end
485
486      add("...")
487   end
488
489   return table.concat(lines, "\r\n")
490end
491
492local margin_len = 3
493local margin_len2 = 25
494local margin = (" "):rep(margin_len)
495local margin2 = (" "):rep(margin_len2)
496
497local function make_two_columns(s1, s2)
498   if s2 == "" then
499      return margin .. s1
500   end
501
502   s2 = s2:gsub("[\r\n][\r\n]?", function(sub)
503      if #sub == 1 or sub == "\r\n" then
504         return "\r\n" .. margin2
505      else
506         return "\r\n\r\n" .. margin2
507      end
508   end)
509
510   if #s1 < (margin_len2-margin_len) then
511      return margin .. s1 .. (" "):rep(margin_len2-margin_len-#s1) .. s2
512   else
513      return margin .. s1 .. "\r\n" .. margin2 .. s2
514   end
515end
516
517function Parser:get_help()
518   if self._help then
519      return self._help
520   end
521
522   local blocks = {self:get_usage()}
523
524   if self._description then
525      table.insert(blocks, self._description)
526   end
527
528   local labels = {"Arguments: ", "Options: ", "Commands: "}
529
530   for i, elements in ipairs{self._arguments, self._options, self._commands} do
531      if #elements > 0 then
532         local buf = {labels[i]}
533
534         for _, element in ipairs(elements) do
535            table.insert(buf, make_two_columns(element:_get_label(), element:_get_description()))
536         end
537
538         table.insert(blocks, table.concat(buf, "\r\n"))
539      end
540   end
541
542   if self._epilog then
543      table.insert(blocks, self._epilog)
544   end
545
546   return table.concat(blocks, "\r\n\r\n")
547end
548
549local function get_tip(context, wrong_name)
550   local context_pool = {}
551   local possible_name
552   local possible_names = {}
553
554   for name in pairs(context) do
555      for i=1, #name do
556         possible_name = name:sub(1, i-1) .. name:sub(i+1)
557
558         if not context_pool[possible_name] then
559            context_pool[possible_name] = {}
560         end
561
562         table.insert(context_pool[possible_name], name)
563      end
564   end
565
566   for i=1, #wrong_name+1 do
567      possible_name = wrong_name:sub(1, i-1) .. wrong_name:sub(i+1)
568
569      if context[possible_name] then
570         possible_names[possible_name] = true
571      elseif context_pool[possible_name] then
572         for _, name in ipairs(context_pool[possible_name]) do
573            possible_names[name] = true
574         end
575      end
576   end
577
578   local first = next(possible_names)
579   if first then
580      if next(possible_names, first) then
581         local possible_names_arr = {}
582
583         for name in pairs(possible_names) do
584            table.insert(possible_names_arr, "'" .. name .. "'")
585         end
586
587         table.sort(possible_names_arr)
588         return "\r\nDid you mean one of these: " .. table.concat(possible_names_arr, " ") .. "?"
589      else
590         return "\r\nDid you mean '" .. first .. "'?"
591      end
592   else
593      return ""
594   end
595end
596
597local function plural(x)
598   if x == 1 then
599      return ""
600   end
601
602   return "s"
603end
604
605local default_cmdline = arg or {}
606
607function Parser:_parse(args, errhandler)
608   args = args or default_cmdline
609   local parser
610   local charset
611   local options = {}
612   local arguments = {}
613   local commands
614   local option_mutexes = {}
615   local used_mutexes = {}
616   local opt_context = {}
617   local com_context
618   local result = {}
619   local invocations = {}
620   local passed = {}
621   local cur_option
622   local cur_arg_i = 1
623   local cur_arg
624   local targets = {}
625
626   local function error_(fmt, ...)
627      return errhandler(parser, fmt:format(...))
628   end
629
630   local function assert_(assertion, ...)
631      return assertion or error_(...)
632   end
633
634   local function convert(element, data)
635      if element._convert then
636         local ok, err
637
638         if type(element._convert) == "function" then
639            ok, err = element._convert(data)
640         else
641            ok, err = element._convert[data]
642         end
643
644         assert_(ok ~= nil, "%s", err or "malformed argument '" .. data .. "'")
645         data = ok
646      end
647
648      return data
649   end
650
651   local invoke, pass, close
652
653   function invoke(element)
654      local overwrite = false
655
656      if invocations[element] == element._maxcount then
657         if element._overwrite then
658            overwrite = true
659         else
660            error_("option '%s' must be used at most %d time%s", element._name, element._maxcount, plural(element._maxcount))
661         end
662      else
663         invocations[element] = invocations[element]+1
664      end
665
666      passed[element] = 0
667      local type_ = element:_get_type()
668      local target = targets[element]
669
670      if type_ == "flag" then
671         result[target] = true
672      elseif type_ == "multiarg" then
673         result[target] = {}
674      elseif type_ == "counter" then
675         if not overwrite then
676            result[target] = result[target]+1
677         end
678      elseif type_ == "multicount" then
679         if overwrite then
680            table.remove(result[target], 1)
681         end
682      elseif type_ == "twodimensional" then
683         table.insert(result[target], {})
684
685         if overwrite then
686            table.remove(result[target], 1)
687         end
688      end
689
690      if element._maxargs == 0 then
691         close(element)
692      end
693   end
694
695   function pass(element, data)
696      passed[element] = passed[element]+1
697      data = convert(element, data)
698      local type_ = element:_get_type()
699      local target = targets[element]
700
701      if type_ == "arg" then
702         result[target] = data
703      elseif type_ == "multiarg" or type_ == "multicount" then
704         table.insert(result[target], data)
705      elseif type_ == "twodimensional" then
706         table.insert(result[target][#result[target]], data)
707      end
708
709      if passed[element] == element._maxargs then
710         close(element)
711      end
712   end
713
714   local function complete_invocation(element)
715      while passed[element] < element._minargs do
716         pass(element, element._default)
717      end
718   end
719
720   function close(element)
721      if passed[element] < element._minargs then
722         if element._default and element._defmode:find "a" then
723            complete_invocation(element)
724         else
725            error_("too few arguments")
726         end
727      else
728         if element == cur_option then
729            cur_option = nil
730         elseif element == cur_arg then
731            cur_arg_i = cur_arg_i+1
732            cur_arg = arguments[cur_arg_i]
733         end
734      end
735   end
736
737   local function switch(p)
738      parser = p
739
740      for _, option in ipairs(parser._options) do
741         table.insert(options, option)
742
743         for _, alias in ipairs(option._aliases) do
744            opt_context[alias] = option
745         end
746
747         local type_ = option:_get_type()
748         targets[option] = option:_get_target()
749
750         if type_ == "counter" then
751            result[targets[option]] = 0
752         elseif type_ == "multicount" or type_ == "twodimensional" then
753            result[targets[option]] = {}
754         end
755
756         invocations[option] = 0
757      end
758
759      for _, mutex in ipairs(parser._mutexes) do
760         for _, option in ipairs(mutex) do
761            if not option_mutexes[option] then
762               option_mutexes[option] = {mutex}
763            else
764               table.insert(option_mutexes[option], mutex)
765            end
766         end
767      end
768
769      for _, argument in ipairs(parser._arguments) do
770         table.insert(arguments, argument)
771         invocations[argument] = 0
772         targets[argument] = argument._target or argument._name
773         invoke(argument)
774      end
775
776      cur_arg = arguments[cur_arg_i]
777      commands = parser._commands
778      com_context = {}
779
780      for _, command in ipairs(commands) do
781         targets[command] = command._target or command._name
782
783         for _, alias in ipairs(command._aliases) do
784            com_context[alias] = command
785         end
786      end
787   end
788
789   local function get_option(name)
790      return assert_(opt_context[name], "unknown option '%s'%s", name, get_tip(opt_context, name))
791   end
792
793   local function do_action(element)
794      if element._action then
795         element._action()
796      end
797   end
798
799   local function handle_argument(data)
800      if cur_option then
801         pass(cur_option, data)
802      elseif cur_arg then
803         pass(cur_arg, data)
804      else
805         local com = com_context[data]
806
807         if not com then
808            if #commands > 0 then
809               error_("unknown command '%s'%s", data, get_tip(com_context, data))
810            else
811               error_("too many arguments")
812            end
813         else
814            result[targets[com]] = true
815            do_action(com)
816            switch(com)
817         end
818      end
819   end
820
821   local function handle_option(data)
822      if cur_option then
823         close(cur_option)
824      end
825
826      cur_option = opt_context[data]
827
828      if option_mutexes[cur_option] then
829         for _, mutex in ipairs(option_mutexes[cur_option]) do
830            if used_mutexes[mutex] and used_mutexes[mutex] ~= cur_option then
831               error_("option '%s' can not be used together with option '%s'", data, used_mutexes[mutex]._name)
832            else
833               used_mutexes[mutex] = cur_option
834            end
835         end
836      end
837
838      do_action(cur_option)
839      invoke(cur_option)
840   end
841
842   local function mainloop()
843      local handle_options = true
844
845      for _, data in ipairs(args) do
846         local plain = true
847         local first, name, option
848
849         if handle_options then
850            first = data:sub(1, 1)
851            if charset[first] then
852               if #data > 1 then
853                  plain = false
854                  if data:sub(2, 2) == first then
855                     if #data == 2 then
856                        if cur_option then
857                           close(cur_option)
858                        end
859
860                        handle_options = false
861                     else
862                        local equal = data:find "="
863                        if equal then
864                           name = data:sub(1, equal-1)
865                           option = get_option(name)
866                           assert_(option._maxargs > 0, "option '%s' does not take arguments", name)
867
868                           handle_option(data:sub(1, equal-1))
869                           handle_argument(data:sub(equal+1))
870                        else
871                           get_option(data)
872                           handle_option(data)
873                        end
874                     end
875                  else
876                     for i = 2, #data do
877                        name = first .. data:sub(i, i)
878                        option = get_option(name)
879                        handle_option(name)
880
881                        if i ~= #data and option._minargs > 0 then
882                           handle_argument(data:sub(i+1))
883                           break
884                        end
885                     end
886                  end
887               end
888            end
889         end
890
891         if plain then
892            handle_argument(data)
893         end
894      end
895   end
896
897   switch(self)
898   charset = parser:_update_charset()
899   mainloop()
900
901   if cur_option then
902      close(cur_option)
903   end
904
905   while cur_arg do
906      if passed[cur_arg] == 0 and cur_arg._default and cur_arg._defmode:find "u" then
907         complete_invocation(cur_arg)
908      else
909         close(cur_arg)
910      end
911   end
912
913   if parser._require_command and #commands > 0 then
914      error_("a command is required")
915   end
916
917   for _, option in ipairs(options) do
918      if invocations[option] == 0 then
919         if option._default and option._defmode:find "u" then
920            invoke(option)
921            complete_invocation(option)
922            close(option)
923         end
924      end
925
926      if invocations[option] < option._mincount then
927         if option._default and option._defmode:find "a" then
928            while invocations[option] < option._mincount do
929               invoke(option)
930               close(option)
931            end
932         else
933            error_("option '%s' must be used at least %d time%s", option._name, option._mincount, plural(option._mincount))
934         end
935      end
936   end
937
938   return result
939end
940
941function Parser:error(msg)
942   if _TEST then
943      error(msg)
944   else
945      io.stderr:write(("%s\r\n\r\nError: %s\r\n"):format(self:get_usage(), msg))
946      os.exit(1)
947   end
948end
949
950function Parser:parse(args)
951   return self:_parse(args, Parser.error)
952end
953
954function Parser:pparse(args)
955   local errmsg
956   local ok, result = pcall(function()
957      return self:_parse(args, function(parser, err)
958         errmsg = err
959         return error()
960      end)
961   end)
962
963   if ok then
964      return true, result
965   else
966      assert(errmsg, result)
967      return false, errmsg
968   end
969end
970
971return function(...)
972   return Parser(default_cmdline[0]):add_help(true)(...)
973end
974