1from __future__ import print_function
2import sys
3import sqlparse
4from sqlparse.sql import Comparison, Identifier, Where
5from litecli.encodingutils import string_types, text_type
6from .parseutils import last_word, extract_tables, find_prev_keyword
7from .special import parse_special_command
8
9
10def suggest_type(full_text, text_before_cursor):
11    """Takes the full_text that is typed so far and also the text before the
12    cursor to suggest completion type and scope.
13
14    Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
15    A scope for a column category will be a list of tables.
16    """
17
18    word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
19
20    identifier = None
21
22    # here should be removed once sqlparse has been fixed
23    try:
24        # If we've partially typed a word then word_before_cursor won't be an empty
25        # string. In that case we want to remove the partially typed string before
26        # sending it to the sqlparser. Otherwise the last token will always be the
27        # partially typed string which renders the smart completion useless because
28        # it will always return the list of keywords as completion.
29        if word_before_cursor:
30            if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
31                parsed = sqlparse.parse(text_before_cursor)
32            else:
33                parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])
34
35                # word_before_cursor may include a schema qualification, like
36                # "schema_name.partial_name" or "schema_name.", so parse it
37                # separately
38                p = sqlparse.parse(word_before_cursor)[0]
39
40                if p.tokens and isinstance(p.tokens[0], Identifier):
41                    identifier = p.tokens[0]
42        else:
43            parsed = sqlparse.parse(text_before_cursor)
44    except (TypeError, AttributeError):
45        return [{"type": "keyword"}]
46
47    if len(parsed) > 1:
48        # Multiple statements being edited -- isolate the current one by
49        # cumulatively summing statement lengths to find the one that bounds the
50        # current position
51        current_pos = len(text_before_cursor)
52        stmt_start, stmt_end = 0, 0
53
54        for statement in parsed:
55            stmt_len = len(text_type(statement))
56            stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
57
58            if stmt_end >= current_pos:
59                text_before_cursor = full_text[stmt_start:current_pos]
60                full_text = full_text[stmt_start:]
61                break
62
63    elif parsed:
64        # A single statement
65        statement = parsed[0]
66    else:
67        # The empty string
68        statement = None
69
70    # Check for special commands and handle those separately
71    if statement:
72        # Be careful here because trivial whitespace is parsed as a statement,
73        # but the statement won't have a first token
74        tok1 = statement.token_first()
75        if tok1 and tok1.value.startswith("."):
76            return suggest_special(text_before_cursor)
77        elif tok1 and tok1.value.startswith("\\"):
78            return suggest_special(text_before_cursor)
79        elif tok1 and tok1.value.startswith("source"):
80            return suggest_special(text_before_cursor)
81        elif text_before_cursor and text_before_cursor.startswith(".open "):
82            return suggest_special(text_before_cursor)
83
84    last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""
85
86    return suggest_based_on_last_token(
87        last_token, text_before_cursor, full_text, identifier
88    )
89
90
91def suggest_special(text):
92    text = text.lstrip()
93    cmd, _, arg = parse_special_command(text)
94
95    if cmd == text:
96        # Trying to complete the special command itself
97        return [{"type": "special"}]
98
99    if cmd in ("\\u", "\\r"):
100        return [{"type": "database"}]
101
102    if cmd in ("\\T"):
103        return [{"type": "table_format"}]
104
105    if cmd in ["\\f", "\\fs", "\\fd"]:
106        return [{"type": "favoritequery"}]
107
108    if cmd in ["\\d", "\\dt", "\\dt+", ".schema"]:
109        return [
110            {"type": "table", "schema": []},
111            {"type": "view", "schema": []},
112            {"type": "schema"},
113        ]
114
115    if cmd in ["\\.", "source", ".open"]:
116        return [{"type": "file_name"}]
117
118    if cmd in [".import"]:
119        # Usage: .import filename table
120        if _expecting_arg_idx(arg, text) == 1:
121            return [{"type": "file_name"}]
122        else:
123            return [{"type": "table", "schema": []}]
124
125    return [{"type": "keyword"}, {"type": "special"}]
126
127
128def _expecting_arg_idx(arg, text):
129    """Return the index of expecting argument.
130
131    >>> _expecting_arg_idx("./da", ".import ./da")
132    1
133    >>> _expecting_arg_idx("./data.csv", ".import ./data.csv")
134    1
135    >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ")
136    2
137    >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t")
138    2
139    """
140    args = arg.split()
141    return len(args) + int(text[-1].isspace())
142
143
144def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
145    if isinstance(token, string_types):
146        token_v = token.lower()
147    elif isinstance(token, Comparison):
148        # If 'token' is a Comparison type such as
149        # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
150        # token.value on the comparison type will only return the lhs of the
151        # comparison. In this case a.id. So we need to do token.tokens to get
152        # both sides of the comparison and pick the last token out of that
153        # list.
154        token_v = token.tokens[-1].value.lower()
155    elif isinstance(token, Where):
156        # sqlparse groups all tokens from the where clause into a single token
157        # list. This means that token.value may be something like
158        # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
159        # suggestions in complicated where clauses correctly
160        prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
161        return suggest_based_on_last_token(
162            prev_keyword, text_before_cursor, full_text, identifier
163        )
164    else:
165        token_v = token.value.lower()
166
167    is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]])
168
169    if not token:
170        return [{"type": "keyword"}, {"type": "special"}]
171    elif token_v.endswith("("):
172        p = sqlparse.parse(text_before_cursor)[0]
173
174        if p.tokens and isinstance(p.tokens[-1], Where):
175            # Four possibilities:
176            #  1 - Parenthesized clause like "WHERE foo AND ("
177            #        Suggest columns/functions
178            #  2 - Function call like "WHERE foo("
179            #        Suggest columns/functions
180            #  3 - Subquery expression like "WHERE EXISTS ("
181            #        Suggest keywords, in order to do a subquery
182            #  4 - Subquery OR array comparison like "WHERE foo = ANY("
183            #        Suggest columns/functions AND keywords. (If we wanted to be
184            #        really fancy, we could suggest only array-typed columns)
185
186            column_suggestions = suggest_based_on_last_token(
187                "where", text_before_cursor, full_text, identifier
188            )
189
190            # Check for a subquery expression (cases 3 & 4)
191            where = p.tokens[-1]
192            idx, prev_tok = where.token_prev(len(where.tokens) - 1)
193
194            if isinstance(prev_tok, Comparison):
195                # e.g. "SELECT foo FROM bar WHERE foo = ANY("
196                prev_tok = prev_tok.tokens[-1]
197
198            prev_tok = prev_tok.value.lower()
199            if prev_tok == "exists":
200                return [{"type": "keyword"}]
201            else:
202                return column_suggestions
203
204        # Get the token before the parens
205        idx, prev_tok = p.token_prev(len(p.tokens) - 1)
206        if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
207            # tbl1 INNER JOIN tbl2 USING (col1, col2)
208            tables = extract_tables(full_text)
209
210            # suggest columns that are present in more than one table
211            return [{"type": "column", "tables": tables, "drop_unique": True}]
212        elif p.token_first().value.lower() == "select":
213            # If the lparen is preceeded by a space chances are we're about to
214            # do a sub-select.
215            if last_word(text_before_cursor, "all_punctuations").startswith("("):
216                return [{"type": "keyword"}]
217        elif p.token_first().value.lower() == "show":
218            return [{"type": "show"}]
219
220        # We're probably in a function argument list
221        return [{"type": "column", "tables": extract_tables(full_text)}]
222    elif token_v in ("set", "order by", "distinct"):
223        return [{"type": "column", "tables": extract_tables(full_text)}]
224    elif token_v == "as":
225        # Don't suggest anything for an alias
226        return []
227    elif token_v in ("show"):
228        return [{"type": "show"}]
229    elif token_v in ("to",):
230        p = sqlparse.parse(text_before_cursor)[0]
231        if p.token_first().value.lower() == "change":
232            return [{"type": "change"}]
233        else:
234            return [{"type": "user"}]
235    elif token_v in ("user", "for"):
236        return [{"type": "user"}]
237    elif token_v in ("select", "where", "having"):
238        # Check for a table alias or schema qualification
239        parent = (identifier and identifier.get_parent_name()) or []
240
241        tables = extract_tables(full_text)
242        if parent:
243            tables = [t for t in tables if identifies(parent, *t)]
244            return [
245                {"type": "column", "tables": tables},
246                {"type": "table", "schema": parent},
247                {"type": "view", "schema": parent},
248                {"type": "function", "schema": parent},
249            ]
250        else:
251            aliases = [alias or table for (schema, table, alias) in tables]
252            return [
253                {"type": "column", "tables": tables},
254                {"type": "function", "schema": []},
255                {"type": "alias", "aliases": aliases},
256                {"type": "keyword"},
257            ]
258    elif (token_v.endswith("join") and token.is_keyword) or (
259        token_v
260        in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
261    ):
262        schema = (identifier and identifier.get_parent_name()) or []
263
264        # Suggest tables from either the currently-selected schema or the
265        # public schema if no schema has been specified
266        suggest = [{"type": "table", "schema": schema}]
267
268        if not schema:
269            # Suggest schemas
270            suggest.insert(0, {"type": "schema"})
271
272        # Only tables can be TRUNCATED, otherwise suggest views
273        if token_v != "truncate":
274            suggest.append({"type": "view", "schema": schema})
275
276        return suggest
277
278    elif token_v in ("table", "view", "function"):
279        # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
280        rel_type = token_v
281        schema = (identifier and identifier.get_parent_name()) or []
282        if schema:
283            return [{"type": rel_type, "schema": schema}]
284        else:
285            return [{"type": "schema"}, {"type": rel_type, "schema": []}]
286    elif token_v == "on":
287        tables = extract_tables(full_text)  # [(schema, table, alias), ...]
288        parent = (identifier and identifier.get_parent_name()) or []
289        if parent:
290            # "ON parent.<suggestion>"
291            # parent can be either a schema name or table alias
292            tables = [t for t in tables if identifies(parent, *t)]
293            return [
294                {"type": "column", "tables": tables},
295                {"type": "table", "schema": parent},
296                {"type": "view", "schema": parent},
297                {"type": "function", "schema": parent},
298            ]
299        else:
300            # ON <suggestion>
301            # Use table alias if there is one, otherwise the table name
302            aliases = [alias or table for (schema, table, alias) in tables]
303            suggest = [{"type": "alias", "aliases": aliases}]
304
305            # The lists of 'aliases' could be empty if we're trying to complete
306            # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
307            # In that case we just suggest all tables.
308            if not aliases:
309                suggest.append({"type": "table", "schema": parent})
310            return suggest
311
312    elif token_v in ("use", "database", "template", "connect"):
313        # "\c <db", "use <db>", "DROP DATABASE <db>",
314        # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
315        return [{"type": "database"}]
316    elif token_v == "tableformat":
317        return [{"type": "table_format"}]
318    elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
319        prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
320        if prev_keyword:
321            return suggest_based_on_last_token(
322                prev_keyword, text_before_cursor, full_text, identifier
323            )
324        else:
325            return []
326    else:
327        return [{"type": "keyword"}]
328
329
330def identifies(id, schema, table, alias):
331    return id == alias or id == table or (schema and (id == schema + "." + table))
332