1import platform
2import warnings
3from os.path import expanduser
4
5from configobj import ConfigObj, ParseError
6from pgspecial.namedqueries import NamedQueries
7from .config import skip_initial_comment
8
9warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
10
11import os
12import re
13import sys
14import traceback
15import logging
16import threading
17import shutil
18import functools
19import pendulum
20import datetime as dt
21import itertools
22import platform
23from time import time, sleep
24
25keyring = None  # keyring will be loaded later
26
27from cli_helpers.tabular_output import TabularOutputFormatter
28from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
29from cli_helpers.utils import strip_ansi
30import click
31
32try:
33    import setproctitle
34except ImportError:
35    setproctitle = None
36from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter
37from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
38from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
39from prompt_toolkit.document import Document
40from prompt_toolkit.filters import HasFocus, IsDone
41from prompt_toolkit.formatted_text import ANSI
42from prompt_toolkit.lexers import PygmentsLexer
43from prompt_toolkit.layout.processors import (
44    ConditionalProcessor,
45    HighlightMatchingBracketProcessor,
46    TabsProcessor,
47)
48from prompt_toolkit.history import FileHistory
49from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
50from pygments.lexers.sql import PostgresLexer
51
52from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
53import pgspecial as special
54
55from .pgcompleter import PGCompleter
56from .pgtoolbar import create_toolbar_tokens_func
57from .pgstyle import style_factory, style_factory_output
58from .pgexecute import PGExecute
59from .completion_refresher import CompletionRefresher
60from .config import (
61    get_casing_file,
62    load_config,
63    config_location,
64    ensure_dir_exists,
65    get_config,
66    get_config_filename,
67)
68from .key_bindings import pgcli_bindings
69from .packages.prompt_utils import confirm_destructive_query
70from .__init__ import __version__
71
72click.disable_unicode_literals_warning = True
73
74try:
75    from urlparse import urlparse, unquote, parse_qs
76except ImportError:
77    from urllib.parse import urlparse, unquote, parse_qs
78
79from getpass import getuser
80from psycopg2 import OperationalError, InterfaceError
81import psycopg2
82
83from collections import namedtuple
84
85from textwrap import dedent
86
87# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
88COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
89
90# Query tuples are used for maintaining history
91MetaQuery = namedtuple(
92    "Query",
93    [
94        "query",  # The entire text of the command
95        "successful",  # True If all subqueries were successful
96        "total_time",  # Time elapsed executing the query and formatting results
97        "execution_time",  # Time elapsed executing the query
98        "meta_changed",  # True if any subquery executed create/alter/drop
99        "db_changed",  # True if any subquery changed the database
100        "path_changed",  # True if any subquery changed the search path
101        "mutated",  # True if any subquery executed insert/update/delete
102        "is_special",  # True if the query is a special command
103    ],
104)
105MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False)
106
107OutputSettings = namedtuple(
108    "OutputSettings",
109    "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output",
110)
111OutputSettings.__new__.__defaults__ = (
112    None,
113    None,
114    None,
115    "<null>",
116    False,
117    None,
118    lambda x: x,
119    None,
120)
121
122
123class PgCliQuitError(Exception):
124    pass
125
126
127class PGCli:
128    default_prompt = "\\u@\\h:\\d> "
129    max_len_prompt = 30
130
131    def set_default_pager(self, config):
132        configured_pager = config["main"].get("pager")
133        os_environ_pager = os.environ.get("PAGER")
134
135        if configured_pager:
136            self.logger.info(
137                'Default pager found in config file: "%s"', configured_pager
138            )
139            os.environ["PAGER"] = configured_pager
140        elif os_environ_pager:
141            self.logger.info(
142                'Default pager found in PAGER environment variable: "%s"',
143                os_environ_pager,
144            )
145            os.environ["PAGER"] = os_environ_pager
146        else:
147            self.logger.info(
148                "No default pager found in environment. Using os default pager"
149            )
150
151        # Set default set of less recommended options, if they are not already set.
152        # They are ignored if pager is different than less.
153        if not os.environ.get("LESS"):
154            os.environ["LESS"] = "-SRXF"
155
156    def __init__(
157        self,
158        force_passwd_prompt=False,
159        never_passwd_prompt=False,
160        pgexecute=None,
161        pgclirc_file=None,
162        row_limit=None,
163        single_connection=False,
164        less_chatty=None,
165        prompt=None,
166        prompt_dsn=None,
167        auto_vertical_output=False,
168        warn=None,
169    ):
170
171        self.force_passwd_prompt = force_passwd_prompt
172        self.never_passwd_prompt = never_passwd_prompt
173        self.pgexecute = pgexecute
174        self.dsn_alias = None
175        self.watch_command = None
176
177        # Load config.
178        c = self.config = get_config(pgclirc_file)
179
180        # at this point, config should be written to pgclirc_file if it did not exist. Read it.
181        self.config_writer = load_config(get_config_filename(pgclirc_file))
182
183        # make sure to use self.config_writer, not self.config
184        NamedQueries.instance = NamedQueries.from_config(self.config_writer)
185
186        self.logger = logging.getLogger(__name__)
187        self.initialize_logging()
188
189        self.set_default_pager(c)
190        self.output_file = None
191        self.pgspecial = PGSpecial()
192
193        self.multi_line = c["main"].as_bool("multi_line")
194        self.multiline_mode = c["main"].get("multi_line_mode", "psql")
195        self.vi_mode = c["main"].as_bool("vi")
196        self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
197        self.expanded_output = c["main"].as_bool("expand")
198        self.pgspecial.timing_enabled = c["main"].as_bool("timing")
199        if row_limit is not None:
200            self.row_limit = row_limit
201        else:
202            self.row_limit = c["main"].as_int("row_limit")
203
204        self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
205        self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
206        self.table_format = c["main"]["table_format"]
207        self.syntax_style = c["main"]["syntax_style"]
208        self.cli_style = c["colors"]
209        self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
210        self.destructive_warning = warn or c["main"]["destructive_warning"]
211        # also handle boolean format of destructive warning
212        self.destructive_warning = {"true": "all", "false": "off"}.get(
213            self.destructive_warning.lower(), self.destructive_warning
214        )
215        self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
216        self.null_string = c["main"].get("null_string", "<null>")
217        self.prompt_format = (
218            prompt
219            if prompt is not None
220            else c["main"].get("prompt", self.default_prompt)
221        )
222        self.prompt_dsn_format = prompt_dsn
223        self.on_error = c["main"]["on_error"].upper()
224        self.decimal_format = c["data_formats"]["decimal"]
225        self.float_format = c["data_formats"]["float"]
226        self.initialize_keyring()
227        self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
228
229        self.pgspecial.pset_pager(
230            self.config["main"].as_bool("enable_pager") and "on" or "off"
231        )
232
233        self.style_output = style_factory_output(self.syntax_style, c["colors"])
234
235        self.now = dt.datetime.today()
236
237        self.completion_refresher = CompletionRefresher()
238
239        self.query_history = []
240
241        # Initialize completer
242        smart_completion = c["main"].as_bool("smart_completion")
243        keyword_casing = c["main"]["keyword_casing"]
244        self.settings = {
245            "casing_file": get_casing_file(c),
246            "generate_casing_file": c["main"].as_bool("generate_casing_file"),
247            "generate_aliases": c["main"].as_bool("generate_aliases"),
248            "asterisk_column_order": c["main"]["asterisk_column_order"],
249            "qualify_columns": c["main"]["qualify_columns"],
250            "case_column_headers": c["main"].as_bool("case_column_headers"),
251            "search_path_filter": c["main"].as_bool("search_path_filter"),
252            "single_connection": single_connection,
253            "less_chatty": less_chatty,
254            "keyword_casing": keyword_casing,
255        }
256
257        completer = PGCompleter(
258            smart_completion, pgspecial=self.pgspecial, settings=self.settings
259        )
260        self.completer = completer
261        self._completer_lock = threading.Lock()
262        self.register_special_commands()
263
264        self.prompt_app = None
265
266    def quit(self):
267        raise PgCliQuitError
268
269    def register_special_commands(self):
270
271        self.pgspecial.register(
272            self.change_db,
273            "\\c",
274            "\\c[onnect] database_name",
275            "Change to a new database.",
276            aliases=("use", "\\connect", "USE"),
277        )
278
279        refresh_callback = lambda: self.refresh_completions(persist_priorities="all")
280
281        self.pgspecial.register(
282            self.quit,
283            "\\q",
284            "\\q",
285            "Quit pgcli.",
286            arg_type=NO_QUERY,
287            case_sensitive=True,
288            aliases=(":q",),
289        )
290        self.pgspecial.register(
291            self.quit,
292            "quit",
293            "quit",
294            "Quit pgcli.",
295            arg_type=NO_QUERY,
296            case_sensitive=False,
297            aliases=("exit",),
298        )
299        self.pgspecial.register(
300            refresh_callback,
301            "\\#",
302            "\\#",
303            "Refresh auto-completions.",
304            arg_type=NO_QUERY,
305        )
306        self.pgspecial.register(
307            refresh_callback,
308            "\\refresh",
309            "\\refresh",
310            "Refresh auto-completions.",
311            arg_type=NO_QUERY,
312        )
313        self.pgspecial.register(
314            self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
315        )
316        self.pgspecial.register(
317            self.write_to_file,
318            "\\o",
319            "\\o [filename]",
320            "Send all query results to file.",
321        )
322        self.pgspecial.register(
323            self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
324        )
325        self.pgspecial.register(
326            self.change_table_format,
327            "\\T",
328            "\\T [format]",
329            "Change the table format used to output results",
330        )
331
332    def change_table_format(self, pattern, **_):
333        try:
334            if pattern not in TabularOutputFormatter().supported_formats:
335                raise ValueError()
336            self.table_format = pattern
337            yield (None, None, None, f"Changed table format to {pattern}")
338        except ValueError:
339            msg = f"Table format {pattern} not recognized. Allowed formats:"
340            for table_type in TabularOutputFormatter().supported_formats:
341                msg += f"\n\t{table_type}"
342            msg += "\nCurrently set to: %s" % self.table_format
343            yield (None, None, None, msg)
344
345    def info_connection(self, **_):
346        if self.pgexecute.host.startswith("/"):
347            host = 'socket "%s"' % self.pgexecute.host
348        else:
349            host = 'host "%s"' % self.pgexecute.host
350
351        yield (
352            None,
353            None,
354            None,
355            'You are connected to database "%s" as user '
356            '"%s" on %s at port "%s".'
357            % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
358        )
359
360    def change_db(self, pattern, **_):
361        if pattern:
362            # Get all the parameters in pattern, handling double quotes if any.
363            infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
364            # Now removing quotes.
365            list(map(lambda s: s.strip('"'), infos))
366
367            infos.extend([None] * (4 - len(infos)))
368            db, user, host, port = infos
369            try:
370                self.pgexecute.connect(
371                    database=db,
372                    user=user,
373                    host=host,
374                    port=port,
375                    **self.pgexecute.extra_args,
376                )
377            except OperationalError as e:
378                click.secho(str(e), err=True, fg="red")
379                click.echo("Previous connection kept")
380        else:
381            self.pgexecute.connect()
382
383        yield (
384            None,
385            None,
386            None,
387            'You are now connected to database "%s" as '
388            'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
389        )
390
391    def execute_from_file(self, pattern, **_):
392        if not pattern:
393            message = "\\i: missing required argument"
394            return [(None, None, None, message, "", False, True)]
395        try:
396            with open(os.path.expanduser(pattern), encoding="utf-8") as f:
397                query = f.read()
398        except OSError as e:
399            return [(None, None, None, str(e), "", False, True)]
400
401        if (
402            self.destructive_warning != "off"
403            and confirm_destructive_query(query, self.destructive_warning) is False
404        ):
405            message = "Wise choice. Command execution stopped."
406            return [(None, None, None, message)]
407
408        on_error_resume = self.on_error == "RESUME"
409        return self.pgexecute.run(
410            query, self.pgspecial, on_error_resume=on_error_resume
411        )
412
413    def write_to_file(self, pattern, **_):
414        if not pattern:
415            self.output_file = None
416            message = "File output disabled"
417            return [(None, None, None, message, "", True, True)]
418        filename = os.path.abspath(os.path.expanduser(pattern))
419        if not os.path.isfile(filename):
420            try:
421                open(filename, "w").close()
422            except OSError as e:
423                self.output_file = None
424                message = str(e) + "\nFile output disabled"
425                return [(None, None, None, message, "", False, True)]
426        self.output_file = filename
427        message = 'Writing to file "%s"' % self.output_file
428        return [(None, None, None, message, "", True, True)]
429
430    def initialize_logging(self):
431
432        log_file = self.config["main"]["log_file"]
433        if log_file == "default":
434            log_file = config_location() + "log"
435        ensure_dir_exists(log_file)
436        log_level = self.config["main"]["log_level"]
437
438        # Disable logging if value is NONE by switching to a no-op handler.
439        # Set log level to a high value so it doesn't even waste cycles getting called.
440        if log_level.upper() == "NONE":
441            handler = logging.NullHandler()
442        else:
443            handler = logging.FileHandler(os.path.expanduser(log_file))
444
445        level_map = {
446            "CRITICAL": logging.CRITICAL,
447            "ERROR": logging.ERROR,
448            "WARNING": logging.WARNING,
449            "INFO": logging.INFO,
450            "DEBUG": logging.DEBUG,
451            "NONE": logging.CRITICAL,
452        }
453
454        log_level = level_map[log_level.upper()]
455
456        formatter = logging.Formatter(
457            "%(asctime)s (%(process)d/%(threadName)s) "
458            "%(name)s %(levelname)s - %(message)s"
459        )
460
461        handler.setFormatter(formatter)
462
463        root_logger = logging.getLogger("pgcli")
464        root_logger.addHandler(handler)
465        root_logger.setLevel(log_level)
466
467        root_logger.debug("Initializing pgcli logging.")
468        root_logger.debug("Log file %r.", log_file)
469
470        pgspecial_logger = logging.getLogger("pgspecial")
471        pgspecial_logger.addHandler(handler)
472        pgspecial_logger.setLevel(log_level)
473
474    def initialize_keyring(self):
475        global keyring
476
477        keyring_enabled = self.config["main"].as_bool("keyring")
478        if keyring_enabled:
479            # Try best to load keyring (issue #1041).
480            import importlib
481
482            try:
483                keyring = importlib.import_module("keyring")
484            except Exception as e:  # ImportError for Python 2, ModuleNotFoundError for Python 3
485                self.logger.warning("import keyring failed: %r.", e)
486
487    def connect_dsn(self, dsn, **kwargs):
488        self.connect(dsn=dsn, **kwargs)
489
490    def connect_service(self, service, user):
491        service_config, file = parse_service_info(service)
492        if service_config is None:
493            click.secho(
494                f"service '{service}' was not found in {file}", err=True, fg="red"
495            )
496            exit(1)
497        self.connect(
498            database=service_config.get("dbname"),
499            host=service_config.get("host"),
500            user=user or service_config.get("user"),
501            port=service_config.get("port"),
502            passwd=service_config.get("password"),
503        )
504
505    def connect_uri(self, uri):
506        kwargs = psycopg2.extensions.parse_dsn(uri)
507        remap = {"dbname": "database", "password": "passwd"}
508        kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
509        self.connect(**kwargs)
510
511    def connect(
512        self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
513    ):
514        # Connect to the database.
515
516        if not user:
517            user = getuser()
518
519        if not database:
520            database = user
521
522        kwargs.setdefault("application_name", "pgcli")
523
524        # If password prompt is not forced but no password is provided, try
525        # getting it from environment variable.
526        if not self.force_passwd_prompt and not passwd:
527            passwd = os.environ.get("PGPASSWORD", "")
528
529        # Find password from store
530        key = f"{user}@{host}"
531        keyring_error_message = dedent(
532            """\
533            {}
534            {}
535            To remove this message do one of the following:
536            - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
537            - uninstall keyring: pip uninstall keyring
538            - disable keyring in our configuration: add keyring = False to [main]"""
539        )
540        if not passwd and keyring:
541
542            try:
543                passwd = keyring.get_password("pgcli", key)
544            except (RuntimeError, keyring.errors.InitError) as e:
545                click.secho(
546                    keyring_error_message.format(
547                        "Load your password from keyring returned:", str(e)
548                    ),
549                    err=True,
550                    fg="red",
551                )
552
553        # Prompt for a password immediately if requested via the -W flag. This
554        # avoids wasting time trying to connect to the database and catching a
555        # no-password exception.
556        # If we successfully parsed a password from a URI, there's no need to
557        # prompt for it, even with the -W flag
558        if self.force_passwd_prompt and not passwd:
559            passwd = click.prompt(
560                "Password for %s" % user, hide_input=True, show_default=False, type=str
561            )
562
563        def should_ask_for_password(exc):
564            # Prompt for a password after 1st attempt to connect
565            # fails. Don't prompt if the -w flag is supplied
566            if self.never_passwd_prompt:
567                return False
568            error_msg = exc.args[0]
569            if "no password supplied" in error_msg:
570                return True
571            if "password authentication failed" in error_msg:
572                return True
573            return False
574
575        # Attempt to connect to the database.
576        # Note that passwd may be empty on the first attempt. If connection
577        # fails because of a missing or incorrect password, but we're allowed to
578        # prompt for a password (no -w flag), prompt for a passwd and try again.
579        try:
580            try:
581                pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
582            except (OperationalError, InterfaceError) as e:
583                if should_ask_for_password(e):
584                    passwd = click.prompt(
585                        "Password for %s" % user,
586                        hide_input=True,
587                        show_default=False,
588                        type=str,
589                    )
590                    pgexecute = PGExecute(
591                        database, user, passwd, host, port, dsn, **kwargs
592                    )
593                else:
594                    raise e
595            if passwd and keyring:
596                try:
597                    keyring.set_password("pgcli", key, passwd)
598                except (RuntimeError, keyring.errors.KeyringError) as e:
599                    click.secho(
600                        keyring_error_message.format(
601                            "Set password in keyring returned:", str(e)
602                        ),
603                        err=True,
604                        fg="red",
605                    )
606
607        except Exception as e:  # Connecting to a database could fail.
608            self.logger.debug("Database connection failed: %r.", e)
609            self.logger.error("traceback: %r", traceback.format_exc())
610            click.secho(str(e), err=True, fg="red")
611            exit(1)
612
613        self.pgexecute = pgexecute
614
615    def handle_editor_command(self, text):
616        r"""
617        Editor command is any query that is prefixed or suffixed
618        by a '\e'. The reason for a while loop is because a user
619        might edit a query multiple times.
620        For eg:
621        "select * from \e"<enter> to edit it in vim, then come
622        back to the prompt with the edited query "select * from
623        blah where q = 'abc'\e" to edit it again.
624        :param text: Document
625        :return: Document
626        """
627        editor_command = special.editor_command(text)
628        while editor_command:
629            if editor_command == "\\e":
630                filename = special.get_filename(text)
631                query = special.get_editor_query(text) or self.get_last_query()
632            else:  # \ev or \ef
633                filename = None
634                spec = text.split()[1]
635                if editor_command == "\\ev":
636                    query = self.pgexecute.view_definition(spec)
637                elif editor_command == "\\ef":
638                    query = self.pgexecute.function_definition(spec)
639            sql, message = special.open_external_editor(filename, sql=query)
640            if message:
641                # Something went wrong. Raise an exception and bail.
642                raise RuntimeError(message)
643            while True:
644                try:
645                    text = self.prompt_app.prompt(default=sql)
646                    break
647                except KeyboardInterrupt:
648                    sql = ""
649
650            editor_command = special.editor_command(text)
651        return text
652
653    def execute_command(self, text):
654        logger = self.logger
655
656        query = MetaQuery(query=text, successful=False)
657
658        try:
659            if self.destructive_warning != "off":
660                destroy = confirm = confirm_destructive_query(
661                    text, self.destructive_warning
662                )
663                if destroy is False:
664                    click.secho("Wise choice!")
665                    raise KeyboardInterrupt
666                elif destroy:
667                    click.secho("Your call!")
668            output, query = self._evaluate_command(text)
669        except KeyboardInterrupt:
670            # Restart connection to the database
671            self.pgexecute.connect()
672            logger.debug("cancelled query, sql: %r", text)
673            click.secho("cancelled query", err=True, fg="red")
674        except NotImplementedError:
675            click.secho("Not Yet Implemented.", fg="yellow")
676        except OperationalError as e:
677            logger.error("sql: %r, error: %r", text, e)
678            logger.error("traceback: %r", traceback.format_exc())
679            self._handle_server_closed_connection(text)
680        except (PgCliQuitError, EOFError) as e:
681            raise
682        except Exception as e:
683            logger.error("sql: %r, error: %r", text, e)
684            logger.error("traceback: %r", traceback.format_exc())
685            click.secho(str(e), err=True, fg="red")
686        else:
687            try:
688                if self.output_file and not text.startswith(("\\o ", "\\? ")):
689                    try:
690                        with open(self.output_file, "a", encoding="utf-8") as f:
691                            click.echo(text, file=f)
692                            click.echo("\n".join(output), file=f)
693                            click.echo("", file=f)  # extra newline
694                    except OSError as e:
695                        click.secho(str(e), err=True, fg="red")
696                else:
697                    if output:
698                        self.echo_via_pager("\n".join(output))
699            except KeyboardInterrupt:
700                pass
701
702            if self.pgspecial.timing_enabled:
703                # Only add humanized time display if > 1 second
704                if query.total_time > 1:
705                    print(
706                        "Time: %0.03fs (%s), executed in: %0.03fs (%s)"
707                        % (
708                            query.total_time,
709                            pendulum.Duration(seconds=query.total_time).in_words(),
710                            query.execution_time,
711                            pendulum.Duration(seconds=query.execution_time).in_words(),
712                        )
713                    )
714                else:
715                    print("Time: %0.03fs" % query.total_time)
716
717            # Check if we need to update completions, in order of most
718            # to least drastic changes
719            if query.db_changed:
720                with self._completer_lock:
721                    self.completer.reset_completions()
722                self.refresh_completions(persist_priorities="keywords")
723            elif query.meta_changed:
724                self.refresh_completions(persist_priorities="all")
725            elif query.path_changed:
726                logger.debug("Refreshing search path")
727                with self._completer_lock:
728                    self.completer.set_search_path(self.pgexecute.search_path())
729                logger.debug("Search path: %r", self.completer.search_path)
730        return query
731
732    def run_cli(self):
733        logger = self.logger
734
735        history_file = self.config["main"]["history_file"]
736        if history_file == "default":
737            history_file = config_location() + "history"
738        history = FileHistory(os.path.expanduser(history_file))
739        self.refresh_completions(history=history, persist_priorities="none")
740
741        self.prompt_app = self._build_cli(history)
742
743        if not self.less_chatty:
744            print("Server: PostgreSQL", self.pgexecute.server_version)
745            print("Version:", __version__)
746            print("Home: http://pgcli.com")
747
748        try:
749            while True:
750                try:
751                    text = self.prompt_app.prompt()
752                except KeyboardInterrupt:
753                    continue
754
755                try:
756                    text = self.handle_editor_command(text)
757                except RuntimeError as e:
758                    logger.error("sql: %r, error: %r", text, e)
759                    logger.error("traceback: %r", traceback.format_exc())
760                    click.secho(str(e), err=True, fg="red")
761                    continue
762
763                # Initialize default metaquery in case execution fails
764                self.watch_command, timing = special.get_watch_command(text)
765                if self.watch_command:
766                    while self.watch_command:
767                        try:
768                            query = self.execute_command(self.watch_command)
769                            click.echo(f"Waiting for {timing} seconds before repeating")
770                            sleep(timing)
771                        except KeyboardInterrupt:
772                            self.watch_command = None
773                else:
774                    query = self.execute_command(text)
775
776                self.now = dt.datetime.today()
777
778                # Allow PGCompleter to learn user's preferred keywords, etc.
779                with self._completer_lock:
780                    self.completer.extend_query_history(text)
781
782                self.query_history.append(query)
783
784        except (PgCliQuitError, EOFError):
785            if not self.less_chatty:
786                print("Goodbye!")
787
788    def _build_cli(self, history):
789        key_bindings = pgcli_bindings(self)
790
791        def get_message():
792            if self.dsn_alias and self.prompt_dsn_format is not None:
793                prompt_format = self.prompt_dsn_format
794            else:
795                prompt_format = self.prompt_format
796
797            prompt = self.get_prompt(prompt_format)
798
799            if (
800                prompt_format == self.default_prompt
801                and len(prompt) > self.max_len_prompt
802            ):
803                prompt = self.get_prompt("\\d> ")
804
805            prompt = prompt.replace("\\x1b", "\x1b")
806            return ANSI(prompt)
807
808        def get_continuation(width, line_number, is_soft_wrap):
809            continuation = self.multiline_continuation_char * (width - 1) + " "
810            return [("class:continuation", continuation)]
811
812        get_toolbar_tokens = create_toolbar_tokens_func(self)
813
814        if self.wider_completion_menu:
815            complete_style = CompleteStyle.MULTI_COLUMN
816        else:
817            complete_style = CompleteStyle.COLUMN
818
819        with self._completer_lock:
820            prompt_app = PromptSession(
821                lexer=PygmentsLexer(PostgresLexer),
822                reserve_space_for_menu=self.min_num_menu_lines,
823                message=get_message,
824                prompt_continuation=get_continuation,
825                bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
826                complete_style=complete_style,
827                input_processors=[
828                    # Highlight matching brackets while editing.
829                    ConditionalProcessor(
830                        processor=HighlightMatchingBracketProcessor(chars="[](){}"),
831                        filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
832                    ),
833                    # Render \t as 4 spaces instead of "^I"
834                    TabsProcessor(char1=" ", char2=" "),
835                ],
836                auto_suggest=AutoSuggestFromHistory(),
837                tempfile_suffix=".sql",
838                # N.b. pgcli's multi-line mode controls submit-on-Enter (which
839                # overrides the default behaviour of prompt_toolkit) and is
840                # distinct from prompt_toolkit's multiline mode here, which
841                # controls layout/display of the prompt/buffer
842                multiline=True,
843                history=history,
844                completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)),
845                complete_while_typing=True,
846                style=style_factory(self.syntax_style, self.cli_style),
847                include_default_pygments_style=False,
848                key_bindings=key_bindings,
849                enable_open_in_editor=True,
850                enable_system_prompt=True,
851                enable_suspend=True,
852                editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
853                search_ignore_case=True,
854            )
855
856            return prompt_app
857
858    def _should_limit_output(self, sql, cur):
859        """returns True if the output should be truncated, False otherwise."""
860        if not is_select(sql):
861            return False
862
863        return (
864            not self._has_limit(sql)
865            and self.row_limit != 0
866            and cur
867            and cur.rowcount > self.row_limit
868        )
869
870    def _has_limit(self, sql):
871        if not sql:
872            return False
873        return "limit " in sql.lower()
874
875    def _limit_output(self, cur):
876        limit = min(self.row_limit, cur.rowcount)
877        new_cur = itertools.islice(cur, limit)
878        new_status = "SELECT " + str(limit)
879        click.secho("The result was limited to %s rows" % limit, fg="red")
880
881        return new_cur, new_status
882
883    def _evaluate_command(self, text):
884        """Used to run a command entered by the user during CLI operation
885        (Puts the E in REPL)
886
887        returns (results, MetaQuery)
888        """
889        logger = self.logger
890        logger.debug("sql: %r", text)
891
892        all_success = True
893        meta_changed = False  # CREATE, ALTER, DROP, etc
894        mutated = False  # INSERT, DELETE, etc
895        db_changed = False
896        path_changed = False
897        output = []
898        total = 0
899        execution = 0
900
901        # Run the query.
902        start = time()
903        on_error_resume = self.on_error == "RESUME"
904        res = self.pgexecute.run(
905            text, self.pgspecial, exception_formatter, on_error_resume
906        )
907
908        is_special = None
909
910        for title, cur, headers, status, sql, success, is_special in res:
911            logger.debug("headers: %r", headers)
912            logger.debug("rows: %r", cur)
913            logger.debug("status: %r", status)
914
915            if self._should_limit_output(sql, cur):
916                cur, status = self._limit_output(cur)
917
918            if self.pgspecial.auto_expand or self.auto_expand:
919                max_width = self.prompt_app.output.get_size().columns
920            else:
921                max_width = None
922
923            expanded = self.pgspecial.expanded_output or self.expanded_output
924            settings = OutputSettings(
925                table_format=self.table_format,
926                dcmlfmt=self.decimal_format,
927                floatfmt=self.float_format,
928                missingval=self.null_string,
929                expanded=expanded,
930                max_width=max_width,
931                case_function=(
932                    self.completer.case
933                    if self.settings["case_column_headers"]
934                    else lambda x: x
935                ),
936                style_output=self.style_output,
937            )
938            execution = time() - start
939            formatted = format_output(title, cur, headers, status, settings)
940
941            output.extend(formatted)
942            total = time() - start
943
944            # Keep track of whether any of the queries are mutating or changing
945            # the database
946            if success:
947                mutated = mutated or is_mutating(status)
948                db_changed = db_changed or has_change_db_cmd(sql)
949                meta_changed = meta_changed or has_meta_cmd(sql)
950                path_changed = path_changed or has_change_path_cmd(sql)
951            else:
952                all_success = False
953
954        meta_query = MetaQuery(
955            text,
956            all_success,
957            total,
958            execution,
959            meta_changed,
960            db_changed,
961            path_changed,
962            mutated,
963            is_special,
964        )
965
966        return output, meta_query
967
968    def _handle_server_closed_connection(self, text):
969        """Used during CLI execution."""
970        try:
971            click.secho("Reconnecting...", fg="green")
972            self.pgexecute.connect()
973            click.secho("Reconnected!", fg="green")
974            self.execute_command(text)
975        except OperationalError as e:
976            click.secho("Reconnect Failed", fg="red")
977            click.secho(str(e), err=True, fg="red")
978
979    def refresh_completions(self, history=None, persist_priorities="all"):
980        """Refresh outdated completions
981
982        :param history: A prompt_toolkit.history.FileHistory object. Used to
983                        load keyword and identifier preferences
984
985        :param persist_priorities: 'all' or 'keywords'
986        """
987
988        callback = functools.partial(
989            self._on_completions_refreshed, persist_priorities=persist_priorities
990        )
991        return self.completion_refresher.refresh(
992            self.pgexecute,
993            self.pgspecial,
994            callback,
995            history=history,
996            settings=self.settings,
997        )
998
999    def _on_completions_refreshed(self, new_completer, persist_priorities):
1000        self._swap_completer_objects(new_completer, persist_priorities)
1001
1002        if self.prompt_app:
1003            # After refreshing, redraw the CLI to clear the statusbar
1004            # "Refreshing completions..." indicator
1005            self.prompt_app.app.invalidate()
1006
1007    def _swap_completer_objects(self, new_completer, persist_priorities):
1008        """Swap the completer object with the newly created completer.
1009
1010        persist_priorities is a string specifying how the old completer's
1011        learned prioritizer should be transferred to the new completer.
1012
1013          'none'     - The new prioritizer is left in a new/clean state
1014
1015          'all'      - The new prioritizer is updated to exactly reflect
1016                       the old one
1017
1018          'keywords' - The new prioritizer is updated with old keyword
1019                       priorities, but not any other.
1020
1021        """
1022        with self._completer_lock:
1023            old_completer = self.completer
1024            self.completer = new_completer
1025
1026            if persist_priorities == "all":
1027                # Just swap over the entire prioritizer
1028                new_completer.prioritizer = old_completer.prioritizer
1029            elif persist_priorities == "keywords":
1030                # Swap over the entire prioritizer, but clear name priorities,
1031                # leaving learned keyword priorities alone
1032                new_completer.prioritizer = old_completer.prioritizer
1033                new_completer.prioritizer.clear_names()
1034            elif persist_priorities == "none":
1035                # Leave the new prioritizer as is
1036                pass
1037            self.completer = new_completer
1038
1039    def get_completions(self, text, cursor_positition):
1040        with self._completer_lock:
1041            return self.completer.get_completions(
1042                Document(text=text, cursor_position=cursor_positition), None
1043            )
1044
1045    def get_prompt(self, string):
1046        # should be before replacing \\d
1047        string = string.replace("\\dsn_alias", self.dsn_alias or "")
1048        string = string.replace("\\t", self.now.strftime("%x %X"))
1049        string = string.replace("\\u", self.pgexecute.user or "(none)")
1050        string = string.replace("\\H", self.pgexecute.host or "(none)")
1051        string = string.replace("\\h", self.pgexecute.short_host or "(none)")
1052        string = string.replace("\\d", self.pgexecute.dbname or "(none)")
1053        string = string.replace(
1054            "\\p",
1055            str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
1056        )
1057        string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
1058        string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
1059        string = string.replace("\\n", "\n")
1060        return string
1061
1062    def get_last_query(self):
1063        """Get the last query executed or None."""
1064        return self.query_history[-1][0] if self.query_history else None
1065
1066    def is_too_wide(self, line):
1067        """Will this line be too wide to fit into terminal?"""
1068        if not self.prompt_app:
1069            return False
1070        return (
1071            len(COLOR_CODE_REGEX.sub("", line))
1072            > self.prompt_app.output.get_size().columns
1073        )
1074
1075    def is_too_tall(self, lines):
1076        """Are there too many lines to fit into terminal?"""
1077        if not self.prompt_app:
1078            return False
1079        return len(lines) >= (self.prompt_app.output.get_size().rows - 4)
1080
1081    def echo_via_pager(self, text, color=None):
1082        if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
1083            click.echo(text, color=color)
1084        elif (
1085            self.pgspecial.pager_config == PAGER_LONG_OUTPUT
1086            and self.table_format != "csv"
1087        ):
1088            lines = text.split("\n")
1089
1090            # The last 4 lines are reserved for the pgcli menu and padding
1091            if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
1092                click.echo_via_pager(text, color=color)
1093            else:
1094                click.echo(text, color=color)
1095        else:
1096            click.echo_via_pager(text, color)
1097
1098
1099@click.command()
1100# Default host is '' so psycopg2 can default to either localhost or unix socket
1101@click.option(
1102    "-h",
1103    "--host",
1104    default="",
1105    envvar="PGHOST",
1106    help="Host address of the postgres database.",
1107)
1108@click.option(
1109    "-p",
1110    "--port",
1111    default=5432,
1112    help="Port number at which the " "postgres instance is listening.",
1113    envvar="PGPORT",
1114    type=click.INT,
1115)
1116@click.option(
1117    "-U",
1118    "--username",
1119    "username_opt",
1120    help="Username to connect to the postgres database.",
1121)
1122@click.option(
1123    "-u", "--user", "username_opt", help="Username to connect to the postgres database."
1124)
1125@click.option(
1126    "-W",
1127    "--password",
1128    "prompt_passwd",
1129    is_flag=True,
1130    default=False,
1131    help="Force password prompt.",
1132)
1133@click.option(
1134    "-w",
1135    "--no-password",
1136    "never_prompt",
1137    is_flag=True,
1138    default=False,
1139    help="Never prompt for password.",
1140)
1141@click.option(
1142    "--single-connection",
1143    "single_connection",
1144    is_flag=True,
1145    default=False,
1146    help="Do not use a separate connection for completions.",
1147)
1148@click.option("-v", "--version", is_flag=True, help="Version of pgcli.")
1149@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.")
1150@click.option(
1151    "--pgclirc",
1152    default=config_location() + "config",
1153    envvar="PGCLIRC",
1154    help="Location of pgclirc file.",
1155    type=click.Path(dir_okay=False),
1156)
1157@click.option(
1158    "-D",
1159    "--dsn",
1160    default="",
1161    envvar="DSN",
1162    help="Use DSN configured into the [alias_dsn] section of pgclirc file.",
1163)
1164@click.option(
1165    "--list-dsn",
1166    "list_dsn",
1167    is_flag=True,
1168    help="list of DSN configured into the [alias_dsn] section of pgclirc file.",
1169)
1170@click.option(
1171    "--row-limit",
1172    default=None,
1173    envvar="PGROWLIMIT",
1174    type=click.INT,
1175    help="Set threshold for row limit prompt. Use 0 to disable prompt.",
1176)
1177@click.option(
1178    "--less-chatty",
1179    "less_chatty",
1180    is_flag=True,
1181    default=False,
1182    help="Skip intro on startup and goodbye on exit.",
1183)
1184@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").')
1185@click.option(
1186    "--prompt-dsn",
1187    help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").',
1188)
1189@click.option(
1190    "-l",
1191    "--list",
1192    "list_databases",
1193    is_flag=True,
1194    help="list " "available databases, then exit.",
1195)
1196@click.option(
1197    "--auto-vertical-output",
1198    is_flag=True,
1199    help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
1200)
1201@click.option(
1202    "--warn",
1203    default=None,
1204    type=click.Choice(["all", "moderate", "off"]),
1205    help="Warn before running a destructive query.",
1206)
1207@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
1208@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
1209def cli(
1210    dbname,
1211    username_opt,
1212    host,
1213    port,
1214    prompt_passwd,
1215    never_prompt,
1216    single_connection,
1217    dbname_opt,
1218    username,
1219    version,
1220    pgclirc,
1221    dsn,
1222    row_limit,
1223    less_chatty,
1224    prompt,
1225    prompt_dsn,
1226    list_databases,
1227    auto_vertical_output,
1228    list_dsn,
1229    warn,
1230):
1231    if version:
1232        print("Version:", __version__)
1233        sys.exit(0)
1234
1235    config_dir = os.path.dirname(config_location())
1236    if not os.path.exists(config_dir):
1237        os.makedirs(config_dir)
1238
1239    # Migrate the config file from old location.
1240    config_full_path = config_location() + "config"
1241    if os.path.exists(os.path.expanduser("~/.pgclirc")):
1242        if not os.path.exists(config_full_path):
1243            shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path)
1244            print("Config file (~/.pgclirc) moved to new location", config_full_path)
1245        else:
1246            print("Config file is now located at", config_full_path)
1247            print(
1248                "Please move the existing config file ~/.pgclirc to",
1249                config_full_path,
1250            )
1251    if list_dsn:
1252        try:
1253            cfg = load_config(pgclirc, config_full_path)
1254            for alias in cfg["alias_dsn"]:
1255                click.secho(alias + " : " + cfg["alias_dsn"][alias])
1256            sys.exit(0)
1257        except Exception as err:
1258            click.secho(
1259                "Invalid DSNs found in the config file. "
1260                'Please check the "[alias_dsn]" section in pgclirc.',
1261                err=True,
1262                fg="red",
1263            )
1264            exit(1)
1265
1266    pgcli = PGCli(
1267        prompt_passwd,
1268        never_prompt,
1269        pgclirc_file=pgclirc,
1270        row_limit=row_limit,
1271        single_connection=single_connection,
1272        less_chatty=less_chatty,
1273        prompt=prompt,
1274        prompt_dsn=prompt_dsn,
1275        auto_vertical_output=auto_vertical_output,
1276        warn=warn,
1277    )
1278
1279    # Choose which ever one has a valid value.
1280    if dbname_opt and dbname:
1281        # work as psql: when database is given as option and argument use the argument as user
1282        username = dbname
1283    database = dbname_opt or dbname or ""
1284    user = username_opt or username
1285    service = None
1286    if database.startswith("service="):
1287        service = database[8:]
1288    elif os.getenv("PGSERVICE") is not None:
1289        service = os.getenv("PGSERVICE")
1290    # because option --list or -l are not supposed to have a db name
1291    if list_databases:
1292        database = "postgres"
1293
1294    if dsn != "":
1295        try:
1296            cfg = load_config(pgclirc, config_full_path)
1297            dsn_config = cfg["alias_dsn"][dsn]
1298        except KeyError:
1299            click.secho(
1300                f"Could not find a DSN with alias {dsn}. "
1301                'Please check the "[alias_dsn]" section in pgclirc.',
1302                err=True,
1303                fg="red",
1304            )
1305            exit(1)
1306        except Exception:
1307            click.secho(
1308                "Invalid DSNs found in the config file. "
1309                'Please check the "[alias_dsn]" section in pgclirc.',
1310                err=True,
1311                fg="red",
1312            )
1313            exit(1)
1314        pgcli.connect_uri(dsn_config)
1315        pgcli.dsn_alias = dsn
1316    elif "://" in database:
1317        pgcli.connect_uri(database)
1318    elif "=" in database and service is None:
1319        pgcli.connect_dsn(database, user=user)
1320    elif service is not None:
1321        pgcli.connect_service(service, user)
1322    else:
1323        pgcli.connect(database, host, user, port)
1324
1325    if list_databases:
1326        cur, headers, status = pgcli.pgexecute.full_databases()
1327
1328        title = "List of databases"
1329        settings = OutputSettings(table_format="ascii", missingval="<null>")
1330        formatted = format_output(title, cur, headers, status, settings)
1331        pgcli.echo_via_pager("\n".join(formatted))
1332
1333        sys.exit(0)
1334
1335    pgcli.logger.debug(
1336        "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r",
1337        database,
1338        user,
1339        host,
1340        port,
1341    )
1342
1343    if setproctitle:
1344        obfuscate_process_password()
1345
1346    pgcli.run_cli()
1347
1348
1349def obfuscate_process_password():
1350    process_title = setproctitle.getproctitle()
1351    if "://" in process_title:
1352        process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title)
1353    elif "=" in process_title:
1354        process_title = re.sub(
1355            r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title
1356        )
1357
1358    setproctitle.setproctitle(process_title)
1359
1360
1361def has_meta_cmd(query):
1362    """Determines if the completion needs a refresh by checking if the sql
1363    statement is an alter, create, drop, commit or rollback."""
1364    try:
1365        first_token = query.split()[0]
1366        if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"):
1367            return True
1368    except Exception:
1369        return False
1370
1371    return False
1372
1373
1374def has_change_db_cmd(query):
1375    """Determines if the statement is a database switch such as 'use' or '\\c'"""
1376    try:
1377        first_token = query.split()[0]
1378        if first_token.lower() in ("use", "\\c", "\\connect"):
1379            return True
1380    except Exception:
1381        return False
1382
1383    return False
1384
1385
1386def has_change_path_cmd(sql):
1387    """Determines if the search_path should be refreshed by checking if the
1388    sql has 'set search_path'."""
1389    return "set search_path" in sql.lower()
1390
1391
1392def is_mutating(status):
1393    """Determines if the statement is mutating based on the status."""
1394    if not status:
1395        return False
1396
1397    mutating = {"insert", "update", "delete"}
1398    return status.split(None, 1)[0].lower() in mutating
1399
1400
1401def is_select(status):
1402    """Returns true if the first word in status is 'select'."""
1403    if not status:
1404        return False
1405    return status.split(None, 1)[0].lower() == "select"
1406
1407
1408def exception_formatter(e):
1409    return click.style(str(e), fg="red")
1410
1411
1412def format_output(title, cur, headers, status, settings):
1413    output = []
1414    expanded = settings.expanded or settings.table_format == "vertical"
1415    table_format = "vertical" if settings.expanded else settings.table_format
1416    max_width = settings.max_width
1417    case_function = settings.case_function
1418    formatter = TabularOutputFormatter(format_name=table_format)
1419
1420    def format_array(val):
1421        if val is None:
1422            return settings.missingval
1423        if not isinstance(val, list):
1424            return val
1425        return "{" + ",".join(str(format_array(e)) for e in val) + "}"
1426
1427    def format_arrays(data, headers, **_):
1428        data = list(data)
1429        for row in data:
1430            row[:] = [
1431                format_array(val) if isinstance(val, list) else val for val in row
1432            ]
1433
1434        return data, headers
1435
1436    output_kwargs = {
1437        "sep_title": "RECORD {n}",
1438        "sep_character": "-",
1439        "sep_length": (1, 25),
1440        "missing_value": settings.missingval,
1441        "integer_format": settings.dcmlfmt,
1442        "float_format": settings.floatfmt,
1443        "preprocessors": (format_numbers, format_arrays),
1444        "disable_numparse": True,
1445        "preserve_whitespace": True,
1446        "style": settings.style_output,
1447    }
1448    if not settings.floatfmt:
1449        output_kwargs["preprocessors"] = (align_decimals,)
1450
1451    if table_format == "csv":
1452        # The default CSV dialect is "excel" which is not handling newline values correctly
1453        # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n'
1454        # as the line terminator
1455        # https://github.com/dbcli/pgcli/issues/1102
1456        dialect = "excel" if platform.system() == "Windows" else "unix"
1457        output_kwargs["dialect"] = dialect
1458
1459    if title:  # Only print the title if it's not None.
1460        output.append(title)
1461
1462    if cur:
1463        headers = [case_function(x) for x in headers]
1464        if max_width is not None:
1465            cur = list(cur)
1466        column_types = None
1467        if hasattr(cur, "description"):
1468            column_types = []
1469            for d in cur.description:
1470                if (
1471                    d[1] in psycopg2.extensions.DECIMAL.values
1472                    or d[1] in psycopg2.extensions.FLOAT.values
1473                ):
1474                    column_types.append(float)
1475                if (
1476                    d[1] == psycopg2.extensions.INTEGER.values
1477                    or d[1] in psycopg2.extensions.LONGINTEGER.values
1478                ):
1479                    column_types.append(int)
1480                else:
1481                    column_types.append(str)
1482
1483        formatted = formatter.format_output(cur, headers, **output_kwargs)
1484        if isinstance(formatted, str):
1485            formatted = iter(formatted.splitlines())
1486        first_line = next(formatted)
1487        formatted = itertools.chain([first_line], formatted)
1488        if (
1489            not expanded
1490            and max_width
1491            and len(strip_ansi(first_line)) > max_width
1492            and headers
1493        ):
1494            formatted = formatter.format_output(
1495                cur, headers, format_name="vertical", column_types=None, **output_kwargs
1496            )
1497            if isinstance(formatted, str):
1498                formatted = iter(formatted.splitlines())
1499
1500        output = itertools.chain(output, formatted)
1501
1502    # Only print the status if it's not None and we are not producing CSV
1503    if status and table_format != "csv":
1504        output = itertools.chain(output, [status])
1505
1506    return output
1507
1508
1509def parse_service_info(service):
1510    service = service or os.getenv("PGSERVICE")
1511    service_file = os.getenv("PGSERVICEFILE")
1512    if not service_file:
1513        # try ~/.pg_service.conf (if that exists)
1514        if platform.system() == "Windows":
1515            service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf"
1516        elif os.getenv("PGSYSCONFDIR"):
1517            service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
1518        else:
1519            service_file = expanduser("~/.pg_service.conf")
1520    if not service or not os.path.exists(service_file):
1521        # nothing to do
1522        return None, service_file
1523    with open(service_file, newline="") as f:
1524        skipped_lines = skip_initial_comment(f)
1525        try:
1526            service_file_config = ConfigObj(f)
1527        except ParseError as err:
1528            err.line_number += skipped_lines
1529            raise err
1530    if service not in service_file_config:
1531        return None, service_file
1532    service_conf = service_file_config.get(service)
1533    return service_conf, service_file
1534
1535
1536if __name__ == "__main__":
1537    cli()
1538