1# Copyright 2015, 2016 OpenMarket Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import logging 16 17from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup 18from synapse.storage.types import Connection 19 20logger = logging.getLogger(__name__) 21 22 23class PostgresEngine(BaseDatabaseEngine): 24 def __init__(self, database_module, database_config): 25 super().__init__(database_module, database_config) 26 self.module.extensions.register_type(self.module.extensions.UNICODE) 27 28 # Disables passing `bytes` to txn.execute, c.f. #6186. If you do 29 # actually want to use bytes than wrap it in `bytearray`. 30 def _disable_bytes_adapter(_): 31 raise Exception("Passing bytes to DB is disabled.") 32 33 self.module.extensions.register_adapter(bytes, _disable_bytes_adapter) 34 self.synchronous_commit = database_config.get("synchronous_commit", True) 35 self._version = None # unknown as yet 36 37 @property 38 def single_threaded(self) -> bool: 39 return False 40 41 def check_database(self, db_conn, allow_outdated_version: bool = False): 42 # Get the version of PostgreSQL that we're using. As per the psycopg2 43 # docs: The number is formed by converting the major, minor, and 44 # revision numbers into two-decimal-digit numbers and appending them 45 # together. For example, version 8.1.5 will be returned as 80105 46 self._version = db_conn.server_version 47 48 # Are we on a supported PostgreSQL version? 49 if not allow_outdated_version and self._version < 90600: 50 raise RuntimeError("Synapse requires PostgreSQL 9.6 or above.") 51 52 with db_conn.cursor() as txn: 53 txn.execute("SHOW SERVER_ENCODING") 54 rows = txn.fetchall() 55 if rows and rows[0][0] != "UTF8": 56 raise IncorrectDatabaseSetup( 57 "Database has incorrect encoding: '%s' instead of 'UTF8'\n" 58 "See docs/postgres.md for more information." % (rows[0][0],) 59 ) 60 61 txn.execute( 62 "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" 63 ) 64 collation, ctype = txn.fetchone() 65 if collation != "C": 66 logger.warning( 67 "Database has incorrect collation of %r. Should be 'C'\n" 68 "See docs/postgres.md for more information.", 69 collation, 70 ) 71 72 if ctype != "C": 73 logger.warning( 74 "Database has incorrect ctype of %r. Should be 'C'\n" 75 "See docs/postgres.md for more information.", 76 ctype, 77 ) 78 79 def check_new_database(self, txn): 80 """Gets called when setting up a brand new database. This allows us to 81 apply stricter checks on new databases versus existing database. 82 """ 83 84 txn.execute( 85 "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" 86 ) 87 collation, ctype = txn.fetchone() 88 89 errors = [] 90 91 if collation != "C": 92 errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,)) 93 94 if ctype != "C": 95 errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,)) 96 97 if errors: 98 raise IncorrectDatabaseSetup( 99 "Database is incorrectly configured:\n\n%s\n\n" 100 "See docs/postgres.md for more information." % ("\n".join(errors)) 101 ) 102 103 def convert_param_style(self, sql): 104 return sql.replace("?", "%s") 105 106 def on_new_connection(self, db_conn): 107 db_conn.set_isolation_level( 108 self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ 109 ) 110 111 # Set the bytea output to escape, vs the default of hex 112 cursor = db_conn.cursor() 113 cursor.execute("SET bytea_output TO escape") 114 115 # Asynchronous commit, don't wait for the server to call fsync before 116 # ending the transaction. 117 # https://www.postgresql.org/docs/current/static/wal-async-commit.html 118 if not self.synchronous_commit: 119 cursor.execute("SET synchronous_commit TO OFF") 120 121 cursor.close() 122 db_conn.commit() 123 124 @property 125 def can_native_upsert(self): 126 """ 127 Can we use native UPSERTs? 128 """ 129 return True 130 131 @property 132 def supports_using_any_list(self): 133 """Do we support using `a = ANY(?)` and passing a list""" 134 return True 135 136 @property 137 def supports_returning(self) -> bool: 138 """Do we support the `RETURNING` clause in insert/update/delete?""" 139 return True 140 141 def is_deadlock(self, error): 142 if isinstance(error, self.module.DatabaseError): 143 # https://www.postgresql.org/docs/current/static/errcodes-appendix.html 144 # "40001" serialization_failure 145 # "40P01" deadlock_detected 146 return error.pgcode in ["40001", "40P01"] 147 return False 148 149 def is_connection_closed(self, conn): 150 return bool(conn.closed) 151 152 def lock_table(self, txn, table): 153 txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) 154 155 @property 156 def server_version(self): 157 """Returns a string giving the server version. For example: '8.1.5' 158 159 Returns: 160 string 161 """ 162 # note that this is a bit of a hack because it relies on check_database 163 # having been called. Still, that should be a safe bet here. 164 numver = self._version 165 assert numver is not None 166 167 # https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION 168 if numver >= 100000: 169 return "%i.%i" % (numver / 10000, numver % 10000) 170 else: 171 return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) 172 173 def in_transaction(self, conn: Connection) -> bool: 174 return conn.status != self.module.extensions.STATUS_READY # type: ignore 175 176 def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): 177 return conn.set_session(autocommit=autocommit) # type: ignore 178