1# Copyright 2020 Matrix.org Foundation C.I.C. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import Any, Dict, List, Optional, Tuple, Union, cast 15 16import attr 17 18from synapse.api.constants import LoginType 19from synapse.api.errors import StoreError 20from synapse.storage._base import SQLBaseStore, db_to_json 21from synapse.storage.database import LoggingTransaction 22from synapse.types import JsonDict 23from synapse.util import json_encoder, stringutils 24 25 26@attr.s(slots=True) 27class UIAuthSessionData: 28 session_id = attr.ib(type=str) 29 # The dictionary from the client root level, not the 'auth' key. 30 clientdict = attr.ib(type=JsonDict) 31 # The URI and method the session was intiatied with. These are checked at 32 # each stage of the authentication to ensure that the asked for operation 33 # has not changed. 34 uri = attr.ib(type=str) 35 method = attr.ib(type=str) 36 # A string description of the operation that the current authentication is 37 # authorising. 38 description = attr.ib(type=str) 39 40 41class UIAuthWorkerStore(SQLBaseStore): 42 """ 43 Manage user interactive authentication sessions. 44 """ 45 46 async def create_ui_auth_session( 47 self, 48 clientdict: JsonDict, 49 uri: str, 50 method: str, 51 description: str, 52 ) -> UIAuthSessionData: 53 """ 54 Creates a new user interactive authentication session. 55 56 The session can be used to track the stages necessary to authenticate a 57 user across multiple HTTP requests. 58 59 Args: 60 clientdict: 61 The dictionary from the client root level, not the 'auth' key. 62 uri: 63 The URI this session was initiated with, this is checked at each 64 stage of the authentication to ensure that the asked for 65 operation has not changed. 66 method: 67 The method this session was initiated with, this is checked at each 68 stage of the authentication to ensure that the asked for 69 operation has not changed. 70 description: 71 A string description of the operation that the current 72 authentication is authorising. 73 Returns: 74 The newly created session. 75 Raises: 76 StoreError if a unique session ID cannot be generated. 77 """ 78 # The clientdict gets stored as JSON. 79 clientdict_json = json_encoder.encode(clientdict) 80 81 # autogen a session ID and try to create it. We may clash, so just 82 # try a few times till one goes through, giving up eventually. 83 attempts = 0 84 while attempts < 5: 85 session_id = stringutils.random_string(24) 86 87 try: 88 await self.db_pool.simple_insert( 89 table="ui_auth_sessions", 90 values={ 91 "session_id": session_id, 92 "clientdict": clientdict_json, 93 "uri": uri, 94 "method": method, 95 "description": description, 96 "serverdict": "{}", 97 "creation_time": self.hs.get_clock().time_msec(), 98 }, 99 desc="create_ui_auth_session", 100 ) 101 return UIAuthSessionData( 102 session_id, clientdict, uri, method, description 103 ) 104 except self.db_pool.engine.module.IntegrityError: 105 attempts += 1 106 raise StoreError(500, "Couldn't generate a session ID.") 107 108 async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: 109 """Retrieve a UI auth session. 110 111 Args: 112 session_id: The ID of the session. 113 Returns: 114 A dict containing the device information. 115 Raises: 116 StoreError if the session is not found. 117 """ 118 result = await self.db_pool.simple_select_one( 119 table="ui_auth_sessions", 120 keyvalues={"session_id": session_id}, 121 retcols=("clientdict", "uri", "method", "description"), 122 desc="get_ui_auth_session", 123 ) 124 125 result["clientdict"] = db_to_json(result["clientdict"]) 126 127 return UIAuthSessionData(session_id, **result) 128 129 async def mark_ui_auth_stage_complete( 130 self, 131 session_id: str, 132 stage_type: str, 133 result: Union[str, bool, JsonDict], 134 ): 135 """ 136 Mark a session stage as completed. 137 138 Args: 139 session_id: The ID of the corresponding session. 140 stage_type: The completed stage type. 141 result: The result of the stage verification. 142 Raises: 143 StoreError if the session cannot be found. 144 """ 145 # Add (or update) the results of the current stage to the database. 146 # 147 # Note that we need to allow for the same stage to complete multiple 148 # times here so that registration is idempotent. 149 try: 150 await self.db_pool.simple_upsert( 151 table="ui_auth_sessions_credentials", 152 keyvalues={"session_id": session_id, "stage_type": stage_type}, 153 values={"result": json_encoder.encode(result)}, 154 desc="mark_ui_auth_stage_complete", 155 ) 156 except self.db_pool.engine.module.IntegrityError: 157 raise StoreError(400, "Unknown session ID: %s" % (session_id,)) 158 159 async def get_completed_ui_auth_stages( 160 self, session_id: str 161 ) -> Dict[str, Union[str, bool, JsonDict]]: 162 """ 163 Retrieve the completed stages of a UI authentication session. 164 165 Args: 166 session_id: The ID of the session. 167 Returns: 168 The completed stages mapped to the result of the verification of 169 that auth-type. 170 """ 171 results = {} 172 for row in await self.db_pool.simple_select_list( 173 table="ui_auth_sessions_credentials", 174 keyvalues={"session_id": session_id}, 175 retcols=("stage_type", "result"), 176 desc="get_completed_ui_auth_stages", 177 ): 178 results[row["stage_type"]] = db_to_json(row["result"]) 179 180 return results 181 182 async def set_ui_auth_clientdict( 183 self, session_id: str, clientdict: JsonDict 184 ) -> None: 185 """ 186 Store an updated clientdict for a given session ID. 187 188 Args: 189 session_id: The ID of this session as returned from check_auth 190 clientdict: 191 The dictionary from the client root level, not the 'auth' key. 192 """ 193 # The clientdict gets stored as JSON. 194 clientdict_json = json_encoder.encode(clientdict) 195 196 await self.db_pool.simple_update_one( 197 table="ui_auth_sessions", 198 keyvalues={"session_id": session_id}, 199 updatevalues={"clientdict": clientdict_json}, 200 desc="set_ui_auth_client_dict", 201 ) 202 203 async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): 204 """ 205 Store a key-value pair into the sessions data associated with this 206 request. This data is stored server-side and cannot be modified by 207 the client. 208 209 Args: 210 session_id: The ID of this session as returned from check_auth 211 key: The key to store the data under 212 value: The data to store 213 Raises: 214 StoreError if the session cannot be found. 215 """ 216 await self.db_pool.runInteraction( 217 "set_ui_auth_session_data", 218 self._set_ui_auth_session_data_txn, 219 session_id, 220 key, 221 value, 222 ) 223 224 def _set_ui_auth_session_data_txn( 225 self, txn: LoggingTransaction, session_id: str, key: str, value: Any 226 ): 227 # Get the current value. 228 result = cast( 229 Dict[str, Any], 230 self.db_pool.simple_select_one_txn( 231 txn, 232 table="ui_auth_sessions", 233 keyvalues={"session_id": session_id}, 234 retcols=("serverdict",), 235 ), 236 ) 237 238 # Update it and add it back to the database. 239 serverdict = db_to_json(result["serverdict"]) 240 serverdict[key] = value 241 242 self.db_pool.simple_update_one_txn( 243 txn, 244 table="ui_auth_sessions", 245 keyvalues={"session_id": session_id}, 246 updatevalues={"serverdict": json_encoder.encode(serverdict)}, 247 ) 248 249 async def get_ui_auth_session_data( 250 self, session_id: str, key: str, default: Optional[Any] = None 251 ) -> Any: 252 """ 253 Retrieve data stored with set_session_data 254 255 Args: 256 session_id: The ID of this session as returned from check_auth 257 key: The key to store the data under 258 default: Value to return if the key has not been set 259 Raises: 260 StoreError if the session cannot be found. 261 """ 262 result = await self.db_pool.simple_select_one( 263 table="ui_auth_sessions", 264 keyvalues={"session_id": session_id}, 265 retcols=("serverdict",), 266 desc="get_ui_auth_session_data", 267 ) 268 269 serverdict = db_to_json(result["serverdict"]) 270 271 return serverdict.get(key, default) 272 273 async def add_user_agent_ip_to_ui_auth_session( 274 self, 275 session_id: str, 276 user_agent: str, 277 ip: str, 278 ): 279 """Add the given user agent / IP to the tracking table""" 280 await self.db_pool.simple_upsert( 281 table="ui_auth_sessions_ips", 282 keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, 283 values={}, 284 desc="add_user_agent_ip_to_ui_auth_session", 285 ) 286 287 async def get_user_agents_ips_to_ui_auth_session( 288 self, 289 session_id: str, 290 ) -> List[Tuple[str, str]]: 291 """Get the given user agents / IPs used during the ui auth process 292 293 Returns: 294 List of user_agent/ip pairs 295 """ 296 rows = await self.db_pool.simple_select_list( 297 table="ui_auth_sessions_ips", 298 keyvalues={"session_id": session_id}, 299 retcols=("user_agent", "ip"), 300 desc="get_user_agents_ips_to_ui_auth_session", 301 ) 302 return [(row["user_agent"], row["ip"]) for row in rows] 303 304 async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: 305 """ 306 Remove sessions which were last used earlier than the expiration time. 307 308 Args: 309 expiration_time: The latest time that is still considered valid. 310 This is an epoch time in milliseconds. 311 312 """ 313 await self.db_pool.runInteraction( 314 "delete_old_ui_auth_sessions", 315 self._delete_old_ui_auth_sessions_txn, 316 expiration_time, 317 ) 318 319 def _delete_old_ui_auth_sessions_txn( 320 self, txn: LoggingTransaction, expiration_time: int 321 ): 322 # Get the expired sessions. 323 sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" 324 txn.execute(sql, [expiration_time]) 325 session_ids = [r[0] for r in txn.fetchall()] 326 327 # Delete the corresponding IP/user agents. 328 self.db_pool.simple_delete_many_txn( 329 txn, 330 table="ui_auth_sessions_ips", 331 column="session_id", 332 values=session_ids, 333 keyvalues={}, 334 ) 335 336 # If a registration token was used, decrement the pending counter 337 # before deleting the session. 338 rows = self.db_pool.simple_select_many_txn( 339 txn, 340 table="ui_auth_sessions_credentials", 341 column="session_id", 342 iterable=session_ids, 343 keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, 344 retcols=["result"], 345 ) 346 347 # Get the tokens used and how much pending needs to be decremented by. 348 token_counts: Dict[str, int] = {} 349 for r in rows: 350 # If registration was successfully completed, the result of the 351 # registration token stage for that session will be True. 352 # If a token was used to authenticate, but registration was 353 # never completed, the result will be the token used. 354 token = db_to_json(r["result"]) 355 if isinstance(token, str): 356 token_counts[token] = token_counts.get(token, 0) + 1 357 358 # Update the `pending` counters. 359 if len(token_counts) > 0: 360 token_rows = self.db_pool.simple_select_many_txn( 361 txn, 362 table="registration_tokens", 363 column="token", 364 iterable=list(token_counts.keys()), 365 keyvalues={}, 366 retcols=["token", "pending"], 367 ) 368 for token_row in token_rows: 369 token = token_row["token"] 370 new_pending = token_row["pending"] - token_counts[token] 371 self.db_pool.simple_update_one_txn( 372 txn, 373 table="registration_tokens", 374 keyvalues={"token": token}, 375 updatevalues={"pending": new_pending}, 376 ) 377 378 # Delete the corresponding completed credentials. 379 self.db_pool.simple_delete_many_txn( 380 txn, 381 table="ui_auth_sessions_credentials", 382 column="session_id", 383 values=session_ids, 384 keyvalues={}, 385 ) 386 387 # Finally, delete the sessions. 388 self.db_pool.simple_delete_many_txn( 389 txn, 390 table="ui_auth_sessions", 391 column="session_id", 392 values=session_ids, 393 keyvalues={}, 394 ) 395 396 397class UIAuthStore(UIAuthWorkerStore): 398 pass 399