1import os
2import sys
3import traceback
4import logging
5import threading
6import re
7import fileinput
8from collections import namedtuple
9try:
10    from pwd import getpwuid
11except ImportError:
12    pass
13from time import time
14from datetime import datetime
15from random import choice
16from io import open
17
18from pymysql import OperationalError
19from cli_helpers.tabular_output import TabularOutputFormatter
20from cli_helpers.tabular_output import preprocessors
21from cli_helpers.utils import strip_ansi
22import click
23import sqlparse
24from mycli.packages.parseutils import is_dropping_database, is_destructive
25from prompt_toolkit.completion import DynamicCompleter
26from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
27from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register
28from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
29from prompt_toolkit.document import Document
30from prompt_toolkit.filters import HasFocus, IsDone
31from prompt_toolkit.formatted_text import ANSI
32from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor,
33                                              ConditionalProcessor)
34from prompt_toolkit.lexers import PygmentsLexer
35from prompt_toolkit.history import FileHistory
36from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
37
38from .packages.special.main import NO_QUERY
39from .packages.prompt_utils import confirm, confirm_destructive_query
40from .packages.tabular_output import sql_format
41from .packages import special
42from .packages.special.favoritequeries import FavoriteQueries
43from .sqlcompleter import SQLCompleter
44from .clitoolbar import create_toolbar_tokens_func
45from .clistyle import style_factory, style_factory_output
46from .sqlexecute import FIELD_TYPES, SQLExecute
47from .clibuffer import cli_is_multiline
48from .completion_refresher import CompletionRefresher
49from .config import (write_default_config, get_mylogin_cnf_path,
50                     open_mylogin_cnf, read_config_files, str_to_bool,
51                     strip_matching_quotes)
52from .key_bindings import mycli_bindings
53from .lexer import MyCliLexer
54from .__init__ import __version__
55from .compat import WIN
56from .packages.filepaths import dir_path_exists, guess_socket_location
57
58import itertools
59
60click.disable_unicode_literals_warning = True
61
62try:
63    from urlparse import urlparse
64    from urlparse import unquote
65except ImportError:
66    from urllib.parse import urlparse
67    from urllib.parse import unquote
68
69
70try:
71    import paramiko
72except ImportError:
73    from mycli.packages.paramiko_stub import paramiko
74
75# Query tuples are used for maintaining history
76Query = namedtuple('Query', ['query', 'successful', 'mutating'])
77
78PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
79
80
81class MyCli(object):
82
83    default_prompt = '\\t \\u@\\h:\\d> '
84    max_len_prompt = 45
85    defaults_suffix = None
86
87    # In order of being loaded. Files lower in list override earlier ones.
88    cnf_files = [
89        '/etc/my.cnf',
90        '/etc/mysql/my.cnf',
91        '/usr/local/etc/my.cnf',
92        '~/.my.cnf'
93    ]
94
95    # check XDG_CONFIG_HOME exists and not an empty string
96    if os.environ.get("XDG_CONFIG_HOME"):
97        xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
98    else:
99        xdg_config_home = "~/.config"
100    system_config_files = [
101        '/etc/myclirc',
102        os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
103    ]
104
105    default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
106    pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
107
108    def __init__(self, sqlexecute=None, prompt=None,
109            logfile=None, defaults_suffix=None, defaults_file=None,
110            login_path=None, auto_vertical_output=False, warn=None,
111            myclirc="~/.myclirc"):
112        self.sqlexecute = sqlexecute
113        self.logfile = logfile
114        self.defaults_suffix = defaults_suffix
115        self.login_path = login_path
116
117        # self.cnf_files is a class variable that stores the list of mysql
118        # config files to read in at launch.
119        # If defaults_file is specified then override the class variable with
120        # defaults_file.
121        if defaults_file:
122            self.cnf_files = [defaults_file]
123
124        # Load config.
125        config_files = ([self.default_config_file] + self.system_config_files +
126                        [myclirc] + [self.pwd_config_file])
127        c = self.config = read_config_files(config_files)
128        self.multi_line = c['main'].as_bool('multi_line')
129        self.key_bindings = c['main']['key_bindings']
130        special.set_timing_enabled(c['main'].as_bool('timing'))
131
132        FavoriteQueries.instance = FavoriteQueries.from_config(self.config)
133
134        self.dsn_alias = None
135        self.formatter = TabularOutputFormatter(
136            format_name=c['main']['table_format'])
137        sql_format.register_new_formatter(self.formatter)
138        self.formatter.mycli = self
139        self.syntax_style = c['main']['syntax_style']
140        self.less_chatty = c['main'].as_bool('less_chatty')
141        self.cli_style = c['colors']
142        self.output_style = style_factory_output(
143            self.syntax_style,
144            self.cli_style
145        )
146        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
147        c_dest_warning = c['main'].as_bool('destructive_warning')
148        self.destructive_warning = c_dest_warning if warn is None else warn
149        self.login_path_as_host = c['main'].as_bool('login_path_as_host')
150
151        # read from cli argument or user config file
152        self.auto_vertical_output = auto_vertical_output or \
153                                c['main'].as_bool('auto_vertical_output')
154
155        # Write user config if system config wasn't the last config loaded.
156        if c.filename not in self.system_config_files and not os.path.exists(myclirc):
157            write_default_config(self.default_config_file, myclirc)
158
159        # audit log
160        if self.logfile is None and 'audit_log' in c['main']:
161            try:
162                self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a')
163            except (IOError, OSError) as e:
164                self.echo('Error: Unable to open the audit log file. Your queries will not be logged.',
165                          err=True, fg='red')
166                self.logfile = False
167
168        self.completion_refresher = CompletionRefresher()
169
170        self.logger = logging.getLogger(__name__)
171        self.initialize_logging()
172
173        prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt']
174        self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \
175                             self.default_prompt
176        self.multiline_continuation_char = c['main']['prompt_continuation']
177        keyword_casing = c['main'].get('keyword_casing', 'auto')
178
179        self.query_history = []
180
181        # Initialize completer.
182        self.smart_completion = c['main'].as_bool('smart_completion')
183        self.completer = SQLCompleter(
184            self.smart_completion,
185            supported_formats=self.formatter.supported_formats,
186            keyword_casing=keyword_casing)
187        self._completer_lock = threading.Lock()
188
189        # Register custom special commands.
190        self.register_special_commands()
191
192        # Load .mylogin.cnf if it exists.
193        mylogin_cnf_path = get_mylogin_cnf_path()
194        if mylogin_cnf_path:
195            mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
196            if mylogin_cnf_path and mylogin_cnf:
197                # .mylogin.cnf gets read last, even if defaults_file is specified.
198                self.cnf_files.append(mylogin_cnf)
199            elif mylogin_cnf_path and not mylogin_cnf:
200                # There was an error reading the login path file.
201                print('Error: Unable to read login path file.')
202
203        self.prompt_app = None
204
205    def register_special_commands(self):
206        special.register_special_command(self.change_db, 'use',
207                '\\u', 'Change to a new database.', aliases=('\\u',))
208        special.register_special_command(self.change_db, 'connect',
209                '\\r', 'Reconnect to the database. Optional database argument.',
210                aliases=('\\r', ), case_sensitive=True)
211        special.register_special_command(self.refresh_completions, 'rehash',
212                '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',))
213        special.register_special_command(
214            self.change_table_format, 'tableformat', '\\T',
215            'Change the table format used to output results.',
216            aliases=('\\T',), case_sensitive=True)
217        special.register_special_command(self.execute_from_file, 'source', '\\. filename',
218                              'Execute commands from file.', aliases=('\\.',))
219        special.register_special_command(self.change_prompt_format, 'prompt',
220                '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True)
221
222    def change_table_format(self, arg, **_):
223        try:
224            self.formatter.format_name = arg
225            yield (None, None, None,
226                   'Changed table format to {}'.format(arg))
227        except ValueError:
228            msg = 'Table format {} not recognized. Allowed formats:'.format(
229                arg)
230            for table_type in self.formatter.supported_formats:
231                msg += "\n\t{}".format(table_type)
232            yield (None, None, None, msg)
233
234    def change_db(self, arg, **_):
235        if not arg:
236            click.secho(
237                "No database selected",
238                err=True, fg="red"
239            )
240            return
241
242        if arg.startswith('`') and arg.endswith('`'):
243            arg = re.sub(r'^`(.*)`$', r'\1', arg)
244            arg = re.sub(r'``', r'`', arg)
245        self.sqlexecute.change_db(arg)
246
247        yield (None, None, None, 'You are now connected to database "%s" as '
248                'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user))
249
250    def execute_from_file(self, arg, **_):
251        if not arg:
252            message = 'Missing required argument, filename.'
253            return [(None, None, None, message)]
254        try:
255            with open(os.path.expanduser(arg)) as f:
256                query = f.read()
257        except IOError as e:
258            return [(None, None, None, str(e))]
259
260        if (self.destructive_warning and
261                confirm_destructive_query(query) is False):
262            message = 'Wise choice. Command execution stopped.'
263            return [(None, None, None, message)]
264
265        return self.sqlexecute.run(query)
266
267    def change_prompt_format(self, arg, **_):
268        """
269        Change the prompt format.
270        """
271        if not arg:
272            message = 'Missing required argument, format.'
273            return [(None, None, None, message)]
274
275        self.prompt_format = self.get_prompt(arg)
276        return [(None, None, None, "Changed prompt format to %s" % arg)]
277
278    def initialize_logging(self):
279
280        log_file = os.path.expanduser(self.config['main']['log_file'])
281        log_level = self.config['main']['log_level']
282
283        level_map = {'CRITICAL': logging.CRITICAL,
284                     'ERROR': logging.ERROR,
285                     'WARNING': logging.WARNING,
286                     'INFO': logging.INFO,
287                     'DEBUG': logging.DEBUG
288                     }
289
290        # Disable logging if value is NONE by switching to a no-op handler
291        # Set log level to a high value so it doesn't even waste cycles getting called.
292        if log_level.upper() == "NONE":
293            handler = logging.NullHandler()
294            log_level = "CRITICAL"
295        elif dir_path_exists(log_file):
296            handler = logging.FileHandler(log_file)
297        else:
298            self.echo(
299                'Error: Unable to open the log file "{}".'.format(log_file),
300                err=True, fg='red')
301            return
302
303        formatter = logging.Formatter(
304            '%(asctime)s (%(process)d/%(threadName)s) '
305            '%(name)s %(levelname)s - %(message)s')
306
307        handler.setFormatter(formatter)
308
309        root_logger = logging.getLogger('mycli')
310        root_logger.addHandler(handler)
311        root_logger.setLevel(level_map[log_level.upper()])
312
313        logging.captureWarnings(True)
314
315        root_logger.debug('Initializing mycli logging.')
316        root_logger.debug('Log file %r.', log_file)
317
318
319    def read_my_cnf_files(self, files, keys):
320        """
321        Reads a list of config files and merges them. The last one will win.
322        :param files: list of files to read
323        :param keys: list of keys to retrieve
324        :returns: tuple, with None for missing keys.
325        """
326        cnf = read_config_files(files, list_values=False)
327
328        sections = ['client', 'mysqld']
329        if self.login_path and self.login_path != 'client':
330            sections.append(self.login_path)
331
332        if self.defaults_suffix:
333            sections.extend([sect + self.defaults_suffix for sect in sections])
334
335        def get(key):
336            result = None
337            for sect in cnf:
338                if sect in sections and key in cnf[sect]:
339                    result = strip_matching_quotes(cnf[sect][key])
340            return result
341
342        return {x: get(x) for x in keys}
343
344    def merge_ssl_with_cnf(self, ssl, cnf):
345        """Merge SSL configuration dict with cnf dict"""
346
347        merged = {}
348        merged.update(ssl)
349        prefix = 'ssl-'
350        for k, v in cnf.items():
351            # skip unrelated options
352            if not k.startswith(prefix):
353                continue
354            if v is None:
355                continue
356            # special case because PyMySQL argument is significantly different
357            # from commandline
358            if k == 'ssl-verify-server-cert':
359                merged['check_hostname'] = v
360            else:
361                # use argument name just strip "ssl-" prefix
362                arg = k[len(prefix):]
363                merged[arg] = v
364
365        return merged
366
367    def connect(self, database='', user='', passwd='', host='', port='',
368                socket='', charset='', local_infile='', ssl='',
369                ssh_user='', ssh_host='', ssh_port='',
370                ssh_password='', ssh_key_filename='', init_command=''):
371
372        cnf = {'database': None,
373               'user': None,
374               'password': None,
375               'host': None,
376               'port': None,
377               'socket': None,
378               'default-character-set': None,
379               'local-infile': None,
380               'loose-local-infile': None,
381               'ssl-ca': None,
382               'ssl-cert': None,
383               'ssl-key': None,
384               'ssl-cipher': None,
385               'ssl-verify-serer-cert': None,
386        }
387
388        cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
389
390        # Fall back to config values only if user did not specify a value.
391
392        database = database or cnf['database']
393        # Socket interface not supported for SSH connections
394        if port or (host and host != 'localhost') or (ssh_host and ssh_port):
395            socket = ''
396        else:
397            socket = socket or cnf['socket'] or guess_socket_location()
398        user = user or cnf['user'] or os.getenv('USER')
399        host = host or cnf['host']
400        port = int(port or cnf['port'] or 3306)
401        ssl = ssl or {}
402
403        passwd = passwd if isinstance(passwd, str) else cnf['password']
404        charset = charset or cnf['default-character-set'] or 'utf8'
405
406        # Favor whichever local_infile option is set.
407        for local_infile_option in (local_infile, cnf['local-infile'],
408                                    cnf['loose-local-infile'], False):
409            try:
410                local_infile = str_to_bool(local_infile_option)
411                break
412            except (TypeError, ValueError):
413                pass
414
415        ssl = self.merge_ssl_with_cnf(ssl, cnf)
416        # prune lone check_hostname=False
417        if not any(v for v in ssl.values()):
418            ssl = None
419
420        # Connect to the database.
421
422        def _connect():
423            try:
424                self.sqlexecute = SQLExecute(
425                    database, user, passwd, host, port, socket, charset,
426                    local_infile, ssl, ssh_user, ssh_host, ssh_port,
427                    ssh_password, ssh_key_filename, init_command
428                )
429            except OperationalError as e:
430                if ('Access denied for user' in e.args[1]):
431                    new_passwd = click.prompt('Password', hide_input=True,
432                                              show_default=False, type=str, err=True)
433                    self.sqlexecute = SQLExecute(
434                        database, user, new_passwd, host, port, socket,
435                        charset, local_infile, ssl, ssh_user, ssh_host,
436                        ssh_port, ssh_password, ssh_key_filename, init_command
437                    )
438                else:
439                    raise e
440
441        try:
442            if not WIN and socket:
443                socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
444                self.echo(
445                    f"Connecting to socket {socket}, owned by user {socket_owner}", err=True)
446                try:
447                    _connect()
448                except OperationalError as e:
449                    # These are "Can't open socket" and 2x "Can't connect"
450                    if [code for code in (2001, 2002, 2003) if code == e.args[0]]:
451                        self.logger.debug('Database connection failed: %r.', e)
452                        self.logger.error(
453                            "traceback: %r", traceback.format_exc())
454                        self.logger.debug('Retrying over TCP/IP')
455                        self.echo(
456                            "Failed to connect to local MySQL server through socket '{}':".format(socket))
457                        self.echo(str(e), err=True)
458                        self.echo(
459                            'Retrying over TCP/IP', err=True)
460
461                        # Else fall back to TCP/IP localhost
462                        socket = ""
463                        host = 'localhost'
464                        port = 3306
465                        _connect()
466                    else:
467                        raise e
468            else:
469                host = host or 'localhost'
470                port = port or 3306
471
472                # Bad ports give particularly daft error messages
473                try:
474                    port = int(port)
475                except ValueError as e:
476                    self.echo("Error: Invalid port number: '{0}'.".format(port),
477                              err=True, fg='red')
478                    exit(1)
479
480                _connect()
481        except Exception as e:  # Connecting to a database could fail.
482            self.logger.debug('Database connection failed: %r.', e)
483            self.logger.error("traceback: %r", traceback.format_exc())
484            self.echo(str(e), err=True, fg='red')
485            exit(1)
486
487    def handle_editor_command(self, text):
488        r"""Editor command is any query that is prefixed or suffixed by a '\e'.
489        The reason for a while loop is because a user might edit a query
490        multiple times. For eg:
491
492        "select * from \e"<enter> to edit it in vim, then come
493        back to the prompt with the edited query "select * from
494        blah where q = 'abc'\e" to edit it again.
495        :param text: Document
496        :return: Document
497
498        """
499
500        while special.editor_command(text):
501            filename = special.get_filename(text)
502            query = (special.get_editor_query(text) or
503                     self.get_last_query())
504            sql, message = special.open_external_editor(filename, sql=query)
505            if message:
506                # Something went wrong. Raise an exception and bail.
507                raise RuntimeError(message)
508            while True:
509                try:
510                    text = self.prompt_app.prompt(default=sql)
511                    break
512                except KeyboardInterrupt:
513                    sql = ""
514
515            continue
516        return text
517
518    def handle_clip_command(self, text):
519        r"""A clip command is any query that is prefixed or suffixed by a
520        '\clip'.
521
522        :param text: Document
523        :return: Boolean
524
525        """
526
527        if special.clip_command(text):
528            query = (special.get_clip_query(text) or
529                     self.get_last_query())
530            message = special.copy_query_to_clipboard(sql=query)
531            if message:
532                raise RuntimeError(message)
533            return True
534        return False
535
536    def run_cli(self):
537        iterations = 0
538        sqlexecute = self.sqlexecute
539        logger = self.logger
540        self.configure_pager()
541
542        if self.smart_completion:
543            self.refresh_completions()
544
545        author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
546        sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
547
548        history_file = os.path.expanduser(
549            os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
550        if dir_path_exists(history_file):
551            history = FileHistory(history_file)
552        else:
553            history = None
554            self.echo(
555                'Error: Unable to open the history file "{}". '
556                'Your query history will not be saved.'.format(history_file),
557                err=True, fg='red')
558
559        key_bindings = mycli_bindings(self)
560
561        if not self.less_chatty:
562            print(' '.join(sqlexecute.server_type()))
563            print('mycli', __version__)
564            print('Chat: https://gitter.im/dbcli/mycli')
565            print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
566            print('Home: http://mycli.net')
567            print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
568
569        def get_message():
570            prompt = self.get_prompt(self.prompt_format)
571            if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
572                prompt = self.get_prompt('\\d> ')
573            prompt = prompt.replace("\\x1b", "\x1b")
574            return ANSI(prompt)
575
576        def get_continuation(width, *_):
577            if self.multiline_continuation_char == '':
578                continuation = ''
579            elif self.multiline_continuation_char:
580                left_padding = width - len(self.multiline_continuation_char)
581                continuation = " " * \
582                    max((left_padding - 1), 0) + \
583                    self.multiline_continuation_char + " "
584            else:
585                continuation = " "
586            return [('class:continuation', continuation)]
587
588        def show_suggestion_tip():
589            return iterations < 2
590
591        def one_iteration(text=None):
592            if text is None:
593                try:
594                    text = self.prompt_app.prompt()
595                except KeyboardInterrupt:
596                    return
597
598                special.set_expanded_output(False)
599
600                try:
601                    text = self.handle_editor_command(text)
602                except RuntimeError as e:
603                    logger.error("sql: %r, error: %r", text, e)
604                    logger.error("traceback: %r", traceback.format_exc())
605                    self.echo(str(e), err=True, fg='red')
606                    return
607
608                try:
609                    if self.handle_clip_command(text):
610                        return
611                except RuntimeError as e:
612                    logger.error("sql: %r, error: %r", text, e)
613                    logger.error("traceback: %r", traceback.format_exc())
614                    self.echo(str(e), err=True, fg='red')
615                    return
616
617            if not text.strip():
618                return
619
620            if self.destructive_warning:
621                destroy = confirm_destructive_query(text)
622                if destroy is None:
623                    pass  # Query was not destructive. Nothing to do here.
624                elif destroy is True:
625                    self.echo('Your call!')
626                else:
627                    self.echo('Wise choice!')
628                    return
629            else:
630                destroy = True
631
632            # Keep track of whether or not the query is mutating. In case
633            # of a multi-statement query, the overall query is considered
634            # mutating if any one of the component statements is mutating
635            mutating = False
636
637            try:
638                logger.debug('sql: %r', text)
639
640                special.write_tee(self.get_prompt(self.prompt_format) + text)
641                if self.logfile:
642                    self.logfile.write('\n# %s\n' % datetime.now())
643                    self.logfile.write(text)
644                    self.logfile.write('\n')
645
646                successful = False
647                start = time()
648                res = sqlexecute.run(text)
649                self.formatter.query = text
650                successful = True
651                result_count = 0
652                for title, cur, headers, status in res:
653                    logger.debug("headers: %r", headers)
654                    logger.debug("rows: %r", cur)
655                    logger.debug("status: %r", status)
656                    threshold = 1000
657                    if (is_select(status) and
658                            cur and cur.rowcount > threshold):
659                        self.echo('The result set has more than {} rows.'.format(
660                            threshold), fg='red')
661                        if not confirm('Do you want to continue?'):
662                            self.echo("Aborted!", err=True, fg='red')
663                            break
664
665                    if self.auto_vertical_output:
666                        max_width = self.prompt_app.output.get_size().columns
667                    else:
668                        max_width = None
669
670                    formatted = self.format_output(
671                        title, cur, headers, special.is_expanded_output(),
672                        max_width)
673
674                    t = time() - start
675                    try:
676                        if result_count > 0:
677                            self.echo('')
678                        try:
679                            self.output(formatted, status)
680                        except KeyboardInterrupt:
681                            pass
682                        if special.is_timing_enabled():
683                            self.echo('Time: %0.03fs' % t)
684                    except KeyboardInterrupt:
685                        pass
686
687                    start = time()
688                    result_count += 1
689                    mutating = mutating or destroy or is_mutating(status)
690                special.unset_once_if_written()
691                special.unset_pipe_once_if_written()
692            except EOFError as e:
693                raise e
694            except KeyboardInterrupt:
695                # get last connection id
696                connection_id_to_kill = sqlexecute.connection_id
697                logger.debug("connection id to kill: %r", connection_id_to_kill)
698                # Restart connection to the database
699                sqlexecute.connect()
700                try:
701                    for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill):
702                        status_str = str(status).lower()
703                        if status_str.find('ok') > -1:
704                            logger.debug("cancelled query, connection id: %r, sql: %r",
705                                         connection_id_to_kill, text)
706                            self.echo("cancelled query", err=True, fg='red')
707                except Exception as e:
708                    self.echo('Encountered error while cancelling query: {}'.format(e),
709                              err=True, fg='red')
710            except NotImplementedError:
711                self.echo('Not Yet Implemented.', fg="yellow")
712            except OperationalError as e:
713                logger.debug("Exception: %r", e)
714                if (e.args[0] in (2003, 2006, 2013)):
715                    logger.debug('Attempting to reconnect.')
716                    self.echo('Reconnecting...', fg='yellow')
717                    try:
718                        sqlexecute.connect()
719                        logger.debug('Reconnected successfully.')
720                        one_iteration(text)
721                        return  # OK to just return, cuz the recursion call runs to the end.
722                    except OperationalError as e:
723                        logger.debug('Reconnect failed. e: %r', e)
724                        self.echo(str(e), err=True, fg='red')
725                        # If reconnection failed, don't proceed further.
726                        return
727                else:
728                    logger.error("sql: %r, error: %r", text, e)
729                    logger.error("traceback: %r", traceback.format_exc())
730                    self.echo(str(e), err=True, fg='red')
731            except Exception as e:
732                logger.error("sql: %r, error: %r", text, e)
733                logger.error("traceback: %r", traceback.format_exc())
734                self.echo(str(e), err=True, fg='red')
735            else:
736                if is_dropping_database(text, self.sqlexecute.dbname):
737                    self.sqlexecute.dbname = None
738                    self.sqlexecute.connect()
739
740                # Refresh the table names and column names if necessary.
741                if need_completion_refresh(text):
742                    self.refresh_completions(
743                        reset=need_completion_reset(text))
744            finally:
745                if self.logfile is False:
746                    self.echo("Warning: This query was not logged.",
747                              err=True, fg='red')
748            query = Query(text, successful, mutating)
749            self.query_history.append(query)
750
751        get_toolbar_tokens = create_toolbar_tokens_func(
752            self, show_suggestion_tip)
753        if self.wider_completion_menu:
754            complete_style = CompleteStyle.MULTI_COLUMN
755        else:
756            complete_style = CompleteStyle.COLUMN
757
758        with self._completer_lock:
759
760            if self.key_bindings == 'vi':
761                editing_mode = EditingMode.VI
762            else:
763                editing_mode = EditingMode.EMACS
764
765            self.prompt_app = PromptSession(
766                lexer=PygmentsLexer(MyCliLexer),
767                reserve_space_for_menu=self.get_reserved_space(),
768                message=get_message,
769                prompt_continuation=get_continuation,
770                bottom_toolbar=get_toolbar_tokens,
771                complete_style=complete_style,
772                input_processors=[ConditionalProcessor(
773                    processor=HighlightMatchingBracketProcessor(
774                        chars='[](){}'),
775                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()
776                )],
777                tempfile_suffix='.sql',
778                completer=DynamicCompleter(lambda: self.completer),
779                history=history,
780                auto_suggest=AutoSuggestFromHistory(),
781                complete_while_typing=True,
782                multiline=cli_is_multiline(self),
783                style=style_factory(self.syntax_style, self.cli_style),
784                include_default_pygments_style=False,
785                key_bindings=key_bindings,
786                enable_open_in_editor=True,
787                enable_system_prompt=True,
788                enable_suspend=True,
789                editing_mode=editing_mode,
790                search_ignore_case=True
791            )
792
793        try:
794            while True:
795                one_iteration()
796                iterations += 1
797        except EOFError:
798            special.close_tee()
799            if not self.less_chatty:
800                self.echo('Goodbye!')
801
802    def log_output(self, output):
803        """Log the output in the audit log, if it's enabled."""
804        if self.logfile:
805            click.echo(output, file=self.logfile)
806
807    def echo(self, s, **kwargs):
808        """Print a message to stdout.
809
810        The message will be logged in the audit log, if enabled.
811
812        All keyword arguments are passed to click.echo().
813
814        """
815        self.log_output(s)
816        click.secho(s, **kwargs)
817
818    def get_output_margin(self, status=None):
819        """Get the output margin (number of rows for the prompt, footer and
820        timing message."""
821        margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1
822        if special.is_timing_enabled():
823            margin += 1
824        if status:
825            margin += 1 + status.count('\n')
826
827        return margin
828
829
830    def output(self, output, status=None):
831        """Output text to stdout or a pager command.
832
833        The status text is not outputted to pager or files.
834
835        The message will be logged in the audit log, if enabled. The
836        message will be written to the tee file, if enabled. The
837        message will be written to the output file, if enabled.
838
839        """
840        if output:
841            size = self.prompt_app.output.get_size()
842
843            margin = self.get_output_margin(status)
844
845            fits = True
846            buf = []
847            output_via_pager = self.explicit_pager and special.is_pager_enabled()
848            for i, line in enumerate(output, 1):
849                self.log_output(line)
850                special.write_tee(line)
851                special.write_once(line)
852                special.write_pipe_once(line)
853
854                if fits or output_via_pager:
855                    # buffering
856                    buf.append(line)
857                    if len(line) > size.columns or i > (size.rows - margin):
858                        fits = False
859                        if not self.explicit_pager and special.is_pager_enabled():
860                            # doesn't fit, use pager
861                            output_via_pager = True
862
863                        if not output_via_pager:
864                            # doesn't fit, flush buffer
865                            for line in buf:
866                                click.secho(line)
867                            buf = []
868                else:
869                    click.secho(line)
870
871            if buf:
872                if output_via_pager:
873                    def newlinewrapper(text):
874                        for line in text:
875                            yield line + "\n"
876                    click.echo_via_pager(newlinewrapper(buf))
877                else:
878                    for line in buf:
879                        click.secho(line)
880
881        if status:
882            self.log_output(status)
883            click.secho(status)
884
885    def configure_pager(self):
886        # Provide sane defaults for less if they are empty.
887        if not os.environ.get('LESS'):
888            os.environ['LESS'] = '-RXF'
889
890        cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager'])
891        if cnf['pager']:
892            special.set_pager(cnf['pager'])
893            self.explicit_pager = True
894        else:
895            self.explicit_pager = False
896
897        if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'):
898            special.disable_pager()
899
900    def refresh_completions(self, reset=False):
901        if reset:
902            with self._completer_lock:
903                self.completer.reset_completions()
904        self.completion_refresher.refresh(
905            self.sqlexecute, self._on_completions_refreshed,
906            {'smart_completion': self.smart_completion,
907             'supported_formats': self.formatter.supported_formats,
908             'keyword_casing': self.completer.keyword_casing})
909
910        return [(None, None, None,
911                'Auto-completion refresh started in the background.')]
912
913    def _on_completions_refreshed(self, new_completer):
914        """Swap the completer object in cli with the newly created completer.
915        """
916        with self._completer_lock:
917            self.completer = new_completer
918
919        if self.prompt_app:
920            # After refreshing, redraw the CLI to clear the statusbar
921            # "Refreshing completions..." indicator
922            self.prompt_app.app.invalidate()
923
924    def get_completions(self, text, cursor_positition):
925        with self._completer_lock:
926            return self.completer.get_completions(
927                Document(text=text, cursor_position=cursor_positition), None)
928
929    def get_prompt(self, string):
930        sqlexecute = self.sqlexecute
931        host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host
932        now = datetime.now()
933        string = string.replace('\\u', sqlexecute.user or '(none)')
934        string = string.replace('\\h', host or '(none)')
935        string = string.replace('\\d', sqlexecute.dbname or '(none)')
936        string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
937        string = string.replace('\\n', "\n")
938        string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
939        string = string.replace('\\m', now.strftime('%M'))
940        string = string.replace('\\P', now.strftime('%p'))
941        string = string.replace('\\R', now.strftime('%H'))
942        string = string.replace('\\r', now.strftime('%I'))
943        string = string.replace('\\s', now.strftime('%S'))
944        string = string.replace('\\p', str(sqlexecute.port))
945        string = string.replace('\\A', self.dsn_alias or '(none)')
946        string = string.replace('\\_', ' ')
947        return string
948
949    def run_query(self, query, new_line=True):
950        """Runs *query*."""
951        results = self.sqlexecute.run(query)
952        for result in results:
953            title, cur, headers, status = result
954            self.formatter.query = query
955            output = self.format_output(title, cur, headers)
956            for line in output:
957                click.echo(line, nl=new_line)
958
959    def format_output(self, title, cur, headers, expanded=False,
960                      max_width=None):
961        expanded = expanded or self.formatter.format_name == 'vertical'
962        output = []
963
964        output_kwargs = {
965            'dialect': 'unix',
966            'disable_numparse': True,
967            'preserve_whitespace': True,
968            'style': self.output_style
969        }
970
971        if not self.formatter.format_name in sql_format.supported_formats:
972            output_kwargs["preprocessors"] = (preprocessors.align_decimals, )
973
974        if title:  # Only print the title if it's not None.
975            output = itertools.chain(output, [title])
976
977        if cur:
978            column_types = None
979            if hasattr(cur, 'description'):
980                def get_col_type(col):
981                    col_type = FIELD_TYPES.get(col[1], str)
982                    return col_type if type(col_type) is type else str
983                column_types = [get_col_type(col) for col in cur.description]
984
985            if max_width is not None:
986                cur = list(cur)
987
988            formatted = self.formatter.format_output(
989                cur, headers, format_name='vertical' if expanded else None,
990                column_types=column_types,
991                **output_kwargs)
992
993            if isinstance(formatted, str):
994                formatted = formatted.splitlines()
995            formatted = iter(formatted)
996
997            if (not expanded and max_width and headers and cur):
998                first_line = next(formatted)
999                if len(strip_ansi(first_line)) > max_width:
1000                    formatted = self.formatter.format_output(
1001                        cur, headers, format_name='vertical', column_types=column_types, **output_kwargs)
1002                    if isinstance(formatted, str):
1003                        formatted = iter(formatted.splitlines())
1004                else:
1005                    formatted = itertools.chain([first_line], formatted)
1006
1007            output = itertools.chain(output, formatted)
1008
1009
1010        return output
1011
1012    def get_reserved_space(self):
1013        """Get the number of lines to reserve for the completion menu."""
1014        reserved_space_ratio = .45
1015        max_reserved_space = 8
1016        _, height = click.get_terminal_size()
1017        return min(int(round(height * reserved_space_ratio)), max_reserved_space)
1018
1019    def get_last_query(self):
1020        """Get the last query executed or None."""
1021        return self.query_history[-1][0] if self.query_history else None
1022
1023
1024@click.command()
1025@click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.')
1026@click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors '
1027              '$MYSQL_TCP_PORT.')
1028@click.option('-u', '--user', help='User name to connect to the database.')
1029@click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.')
1030@click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str,
1031              help='Password to connect to the database.')
1032@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str,
1033              help='Password to connect to the database.')
1034@click.option('--ssh-user', help='User name to connect to ssh server.')
1035@click.option('--ssh-host', help='Host name to connect to ssh server.')
1036@click.option('--ssh-port', default=22, help='Port to connect to ssh server.')
1037@click.option('--ssh-password', help='Password to connect to ssh server.')
1038@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.')
1039@click.option('--ssh-config-path', help='Path to ssh configuration.',
1040              default=os.path.expanduser('~') + '/.ssh/config')
1041@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.')
1042@click.option('--ssl-ca', help='CA file in PEM format.',
1043              type=click.Path(exists=True))
1044@click.option('--ssl-capath', help='CA directory.')
1045@click.option('--ssl-cert', help='X509 cert in PEM format.',
1046              type=click.Path(exists=True))
1047@click.option('--ssl-key', help='X509 key in PEM format.',
1048              type=click.Path(exists=True))
1049@click.option('--ssl-cipher', help='SSL cipher to use.')
1050@click.option('--ssl-verify-server-cert', is_flag=True,
1051              help=('Verify server\'s "Common Name" in its cert against '
1052                    'hostname used when connecting. This option is disabled '
1053                    'by default.'))
1054# as of 2016-02-15 revocation list is not supported by underling PyMySQL
1055# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client)
1056@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.')
1057@click.option('-v', '--verbose', is_flag=True, help='Verbose output.')
1058@click.option('-D', '--database', 'dbname', help='Database to use.')
1059@click.option('-d', '--dsn', default='', envvar='DSN',
1060              help='Use DSN configured into the [alias_dsn] section of myclirc file.')
1061@click.option('--list-dsn', 'list_dsn', is_flag=True,
1062        help='list of DSN configured into the [alias_dsn] section of myclirc file.')
1063@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True,
1064              help='list ssh configurations in the ssh config (requires paramiko).')
1065@click.option('-R', '--prompt', 'prompt',
1066              help='Prompt format (Default: "{0}").'.format(
1067                  MyCli.default_prompt))
1068@click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'),
1069              help='Log every query and its results to a file.')
1070@click.option('--defaults-group-suffix', type=str,
1071              help='Read MySQL config groups with the specified suffix.')
1072@click.option('--defaults-file', type=click.Path(),
1073              help='Only read MySQL options from the given file.')
1074@click.option('--myclirc', type=click.Path(), default="~/.myclirc",
1075              help='Location of myclirc file.')
1076@click.option('--auto-vertical-output', is_flag=True,
1077              help='Automatically switch to vertical output mode if the result is wider than the terminal width.')
1078@click.option('-t', '--table', is_flag=True,
1079              help='Display batch output in table format.')
1080@click.option('--csv', is_flag=True,
1081              help='Display batch output in CSV format.')
1082@click.option('--warn/--no-warn', default=None,
1083              help='Warn before running a destructive query.')
1084@click.option('--local-infile', type=bool,
1085              help='Enable/disable LOAD DATA LOCAL INFILE.')
1086@click.option('--login-path', type=str,
1087              help='Read this path from the login file.')
1088@click.option('-e', '--execute',  type=str,
1089              help='Execute command and quit.')
1090@click.option('--init-command', type=str,
1091              help='SQL statement to execute after connecting.')
1092@click.option('--charset', type=str,
1093              help='Character set for MySQL session.')
1094@click.argument('database', default='', nargs=1)
1095def cli(database, user, host, port, socket, password, dbname,
1096        version, verbose, prompt, logfile, defaults_group_suffix,
1097        defaults_file, login_path, auto_vertical_output, local_infile,
1098        ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher,
1099        ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
1100        list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
1101        ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
1102        init_command, charset):
1103    """A MySQL terminal client with auto-completion and syntax highlighting.
1104
1105    \b
1106    Examples:
1107      - mycli my_database
1108      - mycli -u my_user -h my_host.com my_database
1109      - mycli mysql://my_user@my_host.com:3306/my_database
1110
1111    """
1112
1113    if version:
1114        print('Version:', __version__)
1115        sys.exit(0)
1116
1117    mycli = MyCli(prompt=prompt, logfile=logfile,
1118                  defaults_suffix=defaults_group_suffix,
1119                  defaults_file=defaults_file, login_path=login_path,
1120                  auto_vertical_output=auto_vertical_output, warn=warn,
1121                  myclirc=myclirc)
1122    if list_dsn:
1123        try:
1124            alias_dsn = mycli.config['alias_dsn']
1125        except KeyError as err:
1126            click.secho('Invalid DSNs found in the config file. '\
1127                'Please check the "[alias_dsn]" section in myclirc.',
1128                 err=True, fg='red')
1129            exit(1)
1130        except Exception as e:
1131            click.secho(str(e), err=True, fg='red')
1132            exit(1)
1133        for alias, value in alias_dsn.items():
1134            if verbose:
1135                click.secho("{} : {}".format(alias, value))
1136            else:
1137                click.secho(alias)
1138        sys.exit(0)
1139    if list_ssh_config:
1140        ssh_config = read_ssh_config(ssh_config_path)
1141        for host in ssh_config.get_hostnames():
1142            if verbose:
1143                host_config = ssh_config.lookup(host)
1144                click.secho("{} : {}".format(
1145                    host, host_config.get('hostname')))
1146            else:
1147                click.secho(host)
1148        sys.exit(0)
1149    # Choose which ever one has a valid value.
1150    database = dbname or database
1151
1152    ssl = {
1153            'ca': ssl_ca and os.path.expanduser(ssl_ca),
1154            'cert': ssl_cert and os.path.expanduser(ssl_cert),
1155            'key': ssl_key and os.path.expanduser(ssl_key),
1156            'capath': ssl_capath,
1157            'cipher': ssl_cipher,
1158            'check_hostname': ssl_verify_server_cert,
1159            }
1160
1161    # remove empty ssl options
1162    ssl = {k: v for k, v in ssl.items() if v is not None}
1163
1164    dsn_uri = None
1165
1166    # Treat the database argument as a DSN alias if we're missing
1167    # other connection information.
1168    if (mycli.config['alias_dsn'] and database and '://' not in database
1169            and not any([user, password, host, port, login_path])):
1170        dsn, database = database, ''
1171
1172    if database and '://' in database:
1173        dsn_uri, database = database, ''
1174
1175    if dsn:
1176        try:
1177            dsn_uri = mycli.config['alias_dsn'][dsn]
1178        except KeyError:
1179            click.secho('Could not find the specified DSN in the config file. '
1180                        'Please check the "[alias_dsn]" section in your '
1181                        'myclirc.', err=True, fg='red')
1182            exit(1)
1183        else:
1184            mycli.dsn_alias = dsn
1185
1186    if dsn_uri:
1187        uri = urlparse(dsn_uri)
1188        if not database:
1189            database = uri.path[1:]  # ignore the leading fwd slash
1190        if not user:
1191            user = unquote(uri.username)
1192        if not password and uri.password is not None:
1193            password = unquote(uri.password)
1194        if not host:
1195            host = uri.hostname
1196        if not port:
1197            port = uri.port
1198
1199    if ssh_config_host:
1200        ssh_config = read_ssh_config(
1201            ssh_config_path
1202        ).lookup(ssh_config_host)
1203        ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
1204        ssh_user = ssh_user if ssh_user else ssh_config.get('user')
1205        if ssh_config.get('port') and ssh_port == 22:
1206            # port has a default value, overwrite it if it's in the config
1207            ssh_port = int(ssh_config.get('port'))
1208        ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
1209            'identityfile', [None])[0]
1210
1211    ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
1212
1213    mycli.connect(
1214        database=database,
1215        user=user,
1216        passwd=password,
1217        host=host,
1218        port=port,
1219        socket=socket,
1220        local_infile=local_infile,
1221        ssl=ssl,
1222        ssh_user=ssh_user,
1223        ssh_host=ssh_host,
1224        ssh_port=ssh_port,
1225        ssh_password=ssh_password,
1226        ssh_key_filename=ssh_key_filename,
1227        init_command=init_command,
1228        charset=charset
1229    )
1230
1231    mycli.logger.debug('Launch Params: \n'
1232            '\tdatabase: %r'
1233            '\tuser: %r'
1234            '\thost: %r'
1235            '\tport: %r', database, user, host, port)
1236
1237    #  --execute argument
1238    if execute:
1239        try:
1240            if csv:
1241                mycli.formatter.format_name = 'csv'
1242            elif not table:
1243                mycli.formatter.format_name = 'tsv'
1244
1245            mycli.run_query(execute)
1246            exit(0)
1247        except Exception as e:
1248            click.secho(str(e), err=True, fg='red')
1249            exit(1)
1250
1251    if sys.stdin.isatty():
1252        mycli.run_cli()
1253    else:
1254        stdin = click.get_text_stream('stdin')
1255        try:
1256            stdin_text = stdin.read()
1257        except MemoryError:
1258            click.secho('Failed! Ran out of memory.', err=True, fg='red')
1259            click.secho('You might want to try the official mysql client.', err=True, fg='red')
1260            click.secho('Sorry... :(', err=True, fg='red')
1261            exit(1)
1262
1263        if mycli.destructive_warning and is_destructive(stdin_text):
1264            try:
1265                sys.stdin = open('/dev/tty')
1266                warn_confirmed = confirm_destructive_query(stdin_text)
1267            except (IOError, OSError):
1268                mycli.logger.warning('Unable to open TTY as stdin.')
1269            if not warn_confirmed:
1270                exit(0)
1271
1272        try:
1273            new_line = True
1274
1275            if csv:
1276                mycli.formatter.format_name = 'csv'
1277            elif not table:
1278                mycli.formatter.format_name = 'tsv'
1279
1280            mycli.run_query(stdin_text, new_line=new_line)
1281            exit(0)
1282        except Exception as e:
1283            click.secho(str(e), err=True, fg='red')
1284            exit(1)
1285
1286
1287def need_completion_refresh(queries):
1288    """Determines if the completion needs a refresh by checking if the sql
1289    statement is an alter, create, drop or change db."""
1290    for query in sqlparse.split(queries):
1291        try:
1292            first_token = query.split()[0]
1293            if first_token.lower() in ('alter', 'create', 'use', '\\r',
1294                                       '\\u', 'connect', 'drop', 'rename'):
1295                return True
1296        except Exception:
1297            return False
1298
1299
1300def need_completion_reset(queries):
1301    """Determines if the statement is a database switch such as 'use' or '\\u'.
1302    When a database is changed the existing completions must be reset before we
1303    start the completion refresh for the new database.
1304    """
1305    for query in sqlparse.split(queries):
1306        try:
1307            first_token = query.split()[0]
1308            if first_token.lower() in ('use', '\\u'):
1309                return True
1310        except Exception:
1311            return False
1312
1313
1314def is_mutating(status):
1315    """Determines if the statement is mutating based on the status."""
1316    if not status:
1317        return False
1318
1319    mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop',
1320                    'replace', 'truncate', 'load', 'rename'])
1321    return status.split(None, 1)[0].lower() in mutating
1322
1323
1324def is_select(status):
1325    """Returns true if the first word in status is 'select'."""
1326    if not status:
1327        return False
1328    return status.split(None, 1)[0].lower() == 'select'
1329
1330
1331def thanks_picker(files=()):
1332    contents = []
1333    for line in fileinput.input(files=files):
1334        m = re.match(r'^ *\* (.*)', line)
1335        if m:
1336            contents.append(m.group(1))
1337    return choice(contents)
1338
1339
1340@prompt_register('edit-and-execute-command')
1341def edit_and_execute(event):
1342    """Different from the prompt-toolkit default, we want to have a choice not
1343    to execute a query after editing, hence validate_and_handle=False."""
1344    buff = event.current_buffer
1345    buff.open_in_editor(validate_and_handle=False)
1346
1347
1348def read_ssh_config(ssh_config_path):
1349    ssh_config = paramiko.config.SSHConfig()
1350    try:
1351        with open(ssh_config_path) as f:
1352            ssh_config.parse(f)
1353    # Paramiko prior to version 2.7 raises Exception on parse errors.
1354    # In 2.7 it has become paramiko.ssh_exception.SSHException,
1355    # but let's catch everything for compatibility
1356    except Exception as err:
1357        click.secho(
1358            f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ',
1359            err=True, fg='red'
1360        )
1361        sys.exit(1)
1362    except FileNotFoundError as e:
1363        click.secho(str(e), err=True, fg='red')
1364        sys.exit(1)
1365    else:
1366        return ssh_config
1367
1368
1369if __name__ == "__main__":
1370    cli()
1371