1from __future__ import unicode_literals
2import os
3import logging
4from collections import namedtuple
5
6from . import export
7from .help.commands import helpcommands
8
9log = logging.getLogger(__name__)
10
11NO_QUERY = 0
12PARSED_QUERY = 1
13RAW_QUERY = 2
14
15PAGER_ALWAYS = 2
16PAGER_LONG_OUTPUT = 1
17PAGER_OFF = 0
18
19PAGER_MSG = {
20    PAGER_OFF: "Pager usage is off.",
21    PAGER_LONG_OUTPUT: "Pager is used for long output.",
22    PAGER_ALWAYS: "Pager is always used.",
23}
24
25SpecialCommand = namedtuple(
26    "SpecialCommand",
27    ["handler", "syntax", "description", "arg_type", "hidden", "case_sensitive"],
28)
29
30
31@export
32class CommandNotFound(Exception):
33    pass
34
35
36@export
37class PGSpecial(object):
38
39    # Default static commands that don't rely on PGSpecial state are registered
40    # via the special_command decorator and stored in default_commands
41    default_commands = {}
42
43    def __init__(self):
44        self.timing_enabled = True
45
46        self.commands = self.default_commands.copy()
47        self.timing_enabled = False
48        self.expanded_output = False
49        self.auto_expand = False
50        self.pager_config = PAGER_ALWAYS
51        self.pager = os.environ.get("PAGER", "")
52
53        self.register(
54            self.show_help, "\\?", "\\?", "Show Commands.", arg_type=PARSED_QUERY
55        )
56
57        self.register(
58            self.toggle_expanded_output,
59            "\\x",
60            "\\x",
61            "Toggle expanded output.",
62            arg_type=PARSED_QUERY,
63        )
64
65        self.register(
66            self.call_pset,
67            "\\pset",
68            "\\pset [key] [value]",
69            "A limited version of traditional \\pset",
70            arg_type=PARSED_QUERY,
71        )
72
73        self.register(
74            self.show_command_help,
75            "\\h",
76            "\\h",
77            "Show SQL syntax and help.",
78            arg_type=PARSED_QUERY,
79        )
80
81        self.register(
82            self.toggle_timing,
83            "\\timing",
84            "\\timing",
85            "Toggle timing of commands.",
86            arg_type=NO_QUERY,
87        )
88
89        self.register(
90            self.set_pager,
91            "\\pager",
92            "\\pager [command]",
93            "Set PAGER. Print the query results via PAGER.",
94            arg_type=PARSED_QUERY,
95        )
96
97    def register(self, *args, **kwargs):
98        register_special_command(*args, command_dict=self.commands, **kwargs)
99
100    def execute(self, cur, sql):
101        commands = self.commands
102        command, verbose, pattern = parse_special_command(sql)
103
104        if (command not in commands) and (command.lower() not in commands):
105            raise CommandNotFound
106
107        try:
108            special_cmd = commands[command]
109        except KeyError:
110            special_cmd = commands[command.lower()]
111            if special_cmd.case_sensitive:
112                raise CommandNotFound("Command not found: %s" % command)
113
114        if special_cmd.arg_type == NO_QUERY:
115            return special_cmd.handler()
116        elif special_cmd.arg_type == PARSED_QUERY:
117            return special_cmd.handler(cur=cur, pattern=pattern, verbose=verbose)
118        elif special_cmd.arg_type == RAW_QUERY:
119            return special_cmd.handler(cur=cur, query=sql)
120
121    def show_help(self, pattern, **_):
122        if pattern.strip():
123            return self.show_command_help(pattern)
124
125        headers = ["Command", "Description"]
126        result = []
127
128        for _, value in sorted(self.commands.items()):
129            if not value.hidden:
130                result.append((value.syntax, value.description))
131        return [(None, result, headers, None)]
132
133    def show_command_help_listing(self):
134        table = chunks(sorted(helpcommands.keys()), 6)
135        return [(None, table, [], None)]
136
137    def show_command_help(self, pattern, **_):
138        command = pattern.strip().upper()
139        message = ""
140
141        if not command:
142            return self.show_command_help_listing()
143
144        if command in helpcommands:
145            helpcommand = helpcommands[command]
146
147            if "description" in helpcommand:
148                message += helpcommand["description"]
149            if "synopsis" in helpcommand:
150                message += "\nSyntax:\n"
151                message += helpcommand["synopsis"]
152        else:
153            message = 'No help available for "%s"' % pattern
154            message += "\nTry \\h with no arguments to see available help."
155
156        return [(None, None, None, message)]
157
158    def toggle_expanded_output(self, pattern, **_):
159        flag = pattern.strip()
160        if flag == "auto":
161            self.auto_expand = True
162            self.expanded_output = False
163            return [(None, None, None, "Expanded display is used automatically.")]
164        elif flag == "off":
165            self.expanded_output = False
166        elif flag == "on":
167            self.expanded_output = True
168        else:
169            self.expanded_output = not (self.expanded_output or self.auto_expand)
170
171        self.auto_expand = self.expanded_output
172        message = "Expanded display is "
173        message += "on." if self.expanded_output else "off."
174        return [(None, None, None, message)]
175
176    def toggle_timing(self):
177        self.timing_enabled = not self.timing_enabled
178        message = "Timing is "
179        message += "on." if self.timing_enabled else "off."
180        return [(None, None, None, message)]
181
182    def call_pset(self, pattern, **_):
183        pattern = pattern.split(" ", 2)
184        val = pattern[1] if len(pattern) > 1 else ""
185        key = pattern[0]
186        if hasattr(self, "pset_" + key):
187            return getattr(self, "pset_" + key)(val)
188        else:
189            return [(None, None, None, "'%s' is currently not supported by pset" % key)]
190
191    def pset_pager(self, value):
192        if value == "always":
193            self.pager_config = PAGER_ALWAYS
194        elif value == "off":
195            self.pager_config = PAGER_OFF
196        elif value == "on":
197            self.pager_config = PAGER_LONG_OUTPUT
198        elif self.pager_config == PAGER_LONG_OUTPUT:
199            self.pager_config = PAGER_OFF
200        else:
201            self.pager_config = PAGER_LONG_OUTPUT
202        return [(None, None, None, "%s" % PAGER_MSG[self.pager_config])]
203
204    def set_pager(self, pattern, **_):
205        if not pattern:
206            if not self.pager:
207                os.environ.pop("PAGER", None)
208                msg = "Pager reset to system default."
209            else:
210                os.environ["PAGER"] = self.pager
211                msg = "Reset pager back to default. Default: %s" % self.pager
212        else:
213            os.environ["PAGER"] = pattern
214            msg = "PAGER set to %s." % pattern
215
216        return [(None, None, None, msg)]
217
218
219@export
220def content_exceeds_width(row, width):
221    # Account for 3 characters between each column
222    separator_space = len(row) * 3
223    # Add 2 columns for a bit of buffer
224    line_len = sum([len(x) for x in row]) + separator_space + 2
225    return line_len > width
226
227
228@export
229def parse_special_command(sql):
230    command, _, arg = sql.partition(" ")
231    verbose = "+" in command
232
233    command = command.strip().replace("+", "")
234    return (command, verbose, arg.strip())
235
236
237def special_command(
238    command,
239    syntax,
240    description,
241    arg_type=PARSED_QUERY,
242    hidden=False,
243    case_sensitive=True,
244    aliases=(),
245):
246    """A decorator used internally for static special commands"""
247
248    def wrapper(wrapped):
249        register_special_command(
250            wrapped,
251            command,
252            syntax,
253            description,
254            arg_type,
255            hidden,
256            case_sensitive,
257            aliases,
258            command_dict=PGSpecial.default_commands,
259        )
260        return wrapped
261
262    return wrapper
263
264
265def register_special_command(
266    handler,
267    command,
268    syntax,
269    description,
270    arg_type=PARSED_QUERY,
271    hidden=False,
272    case_sensitive=True,
273    aliases=(),
274    command_dict=None,
275):
276
277    cmd = command.lower() if not case_sensitive else command
278    command_dict[cmd] = SpecialCommand(
279        handler, syntax, description, arg_type, hidden, case_sensitive
280    )
281    for alias in aliases:
282        cmd = alias.lower() if not case_sensitive else alias
283        command_dict[cmd] = SpecialCommand(
284            handler,
285            syntax,
286            description,
287            arg_type,
288            case_sensitive=case_sensitive,
289            hidden=True,
290        )
291
292
293def chunks(l, n):
294    n = max(1, n)
295    return [l[i : i + n] for i in range(0, len(l), n)]
296
297
298@special_command(
299    "\\e", "\\e [file]", "Edit the query with external editor.", arg_type=NO_QUERY
300)
301@special_command(
302    "\\ef",
303    "\\ef [funcname [line]]",
304    "Edit the contents of the query buffer.",
305    arg_type=NO_QUERY,
306    hidden=True,
307)
308@special_command(
309    "\\ev",
310    "\\ev [viewname [line]]",
311    "Edit the contents of the query buffer.",
312    arg_type=NO_QUERY,
313    hidden=True,
314)
315def doc_only():
316    "Documention placeholder.  Implemented in pgcli.main.handle_editor_command"
317    raise RuntimeError
318
319
320@special_command(
321    "\\do", "\\do[S] [pattern]", "List operators.", arg_type=NO_QUERY, hidden=True
322)
323@special_command(
324    "\\dp",
325    "\\dp [pattern]",
326    "List table, view, and sequence access privileges.",
327    arg_type=NO_QUERY,
328    hidden=True,
329)
330@special_command(
331    "\\z", "\\z [pattern]", "Same as \\dp.", arg_type=NO_QUERY, hidden=True
332)
333def place_holder():
334    raise NotImplementedError
335