1# Copyright 2020 The Matrix.org Foundation C.I.C.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import abc
15import logging
16import threading
17from typing import TYPE_CHECKING, Callable, List, Optional
18
19from synapse.storage.engines import (
20    BaseDatabaseEngine,
21    IncorrectDatabaseSetup,
22    PostgresEngine,
23)
24from synapse.storage.types import Connection, Cursor
25
26if TYPE_CHECKING:
27    from synapse.storage.database import LoggingDatabaseConnection
28
29logger = logging.getLogger(__name__)
30
31
32_INCONSISTENT_SEQUENCE_ERROR = """
33Postgres sequence '%(seq)s' is inconsistent with associated
34table '%(table)s'. This can happen if Synapse has been downgraded and
35then upgraded again, or due to a bad migration.
36
37To fix this error, shut down Synapse (including any and all workers)
38and run the following SQL:
39
40    SELECT setval('%(seq)s', (
41        %(max_id_sql)s
42    ));
43
44See docs/postgres.md for more information.
45"""
46
47_INCONSISTENT_STREAM_ERROR = """
48Postgres sequence '%(seq)s' is inconsistent with associated stream position
49of '%(stream_name)s' in the 'stream_positions' table.
50
51This is likely a programming error and should be reported at
52https://github.com/matrix-org/synapse.
53
54A temporary workaround to fix this error is to shut down Synapse (including
55any and all workers) and run the following SQL:
56
57    DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
58
59This will need to be done every time the server is restarted.
60"""
61
62
63class SequenceGenerator(metaclass=abc.ABCMeta):
64    """A class which generates a unique sequence of integers"""
65
66    @abc.abstractmethod
67    def get_next_id_txn(self, txn: Cursor) -> int:
68        """Gets the next ID in the sequence"""
69        ...
70
71    @abc.abstractmethod
72    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
73        """Get the next `n` IDs in the sequence"""
74        ...
75
76    @abc.abstractmethod
77    def check_consistency(
78        self,
79        db_conn: "LoggingDatabaseConnection",
80        table: str,
81        id_column: str,
82        stream_name: Optional[str] = None,
83        positive: bool = True,
84    ) -> None:
85        """Should be called during start up to test that the current value of
86        the sequence is greater than or equal to the maximum ID in the table.
87
88        This is to handle various cases where the sequence value can get out of
89        sync with the table, e.g. if Synapse gets rolled back to a previous
90        version and the rolled forwards again.
91
92        If a stream name is given then this will check that any value in the
93        `stream_positions` table is less than or equal to the current sequence
94        value. If it isn't then it's likely that streams have been crossed
95        somewhere (e.g. two ID generators have the same stream name).
96        """
97        ...
98
99
100class PostgresSequenceGenerator(SequenceGenerator):
101    """An implementation of SequenceGenerator which uses a postgres sequence"""
102
103    def __init__(self, sequence_name: str):
104        self._sequence_name = sequence_name
105
106    def get_next_id_txn(self, txn: Cursor) -> int:
107        txn.execute("SELECT nextval(?)", (self._sequence_name,))
108        fetch_res = txn.fetchone()
109        assert fetch_res is not None
110        return fetch_res[0]
111
112    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
113        txn.execute(
114            "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
115        )
116        return [i for (i,) in txn]
117
118    def check_consistency(
119        self,
120        db_conn: "LoggingDatabaseConnection",
121        table: str,
122        id_column: str,
123        stream_name: Optional[str] = None,
124        positive: bool = True,
125    ) -> None:
126        """See SequenceGenerator.check_consistency for docstring."""
127
128        txn = db_conn.cursor(txn_name="sequence.check_consistency")
129
130        # First we get the current max ID from the table.
131        table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
132            "id": id_column,
133            "table": table,
134            "agg": "MAX" if positive else "-MIN",
135        }
136
137        txn.execute(table_sql)
138        row = txn.fetchone()
139        if not row:
140            # Table is empty, so nothing to do.
141            txn.close()
142            return
143
144        # Now we fetch the current value from the sequence and compare with the
145        # above.
146        max_stream_id = row[0]
147        txn.execute(
148            "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
149        )
150        fetch_res = txn.fetchone()
151        assert fetch_res is not None
152        last_value, is_called = fetch_res
153
154        # If we have an associated stream check the stream_positions table.
155        max_in_stream_positions = None
156        if stream_name:
157            txn.execute(
158                "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
159                (stream_name,),
160            )
161            row = txn.fetchone()
162            if row:
163                max_in_stream_positions = row[0]
164
165        txn.close()
166
167        # If `is_called` is False then `last_value` is actually the value that
168        # will be generated next, so we decrement to get the true "last value".
169        if not is_called:
170            last_value -= 1
171
172        if max_stream_id > last_value:
173            logger.warning(
174                "Postgres sequence %s is behind table %s: %d < %d",
175                self._sequence_name,
176                table,
177                last_value,
178                max_stream_id,
179            )
180            raise IncorrectDatabaseSetup(
181                _INCONSISTENT_SEQUENCE_ERROR
182                % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
183            )
184
185        # If we have values in the stream positions table then they have to be
186        # less than or equal to `last_value`
187        if max_in_stream_positions and max_in_stream_positions > last_value:
188            raise IncorrectDatabaseSetup(
189                _INCONSISTENT_STREAM_ERROR
190                % {"seq": self._sequence_name, "stream_name": stream_name}
191            )
192
193
194GetFirstCallbackType = Callable[[Cursor], int]
195
196
197class LocalSequenceGenerator(SequenceGenerator):
198    """An implementation of SequenceGenerator which uses local locking
199
200    This only works reliably if there are no other worker processes generating IDs at
201    the same time.
202    """
203
204    def __init__(self, get_first_callback: GetFirstCallbackType):
205        """
206        Args:
207            get_first_callback: a callback which is called on the first call to
208                 get_next_id_txn; should return the curreent maximum id
209        """
210        # the callback. this is cleared after it is called, so that it can be GCed.
211        self._callback: Optional[GetFirstCallbackType] = get_first_callback
212
213        # The current max value, or None if we haven't looked in the DB yet.
214        self._current_max_id: Optional[int] = None
215        self._lock = threading.Lock()
216
217    def get_next_id_txn(self, txn: Cursor) -> int:
218        # We do application locking here since if we're using sqlite then
219        # we are a single process synapse.
220        with self._lock:
221            if self._current_max_id is None:
222                assert self._callback is not None
223                self._current_max_id = self._callback(txn)
224                self._callback = None
225
226            self._current_max_id += 1
227            return self._current_max_id
228
229    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
230        with self._lock:
231            if self._current_max_id is None:
232                assert self._callback is not None
233                self._current_max_id = self._callback(txn)
234                self._callback = None
235
236            first_id = self._current_max_id + 1
237            self._current_max_id += n
238            return [first_id + i for i in range(n)]
239
240    def check_consistency(
241        self,
242        db_conn: Connection,
243        table: str,
244        id_column: str,
245        stream_name: Optional[str] = None,
246        positive: bool = True,
247    ) -> None:
248        # There is nothing to do for in memory sequences
249        pass
250
251
252def build_sequence_generator(
253    db_conn: "LoggingDatabaseConnection",
254    database_engine: BaseDatabaseEngine,
255    get_first_callback: GetFirstCallbackType,
256    sequence_name: str,
257    table: Optional[str],
258    id_column: Optional[str],
259    stream_name: Optional[str] = None,
260    positive: bool = True,
261) -> SequenceGenerator:
262    """Get the best impl of SequenceGenerator available
263
264    This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
265    sqlite.
266
267    Args:
268        database_engine: the database engine we are connected to
269        get_first_callback: a callback which gets the next sequence ID. Used if
270            we're on sqlite.
271        sequence_name: the name of a postgres sequence to use.
272        table, id_column, stream_name, positive: If set then `check_consistency`
273            is called on the created sequence. See docstring for
274            `check_consistency` details.
275    """
276    if isinstance(database_engine, PostgresEngine):
277        seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name)
278    else:
279        seq = LocalSequenceGenerator(get_first_callback)
280
281    if table:
282        assert id_column
283        seq.check_consistency(
284            db_conn=db_conn,
285            table=table,
286            id_column=id_column,
287            stream_name=stream_name,
288            positive=positive,
289        )
290
291    return seq
292