1# Copyright 2014-2016 OpenMarket Ltd 2# Copyright 2020-2021 The Matrix.org Foundation C.I.C. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15from enum import Enum 16from typing import ( 17 TYPE_CHECKING, 18 Any, 19 Collection, 20 Dict, 21 Iterable, 22 List, 23 Optional, 24 Tuple, 25 Union, 26 cast, 27) 28 29from synapse.storage._base import SQLBaseStore 30from synapse.storage.database import ( 31 DatabasePool, 32 LoggingDatabaseConnection, 33 LoggingTransaction, 34) 35from synapse.types import JsonDict, UserID 36 37if TYPE_CHECKING: 38 from synapse.server import HomeServer 39 40BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( 41 "media_repository_drop_index_wo_method" 42) 43BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( 44 "media_repository_drop_index_wo_method_2" 45) 46 47 48class MediaSortOrder(Enum): 49 """ 50 Enum to define the sorting method used when returning media with 51 get_local_media_by_user_paginate 52 """ 53 54 MEDIA_ID = "media_id" 55 UPLOAD_NAME = "upload_name" 56 CREATED_TS = "created_ts" 57 LAST_ACCESS_TS = "last_access_ts" 58 MEDIA_LENGTH = "media_length" 59 MEDIA_TYPE = "media_type" 60 QUARANTINED_BY = "quarantined_by" 61 SAFE_FROM_QUARANTINE = "safe_from_quarantine" 62 63 64class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): 65 def __init__( 66 self, 67 database: DatabasePool, 68 db_conn: LoggingDatabaseConnection, 69 hs: "HomeServer", 70 ): 71 super().__init__(database, db_conn, hs) 72 73 self.db_pool.updates.register_background_index_update( 74 update_name="local_media_repository_url_idx", 75 index_name="local_media_repository_url_idx", 76 table="local_media_repository", 77 columns=["created_ts"], 78 where_clause="url_cache IS NOT NULL", 79 ) 80 81 # The following the updates add the method to the unique constraint of 82 # the thumbnail databases. That fixes an issue, where thumbnails of the 83 # same resolution, but different methods could overwrite one another. 84 # This can happen with custom thumbnail configs or with dynamic thumbnailing. 85 self.db_pool.updates.register_background_index_update( 86 update_name="local_media_repository_thumbnails_method_idx", 87 index_name="local_media_repository_thumbn_media_id_width_height_method_key", 88 table="local_media_repository_thumbnails", 89 columns=[ 90 "media_id", 91 "thumbnail_width", 92 "thumbnail_height", 93 "thumbnail_type", 94 "thumbnail_method", 95 ], 96 unique=True, 97 ) 98 99 self.db_pool.updates.register_background_index_update( 100 update_name="remote_media_repository_thumbnails_method_idx", 101 index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key", 102 table="remote_media_cache_thumbnails", 103 columns=[ 104 "media_origin", 105 "media_id", 106 "thumbnail_width", 107 "thumbnail_height", 108 "thumbnail_type", 109 "thumbnail_method", 110 ], 111 unique=True, 112 ) 113 114 # the original impl of _drop_media_index_without_method was broken (see 115 # https://github.com/matrix-org/synapse/issues/8649), so we replace the original 116 # impl with a no-op and run the fixed migration as 117 # media_repository_drop_index_wo_method_2. 118 self.db_pool.updates.register_noop_background_update( 119 BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD 120 ) 121 self.db_pool.updates.register_background_update_handler( 122 BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2, 123 self._drop_media_index_without_method, 124 ) 125 126 async def _drop_media_index_without_method( 127 self, progress: JsonDict, batch_size: int 128 ) -> int: 129 """background update handler which removes the old constraints. 130 131 Note that this is only run on postgres. 132 """ 133 134 def f(txn: LoggingTransaction) -> None: 135 txn.execute( 136 "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" 137 ) 138 txn.execute( 139 "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key" 140 ) 141 142 await self.db_pool.runInteraction("drop_media_indices_without_method", f) 143 await self.db_pool.updates._end_background_update( 144 BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 145 ) 146 return 1 147 148 149class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): 150 """Persistence for attachments and avatars""" 151 152 def __init__( 153 self, 154 database: DatabasePool, 155 db_conn: LoggingDatabaseConnection, 156 hs: "HomeServer", 157 ): 158 super().__init__(database, db_conn, hs) 159 self.server_name = hs.hostname 160 161 async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: 162 """Get the metadata for a local piece of media 163 164 Returns: 165 None if the media_id doesn't exist. 166 """ 167 return await self.db_pool.simple_select_one( 168 "local_media_repository", 169 {"media_id": media_id}, 170 ( 171 "media_type", 172 "media_length", 173 "upload_name", 174 "created_ts", 175 "quarantined_by", 176 "url_cache", 177 "safe_from_quarantine", 178 ), 179 allow_none=True, 180 desc="get_local_media", 181 ) 182 183 async def get_local_media_by_user_paginate( 184 self, 185 start: int, 186 limit: int, 187 user_id: str, 188 order_by: str = MediaSortOrder.CREATED_TS.value, 189 direction: str = "f", 190 ) -> Tuple[List[Dict[str, Any]], int]: 191 """Get a paginated list of metadata for a local piece of media 192 which an user_id has uploaded 193 194 Args: 195 start: offset in the list 196 limit: maximum amount of media_ids to retrieve 197 user_id: fully-qualified user id 198 order_by: the sort order of the returned list 199 direction: sort ascending or descending 200 Returns: 201 A paginated list of all metadata of user's media, 202 plus the total count of all the user's media 203 """ 204 205 def get_local_media_by_user_paginate_txn( 206 txn: LoggingTransaction, 207 ) -> Tuple[List[Dict[str, Any]], int]: 208 209 # Set ordering 210 order_by_column = MediaSortOrder(order_by).value 211 212 if direction == "b": 213 order = "DESC" 214 else: 215 order = "ASC" 216 217 args: List[Union[str, int]] = [user_id] 218 sql = """ 219 SELECT COUNT(*) as total_media 220 FROM local_media_repository 221 WHERE user_id = ? 222 """ 223 txn.execute(sql, args) 224 count = cast(Tuple[int], txn.fetchone())[0] 225 226 sql = """ 227 SELECT 228 "media_id", 229 "media_type", 230 "media_length", 231 "upload_name", 232 "created_ts", 233 "last_access_ts", 234 "quarantined_by", 235 "safe_from_quarantine" 236 FROM local_media_repository 237 WHERE user_id = ? 238 ORDER BY {order_by_column} {order}, media_id ASC 239 LIMIT ? OFFSET ? 240 """.format( 241 order_by_column=order_by_column, 242 order=order, 243 ) 244 245 args += [limit, start] 246 txn.execute(sql, args) 247 media = self.db_pool.cursor_to_dict(txn) 248 return media, count 249 250 return await self.db_pool.runInteraction( 251 "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn 252 ) 253 254 async def get_local_media_before( 255 self, 256 before_ts: int, 257 size_gt: int, 258 keep_profiles: bool, 259 ) -> List[str]: 260 261 # to find files that have never been accessed (last_access_ts IS NULL) 262 # compare with `created_ts` 263 sql = """ 264 SELECT media_id 265 FROM local_media_repository AS lmr 266 WHERE 267 ( last_access_ts < ? 268 OR ( created_ts < ? AND last_access_ts IS NULL ) ) 269 AND media_length > ? 270 """ 271 272 if keep_profiles: 273 sql_keep = """ 274 AND ( 275 NOT EXISTS 276 (SELECT 1 277 FROM profiles 278 WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) 279 AND NOT EXISTS 280 (SELECT 1 281 FROM groups 282 WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) 283 AND NOT EXISTS 284 (SELECT 1 285 FROM room_memberships 286 WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id) 287 AND NOT EXISTS 288 (SELECT 1 289 FROM user_directory 290 WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id) 291 AND NOT EXISTS 292 (SELECT 1 293 FROM room_stats_state 294 WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id) 295 ) 296 """.format( 297 media_prefix="mxc://%s/" % (self.server_name,), 298 ) 299 sql += sql_keep 300 301 def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]: 302 txn.execute(sql, (before_ts, before_ts, size_gt)) 303 return [row[0] for row in txn] 304 305 return await self.db_pool.runInteraction( 306 "get_local_media_before", _get_local_media_before_txn 307 ) 308 309 async def store_local_media( 310 self, 311 media_id: str, 312 media_type: str, 313 time_now_ms: int, 314 upload_name: Optional[str], 315 media_length: int, 316 user_id: UserID, 317 url_cache: Optional[str] = None, 318 ) -> None: 319 await self.db_pool.simple_insert( 320 "local_media_repository", 321 { 322 "media_id": media_id, 323 "media_type": media_type, 324 "created_ts": time_now_ms, 325 "upload_name": upload_name, 326 "media_length": media_length, 327 "user_id": user_id.to_string(), 328 "url_cache": url_cache, 329 }, 330 desc="store_local_media", 331 ) 332 333 async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None: 334 """Mark a local media as safe or unsafe from quarantining.""" 335 await self.db_pool.simple_update_one( 336 table="local_media_repository", 337 keyvalues={"media_id": media_id}, 338 updatevalues={"safe_from_quarantine": safe}, 339 desc="mark_local_media_as_safe", 340 ) 341 342 async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: 343 """Get the media_id and ts for a cached URL as of the given timestamp 344 Returns: 345 None if the URL isn't cached. 346 """ 347 348 def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: 349 # get the most recently cached result (relative to the given ts) 350 sql = ( 351 "SELECT response_code, etag, expires_ts, og, media_id, download_ts" 352 " FROM local_media_repository_url_cache" 353 " WHERE url = ? AND download_ts <= ?" 354 " ORDER BY download_ts DESC LIMIT 1" 355 ) 356 txn.execute(sql, (url, ts)) 357 row = txn.fetchone() 358 359 if not row: 360 # ...or if we've requested a timestamp older than the oldest 361 # copy in the cache, return the oldest copy (if any) 362 sql = ( 363 "SELECT response_code, etag, expires_ts, og, media_id, download_ts" 364 " FROM local_media_repository_url_cache" 365 " WHERE url = ? AND download_ts > ?" 366 " ORDER BY download_ts ASC LIMIT 1" 367 ) 368 txn.execute(sql, (url, ts)) 369 row = txn.fetchone() 370 371 if not row: 372 return None 373 374 return dict( 375 zip( 376 ( 377 "response_code", 378 "etag", 379 "expires_ts", 380 "og", 381 "media_id", 382 "download_ts", 383 ), 384 row, 385 ) 386 ) 387 388 return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) 389 390 async def store_url_cache( 391 self, url, response_code, etag, expires_ts, og, media_id, download_ts 392 ) -> None: 393 await self.db_pool.simple_insert( 394 "local_media_repository_url_cache", 395 { 396 "url": url, 397 "response_code": response_code, 398 "etag": etag, 399 "expires_ts": expires_ts, 400 "og": og, 401 "media_id": media_id, 402 "download_ts": download_ts, 403 }, 404 desc="store_url_cache", 405 ) 406 407 async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: 408 return await self.db_pool.simple_select_list( 409 "local_media_repository_thumbnails", 410 {"media_id": media_id}, 411 ( 412 "thumbnail_width", 413 "thumbnail_height", 414 "thumbnail_method", 415 "thumbnail_type", 416 "thumbnail_length", 417 ), 418 desc="get_local_media_thumbnails", 419 ) 420 421 async def store_local_thumbnail( 422 self, 423 media_id: str, 424 thumbnail_width: int, 425 thumbnail_height: int, 426 thumbnail_type: str, 427 thumbnail_method: str, 428 thumbnail_length: int, 429 ) -> None: 430 await self.db_pool.simple_upsert( 431 table="local_media_repository_thumbnails", 432 keyvalues={ 433 "media_id": media_id, 434 "thumbnail_width": thumbnail_width, 435 "thumbnail_height": thumbnail_height, 436 "thumbnail_method": thumbnail_method, 437 "thumbnail_type": thumbnail_type, 438 }, 439 values={"thumbnail_length": thumbnail_length}, 440 desc="store_local_thumbnail", 441 ) 442 443 async def get_cached_remote_media( 444 self, origin, media_id: str 445 ) -> Optional[Dict[str, Any]]: 446 return await self.db_pool.simple_select_one( 447 "remote_media_cache", 448 {"media_origin": origin, "media_id": media_id}, 449 ( 450 "media_type", 451 "media_length", 452 "upload_name", 453 "created_ts", 454 "filesystem_id", 455 "quarantined_by", 456 ), 457 allow_none=True, 458 desc="get_cached_remote_media", 459 ) 460 461 async def store_cached_remote_media( 462 self, 463 origin: str, 464 media_id: str, 465 media_type: str, 466 media_length: int, 467 time_now_ms: int, 468 upload_name: Optional[str], 469 filesystem_id: str, 470 ) -> None: 471 await self.db_pool.simple_insert( 472 "remote_media_cache", 473 { 474 "media_origin": origin, 475 "media_id": media_id, 476 "media_type": media_type, 477 "media_length": media_length, 478 "created_ts": time_now_ms, 479 "upload_name": upload_name, 480 "filesystem_id": filesystem_id, 481 "last_access_ts": time_now_ms, 482 }, 483 desc="store_cached_remote_media", 484 ) 485 486 async def update_cached_last_access_time( 487 self, 488 local_media: Iterable[str], 489 remote_media: Iterable[Tuple[str, str]], 490 time_ms: int, 491 ) -> None: 492 """Updates the last access time of the given media 493 494 Args: 495 local_media: Set of media_ids 496 remote_media: Set of (server_name, media_id) 497 time_ms: Current time in milliseconds 498 """ 499 500 def update_cache_txn(txn: LoggingTransaction) -> None: 501 sql = ( 502 "UPDATE remote_media_cache SET last_access_ts = ?" 503 " WHERE media_origin = ? AND media_id = ?" 504 ) 505 506 txn.execute_batch( 507 sql, 508 ( 509 (time_ms, media_origin, media_id) 510 for media_origin, media_id in remote_media 511 ), 512 ) 513 514 sql = ( 515 "UPDATE local_media_repository SET last_access_ts = ?" 516 " WHERE media_id = ?" 517 ) 518 519 txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) 520 521 await self.db_pool.runInteraction( 522 "update_cached_last_access_time", update_cache_txn 523 ) 524 525 async def get_remote_media_thumbnails( 526 self, origin: str, media_id: str 527 ) -> List[Dict[str, Any]]: 528 return await self.db_pool.simple_select_list( 529 "remote_media_cache_thumbnails", 530 {"media_origin": origin, "media_id": media_id}, 531 ( 532 "thumbnail_width", 533 "thumbnail_height", 534 "thumbnail_method", 535 "thumbnail_type", 536 "thumbnail_length", 537 "filesystem_id", 538 ), 539 desc="get_remote_media_thumbnails", 540 ) 541 542 async def get_remote_media_thumbnail( 543 self, 544 origin: str, 545 media_id: str, 546 t_width: int, 547 t_height: int, 548 t_type: str, 549 ) -> Optional[Dict[str, Any]]: 550 """Fetch the thumbnail info of given width, height and type.""" 551 552 return await self.db_pool.simple_select_one( 553 table="remote_media_cache_thumbnails", 554 keyvalues={ 555 "media_origin": origin, 556 "media_id": media_id, 557 "thumbnail_width": t_width, 558 "thumbnail_height": t_height, 559 "thumbnail_type": t_type, 560 }, 561 retcols=( 562 "thumbnail_width", 563 "thumbnail_height", 564 "thumbnail_method", 565 "thumbnail_type", 566 "thumbnail_length", 567 "filesystem_id", 568 ), 569 allow_none=True, 570 desc="get_remote_media_thumbnail", 571 ) 572 573 async def store_remote_media_thumbnail( 574 self, 575 origin: str, 576 media_id: str, 577 filesystem_id: str, 578 thumbnail_width: int, 579 thumbnail_height: int, 580 thumbnail_type: str, 581 thumbnail_method: str, 582 thumbnail_length: int, 583 ) -> None: 584 await self.db_pool.simple_upsert( 585 table="remote_media_cache_thumbnails", 586 keyvalues={ 587 "media_origin": origin, 588 "media_id": media_id, 589 "thumbnail_width": thumbnail_width, 590 "thumbnail_height": thumbnail_height, 591 "thumbnail_method": thumbnail_method, 592 "thumbnail_type": thumbnail_type, 593 }, 594 values={"thumbnail_length": thumbnail_length}, 595 insertion_values={"filesystem_id": filesystem_id}, 596 desc="store_remote_media_thumbnail", 597 ) 598 599 async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]: 600 sql = ( 601 "SELECT media_origin, media_id, filesystem_id" 602 " FROM remote_media_cache" 603 " WHERE last_access_ts < ?" 604 ) 605 606 return await self.db_pool.execute( 607 "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts 608 ) 609 610 async def delete_remote_media(self, media_origin: str, media_id: str) -> None: 611 def delete_remote_media_txn(txn): 612 self.db_pool.simple_delete_txn( 613 txn, 614 "remote_media_cache", 615 keyvalues={"media_origin": media_origin, "media_id": media_id}, 616 ) 617 self.db_pool.simple_delete_txn( 618 txn, 619 "remote_media_cache_thumbnails", 620 keyvalues={"media_origin": media_origin, "media_id": media_id}, 621 ) 622 623 await self.db_pool.runInteraction( 624 "delete_remote_media", delete_remote_media_txn 625 ) 626 627 async def get_expired_url_cache(self, now_ts: int) -> List[str]: 628 sql = ( 629 "SELECT media_id FROM local_media_repository_url_cache" 630 " WHERE expires_ts < ?" 631 " ORDER BY expires_ts ASC" 632 " LIMIT 500" 633 ) 634 635 def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]: 636 txn.execute(sql, (now_ts,)) 637 return [row[0] for row in txn] 638 639 return await self.db_pool.runInteraction( 640 "get_expired_url_cache", _get_expired_url_cache_txn 641 ) 642 643 async def delete_url_cache(self, media_ids: Collection[str]) -> None: 644 if len(media_ids) == 0: 645 return 646 647 sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" 648 649 def _delete_url_cache_txn(txn: LoggingTransaction) -> None: 650 txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) 651 652 await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn) 653 654 async def get_url_cache_media_before(self, before_ts: int) -> List[str]: 655 sql = ( 656 "SELECT media_id FROM local_media_repository" 657 " WHERE created_ts < ? AND url_cache IS NOT NULL" 658 " ORDER BY created_ts ASC" 659 " LIMIT 500" 660 ) 661 662 def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]: 663 txn.execute(sql, (before_ts,)) 664 return [row[0] for row in txn] 665 666 return await self.db_pool.runInteraction( 667 "get_url_cache_media_before", _get_url_cache_media_before_txn 668 ) 669 670 async def delete_url_cache_media(self, media_ids: Collection[str]) -> None: 671 if len(media_ids) == 0: 672 return 673 674 def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None: 675 sql = "DELETE FROM local_media_repository WHERE media_id = ?" 676 677 txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) 678 679 sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" 680 681 txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) 682 683 await self.db_pool.runInteraction( 684 "delete_url_cache_media", _delete_url_cache_media_txn 685 ) 686