1# this module should be kept Python 2.3 compatible
2import re
3import sys
4import time
5import inspect
6import textwrap
7import functools
8import argparse
9from datetime import datetime, date
10from gettext import gettext as _
11
12version = sys.version_info[:2]
13
14if sys.version >= '3':
15    from inspect import getfullargspec
16else:
17    class getfullargspec(object):
18        "A quick and dirty replacement for getfullargspec for Python 2.X"
19        def __init__(self, f):
20            self.args, self.varargs, self.varkw, self.defaults = \
21                inspect.getargspec(f)
22            self.annotations = getattr(f, '__annotations__', {})
23
24
25def to_date(s):
26    """Returns year-month-day"""
27    return date(*time.strptime(s, "%Y-%m-%d")[0:3])
28
29
30def to_datetime(s):
31    """Returns year-month-day hour-minute-second"""
32    return datetime(*time.strptime(s, "%Y-%m-%d %H-%M-%S")[0:6])
33
34
35def getargspec(callableobj):
36    """Given a callable return an object with attributes .args, .varargs,
37    .varkw, .defaults. It tries to do the "right thing" with functions,
38    methods, classes and generic callables."""
39    if inspect.isfunction(callableobj):
40        argspec = getfullargspec(callableobj)
41    elif inspect.ismethod(callableobj):
42        argspec = getfullargspec(callableobj)
43        del argspec.args[0]  # remove first argument
44    elif inspect.isclass(callableobj):
45        if callableobj.__init__ is object.__init__:  # to avoid an error
46            argspec = getfullargspec(lambda self: None)
47        else:
48            argspec = getfullargspec(callableobj.__init__)
49        del argspec.args[0]  # remove first argument
50    elif hasattr(callableobj, '__call__'):
51        argspec = getfullargspec(callableobj.__call__)
52        del argspec.args[0]  # remove first argument
53    else:
54        raise TypeError(_('Could not determine the signature of ') +
55                        str(callableobj))
56    return argspec
57
58
59def annotations(**ann):
60    """
61    Returns a decorator annotating a function with the given annotations.
62    This is a trick to support function annotations in Python 2.X.
63    """
64    def annotate(f):
65        fas = getfullargspec(f)
66        args = fas.args
67        if fas.varargs:
68            args.append(fas.varargs)
69        if fas.varkw:
70            args.append(fas.varkw)
71        for argname in ann:
72            if argname not in args:
73                raise NameError(
74                    _('Annotating non-existing argument: %s') % argname)
75        f.__annotations__ = ann
76        return f
77    return annotate
78
79
80def _annotate(arg, ann, f):
81    try:
82        f.__annotations__[arg] = ann
83    except AttributeError:  # Python 2.7
84        f.__annotations__ = {arg: ann}
85    return f
86
87
88def pos(arg, help=None, type=None, choices=None, metavar=None):
89    """
90    Decorator for annotating positional arguments
91    """
92    return functools.partial(
93        _annotate, arg, (help, 'positional', None, type, choices, metavar))
94
95
96def opt(arg, help=None, type=None, abbrev=None, choices=None, metavar=None):
97    """
98    Decorator for annotating optional arguments
99    """
100    abbrev = abbrev or arg[0]
101    return functools.partial(
102        _annotate, arg, (help, 'option', abbrev, type, choices, metavar))
103
104
105def flg(arg, help=None, abbrev=None):
106    """
107    Decorator for annotating flags
108    """
109    return functools.partial(
110        _annotate, arg, (help, 'flag', abbrev or arg[0], None, None, None))
111
112
113def is_annotation(obj):
114    """
115    An object is an annotation object if it has the attributes
116    help, kind, abbrev, type, choices, metavar.
117    """
118    return (hasattr(obj, 'help') and hasattr(obj, 'kind')
119            and hasattr(obj, 'abbrev') and hasattr(obj, 'type')
120            and hasattr(obj, 'choices') and hasattr(obj, 'metavar'))
121
122
123class Annotation(object):
124    def __init__(self, help=None, kind="positional", abbrev=None, type=None,
125                 choices=None, metavar=None):
126        assert kind in ('positional', 'option', 'flag'), kind
127        if kind == "positional":
128            assert abbrev is None, abbrev
129        self.help = help
130        self.kind = kind
131        self.abbrev = abbrev
132        self.type = type
133        self.choices = choices
134        self.metavar = metavar
135
136    def from_(cls, obj):
137        "Helper to convert an object into an annotation, if needed"
138        if is_annotation(obj):
139            return obj  # do nothing
140        elif inspect.isclass(obj):
141            obj = str(obj)
142        elif iterable(obj):
143            return cls(*obj)
144        return cls(obj)
145    from_ = classmethod(from_)
146
147
148NONE = object()  # sentinel use to signal the absence of a default
149
150PARSER_CFG = getfullargspec(argparse.ArgumentParser.__init__).args[1:]
151# the default arguments accepted by an ArgumentParser object
152
153
154def pconf(obj):
155    """
156    Extracts the configuration of the underlying ArgumentParser from obj
157    """
158    cfg = dict(description=(textwrap.dedent(obj.__doc__.rstrip())
159                            if obj.__doc__ else None),
160               formatter_class=argparse.RawDescriptionHelpFormatter)
161    for name in dir(obj):
162        if name in PARSER_CFG:  # argument of ArgumentParser
163            cfg[name] = getattr(obj, name)
164    return cfg
165
166
167_parser_registry = {}
168
169
170def parser_from(obj, **confparams):
171    """
172    obj can be a callable or an object with a .commands attribute.
173    Returns an ArgumentParser.
174    """
175    try:  # the underlying parser has been generated already
176        return _parser_registry[obj]
177    except KeyError:  # generate a new parser
178        pass
179    conf = pconf(obj).copy()
180    conf.update(confparams)
181    _parser_registry[obj] = parser = ArgumentParser(**conf)
182    parser.obj = obj
183    parser.case_sensitive = confparams.get(
184        'case_sensitive', getattr(obj, 'case_sensitive', True))
185    if hasattr(obj, 'commands') and not inspect.isclass(obj):
186        # a command container instance
187        parser.addsubcommands(obj.commands, obj, 'subcommands')
188    else:
189        parser.populate_from(obj)
190    return parser
191
192
193def _extract_kwargs(args):
194    """
195    Returns two lists: regular args and name=value args
196    """
197    arglist = []
198    kwargs = {}
199    for arg in args:
200        match = re.match(r'([a-zA-Z_]\w*)=', arg)
201        if match:
202            name = match.group(1)
203            kwargs[name] = arg[len(name)+1:]
204        else:
205            arglist.append(arg)
206    return arglist, kwargs
207
208
209def _match_cmd(abbrev, commands, case_sensitive=True):
210    """
211    Extract the command name from an abbreviation or raise a NameError
212    """
213    if not case_sensitive:
214        abbrev = abbrev.upper()
215        commands = [c.upper() for c in commands]
216    perfect_matches = [name for name in commands if name == abbrev]
217    if len(perfect_matches) == 1:
218        return perfect_matches[0]
219    matches = [name for name in commands if name.startswith(abbrev)]
220    n = len(matches)
221    if n == 1:
222        return matches[0]
223    elif n > 1:
224        raise NameError(
225            _('Ambiguous command %r: matching %s' % (abbrev, matches)))
226
227
228class ArgumentParser(argparse.ArgumentParser):
229    """
230    An ArgumentParser with .func and .argspec attributes, and possibly
231    .commands and .subparsers.
232    """
233    case_sensitive = True
234
235    if version < (3, 10):
236        def __init__(self, *args, **kwargs):
237            super(ArgumentParser, self).__init__(*args, **kwargs)
238            if self._action_groups[1].title == _('optional arguments'):
239                self._action_groups[1].title = _('options')
240
241    def alias(self, arg):
242        "Can be overridden to preprocess command-line arguments"
243        return arg
244
245    def consume(self, args):
246        """
247        Call the underlying function with the args. Works also for
248        command containers, by dispatching to the right subparser.
249        """
250        arglist = [self.alias(a) for a in args]
251        cmd = None
252        if hasattr(self, 'subparsers'):
253            subp, cmd = self._extract_subparser_cmd(arglist)
254            if subp is None and cmd is not None:
255                return cmd, self.missing(cmd)
256            elif subp is not None:  # use the subparser
257                self = subp
258        if hasattr(self, 'argspec') and self.argspec.varargs:
259            # ignore unrecognized arguments
260            ns, extraopts = self.parse_known_args(arglist)
261        else:
262            ns, extraopts = self.parse_args(arglist), []  # may raise an exit
263        if not hasattr(self, 'argspec'):
264            raise SystemExit
265        if hasattr(self, 'argspec') and self.argspec.varkw:
266            v = self.argspec.varargs
267            varkw = self.argspec.varkw
268            if v in ns.__dict__:
269                lst = ns.__dict__.pop(v)
270                lst, kwargs = _extract_kwargs(lst)
271                ns.__dict__[v] = lst
272            elif varkw in ns.__dict__:
273                lst = ns.__dict__.pop(varkw)
274                lst, kwargs = _extract_kwargs(lst)
275                ns.__dict__[varkw] = lst
276            if lst and not v:
277                self.error(_('Unrecognized arguments: %s') % arglist)
278        else:
279            kwargs = {}
280        collision = set(self.argspec.args) & set(kwargs)
281        if collision:
282            self.error(
283                _('colliding keyword arguments: %s') % ' '.join(collision))
284        # Correct options with trailing undescores
285        args = [getattr(ns, a.rstrip('_')) for a in self.argspec.args]
286        varargs = getattr(ns, self.argspec.varargs or '', [])
287        return cmd, self.func(*(args + varargs + extraopts), **kwargs)
288
289    def _extract_subparser_cmd(self, arglist):
290        """
291        Extract the right subparser from the first recognized argument
292        """
293        optprefix = self.prefix_chars[0]
294        name_parser_map = self.subparsers._name_parser_map
295        for i, arg in enumerate(arglist):
296            if not arg.startswith(optprefix):
297                cmd = _match_cmd(arg, name_parser_map, self.case_sensitive)
298                del arglist[i]
299                return name_parser_map.get(cmd), cmd or arg
300        return None, None
301
302    def addsubcommands(self, commands, obj, title=None, cmdprefix=''):
303        """
304        Extract a list of subcommands from obj and add them to the parser
305        """
306        if hasattr(obj, cmdprefix) and obj.cmdprefix in self.prefix_chars:
307            raise ValueError(_('The prefix %r is already taken!' % cmdprefix))
308        if not hasattr(self, 'subparsers'):
309            self.subparsers = self.add_subparsers(title=title)
310        elif title:
311            self.add_argument_group(title=title)  # populate ._action_groups
312        prefixlen = len(getattr(obj, 'cmdprefix', ''))
313        add_help = getattr(obj, 'add_help', True)
314        for cmd in commands:
315            func = getattr(obj, cmd[prefixlen:])  # strip the prefix
316            doc = (textwrap.dedent(func.__doc__.rstrip())
317                   if func.__doc__ else None)
318            self.subparsers.add_parser(
319                cmd, add_help=add_help, help=doc, **pconf(func)
320                ).populate_from(func)
321
322    def _set_func_argspec(self, obj):
323        """
324        Extracts the signature from a callable object and adds an .argspec
325        attribute to the parser. Also adds a .func reference to the object.
326        """
327        self.func = obj
328        self.argspec = getargspec(obj)
329        _parser_registry[obj] = self
330
331    def populate_from(self, func):
332        """
333        Extract the arguments from the attributes of the passed function
334        and return a populated ArgumentParser instance.
335        """
336        self._set_func_argspec(func)
337        f = self.argspec
338        defaults = f.defaults or ()
339        n_args = len(f.args)
340        n_defaults = len(defaults)
341        alldefaults = (NONE,) * (n_args - n_defaults) + defaults
342        prefix = self.prefix = getattr(func, 'prefix_chars', '-')[0]
343        for name, default in zip(f.args, alldefaults):
344            ann = f.annotations.get(name, ())
345            a = Annotation.from_(ann)
346            metavar = a.metavar
347            if default is NONE:
348                dflt = None
349            else:
350                dflt = default
351                if a.help is None:
352                    a.help = '[%s]' % str(dflt)  # dflt can be a tuple
353                if a.type is None:
354                    # try to infer the type from the default argument
355                    if isinstance(default, datetime):
356                        a.type = to_datetime
357                    elif isinstance(default, date):
358                        a.type = to_date
359                    elif default is not None:
360                        a.type = type(default)
361                if not metavar and default == '':
362                    metavar = "''"
363            if a.kind in ('option', 'flag'):
364
365                if name.endswith("_"):
366                    # allows reserved words to be specified with underscores
367                    suffix = name.rstrip('_')
368                else:
369                    # convert undescores to dashes.
370                    suffix = name.replace('_', '-')
371
372                if a.abbrev:
373                    shortlong = (prefix + a.abbrev,
374                                 prefix*2 + suffix)
375                else:
376                    shortlong = (prefix + suffix,)
377            elif default is NONE:  # required argument
378                self.add_argument(name, help=a.help, type=a.type,
379                                  choices=a.choices, metavar=metavar)
380            else:  # default argument
381                self.add_argument(
382                    name, nargs='?', help=a.help, default=dflt,
383                    type=a.type, choices=a.choices, metavar=metavar)
384            if a.kind == 'option':
385                if default is not NONE:
386                    metavar = metavar or str(default)
387                self.add_argument(
388                    help=a.help, default=dflt, type=a.type,
389                    choices=a.choices, metavar=metavar, *shortlong)
390            elif a.kind == 'flag':
391                if default is not NONE and default is not False:
392                    raise TypeError(_('Flag %r wants default False, got %r') %
393                                    (name, default))
394                self.add_argument(action='store_true', help=a.help, *shortlong)
395        if f.varargs:
396            a = Annotation.from_(f.annotations.get(f.varargs, ()))
397            self.add_argument(f.varargs, nargs='*', help=a.help, default=[],
398                              type=a.type, metavar=a.metavar)
399        if f.varkw:
400            a = Annotation.from_(f.annotations.get(f.varkw, ()))
401            self.add_argument(f.varkw, nargs='*', help=a.help, default={},
402                              type=a.type, metavar=a.metavar)
403
404    def missing(self, name):
405        "May raise a SystemExit"
406        miss = getattr(self.obj, '__missing__', lambda name:
407                       self.error('No command %r' % name))
408        return miss(name)
409
410    def print_actions(self):
411        "Useful for debugging"
412        print(self)
413        for a in self._actions:
414            print(a)
415
416
417def iterable(obj):
418    "Any object with an __iter__ method which is not a string or class"
419    return hasattr(obj, '__iter__') and not inspect.isclass(obj) and not isinstance(obj, (str, bytes))
420
421
422def call(obj, arglist=None, eager=True, version=None):
423    """
424    If obj is a function or a bound method, parse the given arglist
425    by using the parser inferred from the annotations of obj
426    and call obj with the parsed arguments.
427    If obj is an object with attribute .commands, dispatch to the
428    associated subparser.
429    """
430    if arglist is None:
431        arglist = sys.argv[1:]
432    parser = parser_from(obj)
433    if version:
434        parser.add_argument(
435            '--version', '-v', action='version', version=version)
436    cmd, result = parser.consume(arglist)
437    if iterable(result) and eager:  # listify the result
438        return list(result)
439    return result
440