1"""
2Transaction context managers returned by Connection.transaction()
3"""
4
5# Copyright (C) 2020-2021 The Psycopg Team
6
7import logging
8
9from types import TracebackType
10from typing import Generic, Optional, Type, Union, TYPE_CHECKING
11
12from . import pq
13from . import sql
14from .pq import TransactionStatus
15from .abc import ConnectionType, PQGen
16from .pq.abc import PGresult
17
18if TYPE_CHECKING:
19    from typing import Any
20    from .connection import Connection
21    from .connection_async import AsyncConnection
22
23logger = logging.getLogger(__name__)
24
25
26class Rollback(Exception):
27    """
28    Exit the current `Transaction` context immediately and rollback any changes
29    made within this context.
30
31    If a transaction context is specified in the constructor, rollback
32    enclosing transactions contexts up to and including the one specified.
33    """
34
35    __module__ = "psycopg"
36
37    def __init__(
38        self,
39        transaction: Union["Transaction", "AsyncTransaction", None] = None,
40    ):
41        self.transaction = transaction
42
43    def __repr__(self) -> str:
44        return f"{self.__class__.__qualname__}({self.transaction!r})"
45
46
47class BaseTransaction(Generic[ConnectionType]):
48    def __init__(
49        self,
50        connection: ConnectionType,
51        savepoint_name: Optional[str] = None,
52        force_rollback: bool = False,
53    ):
54        self._conn = connection
55        self._savepoint_name = savepoint_name or ""
56        self.force_rollback = force_rollback
57        self._entered = self._exited = False
58
59    @property
60    def savepoint_name(self) -> Optional[str]:
61        """
62        The name of the savepoint; `!None` if handling the main transaction.
63        """
64        # Yes, it may change on __enter__. No, I don't care, because the
65        # un-entered state is outside the public interface.
66        return self._savepoint_name
67
68    def __repr__(self) -> str:
69        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
70        info = pq.misc.connection_summary(self._conn.pgconn)
71        if not self._entered:
72            status = "inactive"
73        elif not self._exited:
74            status = "active"
75        else:
76            status = "terminated"
77
78        sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
79        return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
80
81    def _enter_gen(self) -> PQGen[PGresult]:
82        if self._entered:
83            raise TypeError("transaction blocks can be used only once")
84        self._entered = True
85
86        self._outer_transaction = (
87            self._conn.pgconn.transaction_status == TransactionStatus.IDLE
88        )
89        if self._outer_transaction:
90            # outer transaction: if no name it's only a begin, else
91            # there will be an additional savepoint
92            assert not self._conn._savepoints
93        else:
94            # inner transaction: it always has a name
95            if not self._savepoint_name:
96                self._savepoint_name = (
97                    f"_pg3_{len(self._conn._savepoints) + 1}"
98                )
99
100        commands = []
101        if self._outer_transaction:
102            assert not self._conn._savepoints, self._conn._savepoints
103            commands.append(self._conn._get_tx_start_command())
104
105        if self._savepoint_name:
106            commands.append(
107                sql.SQL("SAVEPOINT {}")
108                .format(sql.Identifier(self._savepoint_name))
109                .as_bytes(self._conn)
110            )
111
112        self._conn._savepoints.append(self._savepoint_name)
113        return self._conn._exec_command(b"; ".join(commands))
114
115    def _exit_gen(
116        self,
117        exc_type: Optional[Type[BaseException]],
118        exc_val: Optional[BaseException],
119        exc_tb: Optional[TracebackType],
120    ) -> PQGen[bool]:
121        if not exc_val and not self.force_rollback:
122            yield from self._commit_gen()
123            return False
124        else:
125            # try to rollback, but if there are problems (connection in a bad
126            # state) just warn without clobbering the exception bubbling up.
127            try:
128                return (yield from self._rollback_gen(exc_val))
129            except Exception as exc2:
130                logger.warning(
131                    "error ignored in rollback of %s: %s",
132                    self,
133                    exc2,
134                )
135                return False
136
137    def _commit_gen(self) -> PQGen[PGresult]:
138        assert self._conn._savepoints[-1] == self._savepoint_name
139        self._conn._savepoints.pop()
140        self._exited = True
141
142        commands = []
143        if self._savepoint_name and not self._outer_transaction:
144            commands.append(
145                sql.SQL("RELEASE {}")
146                .format(sql.Identifier(self._savepoint_name))
147                .as_bytes(self._conn)
148            )
149
150        if self._outer_transaction:
151            assert not self._conn._savepoints
152            commands.append(b"COMMIT")
153
154        return self._conn._exec_command(b"; ".join(commands))
155
156    def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
157        if isinstance(exc_val, Rollback):
158            logger.debug(
159                f"{self._conn}: Explicit rollback from: ", exc_info=True
160            )
161
162        assert self._conn._savepoints[-1] == self._savepoint_name
163        self._conn._savepoints.pop()
164
165        commands = []
166        if self._savepoint_name and not self._outer_transaction:
167            commands.append(
168                sql.SQL("ROLLBACK TO {n}; RELEASE {n}")
169                .format(n=sql.Identifier(self._savepoint_name))
170                .as_bytes(self._conn)
171            )
172
173        if self._outer_transaction:
174            assert not self._conn._savepoints
175            commands.append(b"ROLLBACK")
176
177        # Also clear the prepared statements cache.
178        cmd = self._conn._prepared.clear()
179        if cmd:
180            commands.append(cmd)
181
182        yield from self._conn._exec_command(b"; ".join(commands))
183
184        if isinstance(exc_val, Rollback):
185            if not exc_val.transaction or exc_val.transaction is self:
186                return True  # Swallow the exception
187
188        return False
189
190
191class Transaction(BaseTransaction["Connection[Any]"]):
192    """
193    Returned by `Connection.transaction()` to handle a transaction block.
194    """
195
196    __module__ = "psycopg"
197
198    @property
199    def connection(self) -> "Connection[Any]":
200        """The connection the object is managing."""
201        return self._conn
202
203    def __enter__(self) -> "Transaction":
204        with self._conn.lock:
205            self._conn.wait(self._enter_gen())
206        return self
207
208    def __exit__(
209        self,
210        exc_type: Optional[Type[BaseException]],
211        exc_val: Optional[BaseException],
212        exc_tb: Optional[TracebackType],
213    ) -> bool:
214        with self._conn.lock:
215            return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
216
217
218class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
219    """
220    Returned by `AsyncConnection.transaction()` to handle a transaction block.
221    """
222
223    __module__ = "psycopg"
224
225    @property
226    def connection(self) -> "AsyncConnection[Any]":
227        return self._conn
228
229    async def __aenter__(self) -> "AsyncTransaction":
230        async with self._conn.lock:
231            await self._conn.wait(self._enter_gen())
232        return self
233
234    async def __aexit__(
235        self,
236        exc_type: Optional[Type[BaseException]],
237        exc_val: Optional[BaseException],
238        exc_tb: Optional[TracebackType],
239    ) -> bool:
240        async with self._conn.lock:
241            return await self._conn.wait(
242                self._exit_gen(exc_type, exc_val, exc_tb)
243            )
244