1# Copyright 2020 The Matrix.org Foundation C.I.C. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import abc 15import logging 16import threading 17from typing import TYPE_CHECKING, Callable, List, Optional 18 19from synapse.storage.engines import ( 20 BaseDatabaseEngine, 21 IncorrectDatabaseSetup, 22 PostgresEngine, 23) 24from synapse.storage.types import Connection, Cursor 25 26if TYPE_CHECKING: 27 from synapse.storage.database import LoggingDatabaseConnection 28 29logger = logging.getLogger(__name__) 30 31 32_INCONSISTENT_SEQUENCE_ERROR = """ 33Postgres sequence '%(seq)s' is inconsistent with associated 34table '%(table)s'. This can happen if Synapse has been downgraded and 35then upgraded again, or due to a bad migration. 36 37To fix this error, shut down Synapse (including any and all workers) 38and run the following SQL: 39 40 SELECT setval('%(seq)s', ( 41 %(max_id_sql)s 42 )); 43 44See docs/postgres.md for more information. 45""" 46 47_INCONSISTENT_STREAM_ERROR = """ 48Postgres sequence '%(seq)s' is inconsistent with associated stream position 49of '%(stream_name)s' in the 'stream_positions' table. 50 51This is likely a programming error and should be reported at 52https://github.com/matrix-org/synapse. 53 54A temporary workaround to fix this error is to shut down Synapse (including 55any and all workers) and run the following SQL: 56 57 DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s'; 58 59This will need to be done every time the server is restarted. 60""" 61 62 63class SequenceGenerator(metaclass=abc.ABCMeta): 64 """A class which generates a unique sequence of integers""" 65 66 @abc.abstractmethod 67 def get_next_id_txn(self, txn: Cursor) -> int: 68 """Gets the next ID in the sequence""" 69 ... 70 71 @abc.abstractmethod 72 def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: 73 """Get the next `n` IDs in the sequence""" 74 ... 75 76 @abc.abstractmethod 77 def check_consistency( 78 self, 79 db_conn: "LoggingDatabaseConnection", 80 table: str, 81 id_column: str, 82 stream_name: Optional[str] = None, 83 positive: bool = True, 84 ) -> None: 85 """Should be called during start up to test that the current value of 86 the sequence is greater than or equal to the maximum ID in the table. 87 88 This is to handle various cases where the sequence value can get out of 89 sync with the table, e.g. if Synapse gets rolled back to a previous 90 version and the rolled forwards again. 91 92 If a stream name is given then this will check that any value in the 93 `stream_positions` table is less than or equal to the current sequence 94 value. If it isn't then it's likely that streams have been crossed 95 somewhere (e.g. two ID generators have the same stream name). 96 """ 97 ... 98 99 100class PostgresSequenceGenerator(SequenceGenerator): 101 """An implementation of SequenceGenerator which uses a postgres sequence""" 102 103 def __init__(self, sequence_name: str): 104 self._sequence_name = sequence_name 105 106 def get_next_id_txn(self, txn: Cursor) -> int: 107 txn.execute("SELECT nextval(?)", (self._sequence_name,)) 108 fetch_res = txn.fetchone() 109 assert fetch_res is not None 110 return fetch_res[0] 111 112 def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: 113 txn.execute( 114 "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n) 115 ) 116 return [i for (i,) in txn] 117 118 def check_consistency( 119 self, 120 db_conn: "LoggingDatabaseConnection", 121 table: str, 122 id_column: str, 123 stream_name: Optional[str] = None, 124 positive: bool = True, 125 ) -> None: 126 """See SequenceGenerator.check_consistency for docstring.""" 127 128 txn = db_conn.cursor(txn_name="sequence.check_consistency") 129 130 # First we get the current max ID from the table. 131 table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { 132 "id": id_column, 133 "table": table, 134 "agg": "MAX" if positive else "-MIN", 135 } 136 137 txn.execute(table_sql) 138 row = txn.fetchone() 139 if not row: 140 # Table is empty, so nothing to do. 141 txn.close() 142 return 143 144 # Now we fetch the current value from the sequence and compare with the 145 # above. 146 max_stream_id = row[0] 147 txn.execute( 148 "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} 149 ) 150 fetch_res = txn.fetchone() 151 assert fetch_res is not None 152 last_value, is_called = fetch_res 153 154 # If we have an associated stream check the stream_positions table. 155 max_in_stream_positions = None 156 if stream_name: 157 txn.execute( 158 "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?", 159 (stream_name,), 160 ) 161 row = txn.fetchone() 162 if row: 163 max_in_stream_positions = row[0] 164 165 txn.close() 166 167 # If `is_called` is False then `last_value` is actually the value that 168 # will be generated next, so we decrement to get the true "last value". 169 if not is_called: 170 last_value -= 1 171 172 if max_stream_id > last_value: 173 logger.warning( 174 "Postgres sequence %s is behind table %s: %d < %d", 175 self._sequence_name, 176 table, 177 last_value, 178 max_stream_id, 179 ) 180 raise IncorrectDatabaseSetup( 181 _INCONSISTENT_SEQUENCE_ERROR 182 % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} 183 ) 184 185 # If we have values in the stream positions table then they have to be 186 # less than or equal to `last_value` 187 if max_in_stream_positions and max_in_stream_positions > last_value: 188 raise IncorrectDatabaseSetup( 189 _INCONSISTENT_STREAM_ERROR 190 % {"seq": self._sequence_name, "stream_name": stream_name} 191 ) 192 193 194GetFirstCallbackType = Callable[[Cursor], int] 195 196 197class LocalSequenceGenerator(SequenceGenerator): 198 """An implementation of SequenceGenerator which uses local locking 199 200 This only works reliably if there are no other worker processes generating IDs at 201 the same time. 202 """ 203 204 def __init__(self, get_first_callback: GetFirstCallbackType): 205 """ 206 Args: 207 get_first_callback: a callback which is called on the first call to 208 get_next_id_txn; should return the curreent maximum id 209 """ 210 # the callback. this is cleared after it is called, so that it can be GCed. 211 self._callback: Optional[GetFirstCallbackType] = get_first_callback 212 213 # The current max value, or None if we haven't looked in the DB yet. 214 self._current_max_id: Optional[int] = None 215 self._lock = threading.Lock() 216 217 def get_next_id_txn(self, txn: Cursor) -> int: 218 # We do application locking here since if we're using sqlite then 219 # we are a single process synapse. 220 with self._lock: 221 if self._current_max_id is None: 222 assert self._callback is not None 223 self._current_max_id = self._callback(txn) 224 self._callback = None 225 226 self._current_max_id += 1 227 return self._current_max_id 228 229 def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: 230 with self._lock: 231 if self._current_max_id is None: 232 assert self._callback is not None 233 self._current_max_id = self._callback(txn) 234 self._callback = None 235 236 first_id = self._current_max_id + 1 237 self._current_max_id += n 238 return [first_id + i for i in range(n)] 239 240 def check_consistency( 241 self, 242 db_conn: Connection, 243 table: str, 244 id_column: str, 245 stream_name: Optional[str] = None, 246 positive: bool = True, 247 ) -> None: 248 # There is nothing to do for in memory sequences 249 pass 250 251 252def build_sequence_generator( 253 db_conn: "LoggingDatabaseConnection", 254 database_engine: BaseDatabaseEngine, 255 get_first_callback: GetFirstCallbackType, 256 sequence_name: str, 257 table: Optional[str], 258 id_column: Optional[str], 259 stream_name: Optional[str] = None, 260 positive: bool = True, 261) -> SequenceGenerator: 262 """Get the best impl of SequenceGenerator available 263 264 This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on 265 sqlite. 266 267 Args: 268 database_engine: the database engine we are connected to 269 get_first_callback: a callback which gets the next sequence ID. Used if 270 we're on sqlite. 271 sequence_name: the name of a postgres sequence to use. 272 table, id_column, stream_name, positive: If set then `check_consistency` 273 is called on the created sequence. See docstring for 274 `check_consistency` details. 275 """ 276 if isinstance(database_engine, PostgresEngine): 277 seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name) 278 else: 279 seq = LocalSequenceGenerator(get_first_callback) 280 281 if table: 282 assert id_column 283 seq.check_consistency( 284 db_conn=db_conn, 285 table=table, 286 id_column=id_column, 287 stream_name=stream_name, 288 positive=positive, 289 ) 290 291 return seq 292