1""" 2Transaction context managers returned by Connection.transaction() 3""" 4 5# Copyright (C) 2020-2021 The Psycopg Team 6 7import logging 8 9from types import TracebackType 10from typing import Generic, Optional, Type, Union, TYPE_CHECKING 11 12from . import pq 13from . import sql 14from .pq import TransactionStatus 15from .abc import ConnectionType, PQGen 16from .pq.abc import PGresult 17 18if TYPE_CHECKING: 19 from typing import Any 20 from .connection import Connection 21 from .connection_async import AsyncConnection 22 23logger = logging.getLogger(__name__) 24 25 26class Rollback(Exception): 27 """ 28 Exit the current `Transaction` context immediately and rollback any changes 29 made within this context. 30 31 If a transaction context is specified in the constructor, rollback 32 enclosing transactions contexts up to and including the one specified. 33 """ 34 35 __module__ = "psycopg" 36 37 def __init__( 38 self, 39 transaction: Union["Transaction", "AsyncTransaction", None] = None, 40 ): 41 self.transaction = transaction 42 43 def __repr__(self) -> str: 44 return f"{self.__class__.__qualname__}({self.transaction!r})" 45 46 47class BaseTransaction(Generic[ConnectionType]): 48 def __init__( 49 self, 50 connection: ConnectionType, 51 savepoint_name: Optional[str] = None, 52 force_rollback: bool = False, 53 ): 54 self._conn = connection 55 self._savepoint_name = savepoint_name or "" 56 self.force_rollback = force_rollback 57 self._entered = self._exited = False 58 59 @property 60 def savepoint_name(self) -> Optional[str]: 61 """ 62 The name of the savepoint; `!None` if handling the main transaction. 63 """ 64 # Yes, it may change on __enter__. No, I don't care, because the 65 # un-entered state is outside the public interface. 66 return self._savepoint_name 67 68 def __repr__(self) -> str: 69 cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" 70 info = pq.misc.connection_summary(self._conn.pgconn) 71 if not self._entered: 72 status = "inactive" 73 elif not self._exited: 74 status = "active" 75 else: 76 status = "terminated" 77 78 sp = f"{self.savepoint_name!r} " if self.savepoint_name else "" 79 return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>" 80 81 def _enter_gen(self) -> PQGen[PGresult]: 82 if self._entered: 83 raise TypeError("transaction blocks can be used only once") 84 self._entered = True 85 86 self._outer_transaction = ( 87 self._conn.pgconn.transaction_status == TransactionStatus.IDLE 88 ) 89 if self._outer_transaction: 90 # outer transaction: if no name it's only a begin, else 91 # there will be an additional savepoint 92 assert not self._conn._savepoints 93 else: 94 # inner transaction: it always has a name 95 if not self._savepoint_name: 96 self._savepoint_name = ( 97 f"_pg3_{len(self._conn._savepoints) + 1}" 98 ) 99 100 commands = [] 101 if self._outer_transaction: 102 assert not self._conn._savepoints, self._conn._savepoints 103 commands.append(self._conn._get_tx_start_command()) 104 105 if self._savepoint_name: 106 commands.append( 107 sql.SQL("SAVEPOINT {}") 108 .format(sql.Identifier(self._savepoint_name)) 109 .as_bytes(self._conn) 110 ) 111 112 self._conn._savepoints.append(self._savepoint_name) 113 return self._conn._exec_command(b"; ".join(commands)) 114 115 def _exit_gen( 116 self, 117 exc_type: Optional[Type[BaseException]], 118 exc_val: Optional[BaseException], 119 exc_tb: Optional[TracebackType], 120 ) -> PQGen[bool]: 121 if not exc_val and not self.force_rollback: 122 yield from self._commit_gen() 123 return False 124 else: 125 # try to rollback, but if there are problems (connection in a bad 126 # state) just warn without clobbering the exception bubbling up. 127 try: 128 return (yield from self._rollback_gen(exc_val)) 129 except Exception as exc2: 130 logger.warning( 131 "error ignored in rollback of %s: %s", 132 self, 133 exc2, 134 ) 135 return False 136 137 def _commit_gen(self) -> PQGen[PGresult]: 138 assert self._conn._savepoints[-1] == self._savepoint_name 139 self._conn._savepoints.pop() 140 self._exited = True 141 142 commands = [] 143 if self._savepoint_name and not self._outer_transaction: 144 commands.append( 145 sql.SQL("RELEASE {}") 146 .format(sql.Identifier(self._savepoint_name)) 147 .as_bytes(self._conn) 148 ) 149 150 if self._outer_transaction: 151 assert not self._conn._savepoints 152 commands.append(b"COMMIT") 153 154 return self._conn._exec_command(b"; ".join(commands)) 155 156 def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: 157 if isinstance(exc_val, Rollback): 158 logger.debug( 159 f"{self._conn}: Explicit rollback from: ", exc_info=True 160 ) 161 162 assert self._conn._savepoints[-1] == self._savepoint_name 163 self._conn._savepoints.pop() 164 165 commands = [] 166 if self._savepoint_name and not self._outer_transaction: 167 commands.append( 168 sql.SQL("ROLLBACK TO {n}; RELEASE {n}") 169 .format(n=sql.Identifier(self._savepoint_name)) 170 .as_bytes(self._conn) 171 ) 172 173 if self._outer_transaction: 174 assert not self._conn._savepoints 175 commands.append(b"ROLLBACK") 176 177 # Also clear the prepared statements cache. 178 cmd = self._conn._prepared.clear() 179 if cmd: 180 commands.append(cmd) 181 182 yield from self._conn._exec_command(b"; ".join(commands)) 183 184 if isinstance(exc_val, Rollback): 185 if not exc_val.transaction or exc_val.transaction is self: 186 return True # Swallow the exception 187 188 return False 189 190 191class Transaction(BaseTransaction["Connection[Any]"]): 192 """ 193 Returned by `Connection.transaction()` to handle a transaction block. 194 """ 195 196 __module__ = "psycopg" 197 198 @property 199 def connection(self) -> "Connection[Any]": 200 """The connection the object is managing.""" 201 return self._conn 202 203 def __enter__(self) -> "Transaction": 204 with self._conn.lock: 205 self._conn.wait(self._enter_gen()) 206 return self 207 208 def __exit__( 209 self, 210 exc_type: Optional[Type[BaseException]], 211 exc_val: Optional[BaseException], 212 exc_tb: Optional[TracebackType], 213 ) -> bool: 214 with self._conn.lock: 215 return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) 216 217 218class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): 219 """ 220 Returned by `AsyncConnection.transaction()` to handle a transaction block. 221 """ 222 223 __module__ = "psycopg" 224 225 @property 226 def connection(self) -> "AsyncConnection[Any]": 227 return self._conn 228 229 async def __aenter__(self) -> "AsyncTransaction": 230 async with self._conn.lock: 231 await self._conn.wait(self._enter_gen()) 232 return self 233 234 async def __aexit__( 235 self, 236 exc_type: Optional[Type[BaseException]], 237 exc_val: Optional[BaseException], 238 exc_tb: Optional[TracebackType], 239 ) -> bool: 240 async with self._conn.lock: 241 return await self._conn.wait( 242 self._exit_gen(exc_type, exc_val, exc_tb) 243 ) 244