1# -*- coding: utf-8 -*-
2import time
3import traceback
4from contextlib import contextmanager
5
6import django
7from django.conf import settings
8from django.core.exceptions import ImproperlyConfigured
9from django.db.backends import utils
10
11
12@contextmanager
13def monkey_patch_cursordebugwrapper(print_sql=None, print_sql_location=False, truncate=None, logger=print, confprefix="DJANGO_EXTENSIONS"):
14    if not print_sql:
15        yield
16    else:
17        truncate = getattr(settings, '%s_PRINT_SQL_TRUNCATE' % confprefix, 1000)
18
19        # Code orginally from http://gist.github.com/118990
20        sqlparse = None
21        if getattr(settings, '%s_SQLPARSE_ENABLED' % confprefix, True):
22            try:
23                import sqlparse
24
25                sqlparse_format_kwargs_defaults = dict(
26                    reindent_aligned=True,
27                    truncate_strings=500,
28                )
29                sqlparse_format_kwargs = getattr(settings, '%s_SQLPARSE_FORMAT_KWARGS' % confprefix, sqlparse_format_kwargs_defaults)
30            except ImportError:
31                sqlparse = None
32
33        pygments = None
34        if getattr(settings, '%s_PYGMENTS_ENABLED' % confprefix, True):
35            try:
36                import pygments.lexers
37                import pygments.formatters
38
39                pygments_formatter = getattr(settings, '%s_PYGMENTS_FORMATTER' % confprefix, pygments.formatters.TerminalFormatter)
40                pygments_formatter_kwargs = getattr(settings, '%s_PYGMENTS_FORMATTER_KWARGS' % confprefix, {})
41            except ImportError:
42                pass
43
44        class PrintQueryWrapperMixin:
45            def execute(self, sql, params=()):
46                starttime = time.time()
47                try:
48                    return utils.CursorWrapper.execute(self, sql, params)
49                finally:
50                    execution_time = time.time() - starttime
51                    raw_sql = self.db.ops.last_executed_query(self.cursor, sql, params)
52                    if truncate:
53                        raw_sql = raw_sql[:truncate]
54
55                    if sqlparse:
56                        raw_sql = sqlparse.format(raw_sql, **sqlparse_format_kwargs)
57
58                    if pygments:
59                        raw_sql = pygments.highlight(
60                            raw_sql,
61                            pygments.lexers.get_lexer_by_name("sql"),
62                            pygments_formatter(**pygments_formatter_kwargs),
63                        )
64
65                    logger(raw_sql)
66                    logger("Execution time: %.6fs [Database: %s]" % (execution_time, self.db.alias))
67                    if print_sql_location:
68                        logger("Location of SQL Call:")
69                        logger(''.join(traceback.format_stack()))
70
71        _CursorDebugWrapper = utils.CursorDebugWrapper
72
73        class PrintCursorQueryWrapper(PrintQueryWrapperMixin, _CursorDebugWrapper):
74            pass
75
76        try:
77            from django.db import connections
78            _force_debug_cursor = {}
79            for connection_name in connections:
80                _force_debug_cursor[connection_name] = connections[connection_name].force_debug_cursor
81        except Exception:
82            connections = None
83
84        utils.CursorDebugWrapper = PrintCursorQueryWrapper
85
86        postgresql_base = None
87        if django.VERSION >= (3, 0):
88            try:
89                from django.db.backends.postgresql import base as postgresql_base
90                _PostgreSQLCursorDebugWrapper = postgresql_base.CursorDebugWrapper
91
92                class PostgreSQLPrintCursorDebugWrapper(PrintQueryWrapperMixin, _PostgreSQLCursorDebugWrapper):
93                    pass
94            except (ImproperlyConfigured, TypeError):
95                postgresql_base = None
96
97        if postgresql_base:
98            postgresql_base.CursorDebugWrapper = PostgreSQLPrintCursorDebugWrapper
99
100        if connections:
101            for connection_name in connections:
102                connections[connection_name].force_debug_cursor = True
103
104        yield
105
106        utils.CursorDebugWrapper = _CursorDebugWrapper
107
108        if postgresql_base:
109            postgresql_base.CursorDebugWrapper = _PostgreSQLCursorDebugWrapper
110
111        if connections:
112            for connection_name in connections:
113                connections[connection_name].force_debug_cursor = _force_debug_cursor[connection_name]
114