1# DExTer : Debugging Experience Tester
2# ~~~~~~   ~         ~~         ~   ~~
3#
4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7"""Extended Argument Parser. Extends the argparse module with some extra
8functionality, to hopefully aid user-friendliness.
9"""
10
11import argparse
12import difflib
13import unittest
14
15from dex.utils import PrettyOutput
16from dex.utils.Exceptions import Error
17
18# re-export all of argparse
19for argitem in argparse.__all__:
20    vars()[argitem] = getattr(argparse, argitem)
21
22
23def _did_you_mean(val, possibles):
24    close_matches = difflib.get_close_matches(val, possibles)
25    did_you_mean = ''
26    if close_matches:
27        did_you_mean = 'did you mean {}?'.format(' or '.join(
28            "<y>'{}'</>".format(c) for c in close_matches[:2]))
29    return did_you_mean
30
31
32def _colorize(message):
33    lines = message.splitlines()
34    for i, line in enumerate(lines):
35        lines[i] = lines[i].replace('usage:', '<g>usage:</>')
36        if line.endswith(':'):
37            lines[i] = '<g>{}</>'.format(line)
38    return '\n'.join(lines)
39
40
41class ExtArgumentParser(argparse.ArgumentParser):
42    def error(self, message):
43        """Use the Dexception Error mechanism (including auto-colored output).
44        """
45        raise Error('{}\n\n{}'.format(message, self.format_usage()))
46
47    # pylint: disable=redefined-builtin
48    def _print_message(self, message, file=None):
49        if message:
50            if file and file.name == '<stdout>':
51                file = PrettyOutput.stdout
52            else:
53                file = PrettyOutput.stderr
54
55            self.context.o.auto(message, file)
56
57    # pylint: enable=redefined-builtin
58
59    def format_usage(self):
60        return _colorize(super(ExtArgumentParser, self).format_usage())
61
62    def format_help(self):
63        return _colorize(super(ExtArgumentParser, self).format_help() + '\n\n')
64
65    @property
66    def _valid_visible_options(self):
67        """A list of all non-suppressed command line flags."""
68        return [
69            item for sublist in vars(self)['_actions']
70            for item in sublist.option_strings
71            if sublist.help != argparse.SUPPRESS
72        ]
73
74    def parse_args(self, args=None, namespace=None):
75        """Add 'did you mean' output to errors."""
76        args, argv = self.parse_known_args(args, namespace)
77        if argv:
78            errors = []
79            for arg in argv:
80                if arg in self._valid_visible_options:
81                    error = "unexpected argument: <y>'{}'</>".format(arg)
82                else:
83                    error = "unrecognized argument: <y>'{}'</>".format(arg)
84                    dym = _did_you_mean(arg, self._valid_visible_options)
85                    if dym:
86                        error += '  ({})'.format(dym)
87                errors.append(error)
88            self.error('\n       '.join(errors))
89
90        return args
91
92    def add_argument(self, *args, **kwargs):
93        """Automatically add the default value to help text."""
94        if 'default' in kwargs:
95            default = kwargs['default']
96            if default is None:
97                default = kwargs.pop('display_default', None)
98
99            if (default and isinstance(default, (str, int, float))
100                    and default != argparse.SUPPRESS):
101                assert (
102                    'choices' not in kwargs or default in kwargs['choices']), (
103                        "default value '{}' is not one of allowed choices: {}".
104                        format(default, kwargs['choices']))
105                if 'help' in kwargs and kwargs['help'] != argparse.SUPPRESS:
106                    assert isinstance(kwargs['help'], str), type(kwargs['help'])
107                    kwargs['help'] = ('{} (default:{})'.format(
108                        kwargs['help'], default))
109
110        super(ExtArgumentParser, self).add_argument(*args, **kwargs)
111
112    def __init__(self, context, *args, **kwargs):
113        self.context = context
114        super(ExtArgumentParser, self).__init__(*args, **kwargs)
115
116
117class TestExtArgumentParser(unittest.TestCase):
118    def test_did_you_mean(self):
119        parser = ExtArgumentParser(None)
120        parser.add_argument('--foo')
121        parser.add_argument('--qoo', help=argparse.SUPPRESS)
122        parser.add_argument('jam', nargs='?')
123
124        parser.parse_args(['--foo', '0'])
125
126        expected = (r"^unrecognized argument\: <y>'\-\-doo'</>\s+"
127                    r"\(did you mean <y>'\-\-foo'</>\?\)\n"
128                    r"\s*<g>usage:</>")
129        with self.assertRaisesRegex(Error, expected):
130            parser.parse_args(['--doo'])
131
132        parser.add_argument('--noo')
133
134        expected = (r"^unrecognized argument\: <y>'\-\-doo'</>\s+"
135                    r"\(did you mean <y>'\-\-noo'</> or <y>'\-\-foo'</>\?\)\n"
136                    r"\s*<g>usage:</>")
137        with self.assertRaisesRegex(Error, expected):
138            parser.parse_args(['--doo'])
139
140        expected = (r"^unrecognized argument\: <y>'\-\-bar'</>\n"
141                    r"\s*<g>usage:</>")
142        with self.assertRaisesRegex(Error, expected):
143            parser.parse_args(['--bar'])
144
145        expected = (r"^unexpected argument\: <y>'\-\-foo'</>\n"
146                    r"\s*<g>usage:</>")
147        with self.assertRaisesRegex(Error, expected):
148            parser.parse_args(['--', 'x', '--foo'])
149