1# -*- coding: utf-8 -*- 2import time 3import traceback 4from contextlib import contextmanager 5 6import django 7from django.conf import settings 8from django.core.exceptions import ImproperlyConfigured 9from django.db.backends import utils 10 11 12@contextmanager 13def monkey_patch_cursordebugwrapper(print_sql=None, print_sql_location=False, truncate=None, logger=print, confprefix="DJANGO_EXTENSIONS"): 14 if not print_sql: 15 yield 16 else: 17 truncate = getattr(settings, '%s_PRINT_SQL_TRUNCATE' % confprefix, 1000) 18 19 # Code orginally from http://gist.github.com/118990 20 sqlparse = None 21 if getattr(settings, '%s_SQLPARSE_ENABLED' % confprefix, True): 22 try: 23 import sqlparse 24 25 sqlparse_format_kwargs_defaults = dict( 26 reindent_aligned=True, 27 truncate_strings=500, 28 ) 29 sqlparse_format_kwargs = getattr(settings, '%s_SQLPARSE_FORMAT_KWARGS' % confprefix, sqlparse_format_kwargs_defaults) 30 except ImportError: 31 sqlparse = None 32 33 pygments = None 34 if getattr(settings, '%s_PYGMENTS_ENABLED' % confprefix, True): 35 try: 36 import pygments.lexers 37 import pygments.formatters 38 39 pygments_formatter = getattr(settings, '%s_PYGMENTS_FORMATTER' % confprefix, pygments.formatters.TerminalFormatter) 40 pygments_formatter_kwargs = getattr(settings, '%s_PYGMENTS_FORMATTER_KWARGS' % confprefix, {}) 41 except ImportError: 42 pass 43 44 class PrintQueryWrapperMixin: 45 def execute(self, sql, params=()): 46 starttime = time.time() 47 try: 48 return utils.CursorWrapper.execute(self, sql, params) 49 finally: 50 execution_time = time.time() - starttime 51 raw_sql = self.db.ops.last_executed_query(self.cursor, sql, params) 52 if truncate: 53 raw_sql = raw_sql[:truncate] 54 55 if sqlparse: 56 raw_sql = sqlparse.format(raw_sql, **sqlparse_format_kwargs) 57 58 if pygments: 59 raw_sql = pygments.highlight( 60 raw_sql, 61 pygments.lexers.get_lexer_by_name("sql"), 62 pygments_formatter(**pygments_formatter_kwargs), 63 ) 64 65 logger(raw_sql) 66 logger("Execution time: %.6fs [Database: %s]" % (execution_time, self.db.alias)) 67 if print_sql_location: 68 logger("Location of SQL Call:") 69 logger(''.join(traceback.format_stack())) 70 71 _CursorDebugWrapper = utils.CursorDebugWrapper 72 73 class PrintCursorQueryWrapper(PrintQueryWrapperMixin, _CursorDebugWrapper): 74 pass 75 76 try: 77 from django.db import connections 78 _force_debug_cursor = {} 79 for connection_name in connections: 80 _force_debug_cursor[connection_name] = connections[connection_name].force_debug_cursor 81 except Exception: 82 connections = None 83 84 utils.CursorDebugWrapper = PrintCursorQueryWrapper 85 86 postgresql_base = None 87 if django.VERSION >= (3, 0): 88 try: 89 from django.db.backends.postgresql import base as postgresql_base 90 _PostgreSQLCursorDebugWrapper = postgresql_base.CursorDebugWrapper 91 92 class PostgreSQLPrintCursorDebugWrapper(PrintQueryWrapperMixin, _PostgreSQLCursorDebugWrapper): 93 pass 94 except (ImproperlyConfigured, TypeError): 95 postgresql_base = None 96 97 if postgresql_base: 98 postgresql_base.CursorDebugWrapper = PostgreSQLPrintCursorDebugWrapper 99 100 if connections: 101 for connection_name in connections: 102 connections[connection_name].force_debug_cursor = True 103 104 yield 105 106 utils.CursorDebugWrapper = _CursorDebugWrapper 107 108 if postgresql_base: 109 postgresql_base.CursorDebugWrapper = _PostgreSQLCursorDebugWrapper 110 111 if connections: 112 for connection_name in connections: 113 connections[connection_name].force_debug_cursor = _force_debug_cursor[connection_name] 114