1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15from enum import Enum
16from typing import (
17    TYPE_CHECKING,
18    Any,
19    Collection,
20    Dict,
21    Iterable,
22    List,
23    Optional,
24    Tuple,
25    Union,
26    cast,
27)
28
29from synapse.storage._base import SQLBaseStore
30from synapse.storage.database import (
31    DatabasePool,
32    LoggingDatabaseConnection,
33    LoggingTransaction,
34)
35from synapse.types import JsonDict, UserID
36
37if TYPE_CHECKING:
38    from synapse.server import HomeServer
39
40BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
41    "media_repository_drop_index_wo_method"
42)
43BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
44    "media_repository_drop_index_wo_method_2"
45)
46
47
48class MediaSortOrder(Enum):
49    """
50    Enum to define the sorting method used when returning media with
51    get_local_media_by_user_paginate
52    """
53
54    MEDIA_ID = "media_id"
55    UPLOAD_NAME = "upload_name"
56    CREATED_TS = "created_ts"
57    LAST_ACCESS_TS = "last_access_ts"
58    MEDIA_LENGTH = "media_length"
59    MEDIA_TYPE = "media_type"
60    QUARANTINED_BY = "quarantined_by"
61    SAFE_FROM_QUARANTINE = "safe_from_quarantine"
62
63
64class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
65    def __init__(
66        self,
67        database: DatabasePool,
68        db_conn: LoggingDatabaseConnection,
69        hs: "HomeServer",
70    ):
71        super().__init__(database, db_conn, hs)
72
73        self.db_pool.updates.register_background_index_update(
74            update_name="local_media_repository_url_idx",
75            index_name="local_media_repository_url_idx",
76            table="local_media_repository",
77            columns=["created_ts"],
78            where_clause="url_cache IS NOT NULL",
79        )
80
81        # The following the updates add the method to the unique constraint of
82        # the thumbnail databases. That fixes an issue, where thumbnails of the
83        # same resolution, but different methods could overwrite one another.
84        # This can happen with custom thumbnail configs or with dynamic thumbnailing.
85        self.db_pool.updates.register_background_index_update(
86            update_name="local_media_repository_thumbnails_method_idx",
87            index_name="local_media_repository_thumbn_media_id_width_height_method_key",
88            table="local_media_repository_thumbnails",
89            columns=[
90                "media_id",
91                "thumbnail_width",
92                "thumbnail_height",
93                "thumbnail_type",
94                "thumbnail_method",
95            ],
96            unique=True,
97        )
98
99        self.db_pool.updates.register_background_index_update(
100            update_name="remote_media_repository_thumbnails_method_idx",
101            index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
102            table="remote_media_cache_thumbnails",
103            columns=[
104                "media_origin",
105                "media_id",
106                "thumbnail_width",
107                "thumbnail_height",
108                "thumbnail_type",
109                "thumbnail_method",
110            ],
111            unique=True,
112        )
113
114        # the original impl of _drop_media_index_without_method was broken (see
115        # https://github.com/matrix-org/synapse/issues/8649), so we replace the original
116        # impl with a no-op and run the fixed migration as
117        # media_repository_drop_index_wo_method_2.
118        self.db_pool.updates.register_noop_background_update(
119            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
120        )
121        self.db_pool.updates.register_background_update_handler(
122            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
123            self._drop_media_index_without_method,
124        )
125
126    async def _drop_media_index_without_method(
127        self, progress: JsonDict, batch_size: int
128    ) -> int:
129        """background update handler which removes the old constraints.
130
131        Note that this is only run on postgres.
132        """
133
134        def f(txn: LoggingTransaction) -> None:
135            txn.execute(
136                "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
137            )
138            txn.execute(
139                "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
140            )
141
142        await self.db_pool.runInteraction("drop_media_indices_without_method", f)
143        await self.db_pool.updates._end_background_update(
144            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
145        )
146        return 1
147
148
149class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
150    """Persistence for attachments and avatars"""
151
152    def __init__(
153        self,
154        database: DatabasePool,
155        db_conn: LoggingDatabaseConnection,
156        hs: "HomeServer",
157    ):
158        super().__init__(database, db_conn, hs)
159        self.server_name = hs.hostname
160
161    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
162        """Get the metadata for a local piece of media
163
164        Returns:
165            None if the media_id doesn't exist.
166        """
167        return await self.db_pool.simple_select_one(
168            "local_media_repository",
169            {"media_id": media_id},
170            (
171                "media_type",
172                "media_length",
173                "upload_name",
174                "created_ts",
175                "quarantined_by",
176                "url_cache",
177                "safe_from_quarantine",
178            ),
179            allow_none=True,
180            desc="get_local_media",
181        )
182
183    async def get_local_media_by_user_paginate(
184        self,
185        start: int,
186        limit: int,
187        user_id: str,
188        order_by: str = MediaSortOrder.CREATED_TS.value,
189        direction: str = "f",
190    ) -> Tuple[List[Dict[str, Any]], int]:
191        """Get a paginated list of metadata for a local piece of media
192        which an user_id has uploaded
193
194        Args:
195            start: offset in the list
196            limit: maximum amount of media_ids to retrieve
197            user_id: fully-qualified user id
198            order_by: the sort order of the returned list
199            direction: sort ascending or descending
200        Returns:
201            A paginated list of all metadata of user's media,
202            plus the total count of all the user's media
203        """
204
205        def get_local_media_by_user_paginate_txn(
206            txn: LoggingTransaction,
207        ) -> Tuple[List[Dict[str, Any]], int]:
208
209            # Set ordering
210            order_by_column = MediaSortOrder(order_by).value
211
212            if direction == "b":
213                order = "DESC"
214            else:
215                order = "ASC"
216
217            args: List[Union[str, int]] = [user_id]
218            sql = """
219                SELECT COUNT(*) as total_media
220                FROM local_media_repository
221                WHERE user_id = ?
222            """
223            txn.execute(sql, args)
224            count = cast(Tuple[int], txn.fetchone())[0]
225
226            sql = """
227                SELECT
228                    "media_id",
229                    "media_type",
230                    "media_length",
231                    "upload_name",
232                    "created_ts",
233                    "last_access_ts",
234                    "quarantined_by",
235                    "safe_from_quarantine"
236                FROM local_media_repository
237                WHERE user_id = ?
238                ORDER BY {order_by_column} {order}, media_id ASC
239                LIMIT ? OFFSET ?
240            """.format(
241                order_by_column=order_by_column,
242                order=order,
243            )
244
245            args += [limit, start]
246            txn.execute(sql, args)
247            media = self.db_pool.cursor_to_dict(txn)
248            return media, count
249
250        return await self.db_pool.runInteraction(
251            "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
252        )
253
254    async def get_local_media_before(
255        self,
256        before_ts: int,
257        size_gt: int,
258        keep_profiles: bool,
259    ) -> List[str]:
260
261        # to find files that have never been accessed (last_access_ts IS NULL)
262        # compare with `created_ts`
263        sql = """
264            SELECT media_id
265            FROM local_media_repository AS lmr
266            WHERE
267                ( last_access_ts < ?
268                OR ( created_ts < ? AND last_access_ts IS NULL ) )
269                AND media_length > ?
270        """
271
272        if keep_profiles:
273            sql_keep = """
274                AND (
275                    NOT EXISTS
276                        (SELECT 1
277                         FROM profiles
278                         WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
279                    AND NOT EXISTS
280                        (SELECT 1
281                         FROM groups
282                         WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
283                    AND NOT EXISTS
284                        (SELECT 1
285                         FROM room_memberships
286                         WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id)
287                    AND NOT EXISTS
288                        (SELECT 1
289                         FROM user_directory
290                         WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id)
291                    AND NOT EXISTS
292                        (SELECT 1
293                         FROM room_stats_state
294                         WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id)
295                )
296            """.format(
297                media_prefix="mxc://%s/" % (self.server_name,),
298            )
299            sql += sql_keep
300
301        def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
302            txn.execute(sql, (before_ts, before_ts, size_gt))
303            return [row[0] for row in txn]
304
305        return await self.db_pool.runInteraction(
306            "get_local_media_before", _get_local_media_before_txn
307        )
308
309    async def store_local_media(
310        self,
311        media_id: str,
312        media_type: str,
313        time_now_ms: int,
314        upload_name: Optional[str],
315        media_length: int,
316        user_id: UserID,
317        url_cache: Optional[str] = None,
318    ) -> None:
319        await self.db_pool.simple_insert(
320            "local_media_repository",
321            {
322                "media_id": media_id,
323                "media_type": media_type,
324                "created_ts": time_now_ms,
325                "upload_name": upload_name,
326                "media_length": media_length,
327                "user_id": user_id.to_string(),
328                "url_cache": url_cache,
329            },
330            desc="store_local_media",
331        )
332
333    async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
334        """Mark a local media as safe or unsafe from quarantining."""
335        await self.db_pool.simple_update_one(
336            table="local_media_repository",
337            keyvalues={"media_id": media_id},
338            updatevalues={"safe_from_quarantine": safe},
339            desc="mark_local_media_as_safe",
340        )
341
342    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
343        """Get the media_id and ts for a cached URL as of the given timestamp
344        Returns:
345            None if the URL isn't cached.
346        """
347
348        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
349            # get the most recently cached result (relative to the given ts)
350            sql = (
351                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
352                " FROM local_media_repository_url_cache"
353                " WHERE url = ? AND download_ts <= ?"
354                " ORDER BY download_ts DESC LIMIT 1"
355            )
356            txn.execute(sql, (url, ts))
357            row = txn.fetchone()
358
359            if not row:
360                # ...or if we've requested a timestamp older than the oldest
361                # copy in the cache, return the oldest copy (if any)
362                sql = (
363                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
364                    " FROM local_media_repository_url_cache"
365                    " WHERE url = ? AND download_ts > ?"
366                    " ORDER BY download_ts ASC LIMIT 1"
367                )
368                txn.execute(sql, (url, ts))
369                row = txn.fetchone()
370
371            if not row:
372                return None
373
374            return dict(
375                zip(
376                    (
377                        "response_code",
378                        "etag",
379                        "expires_ts",
380                        "og",
381                        "media_id",
382                        "download_ts",
383                    ),
384                    row,
385                )
386            )
387
388        return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
389
390    async def store_url_cache(
391        self, url, response_code, etag, expires_ts, og, media_id, download_ts
392    ) -> None:
393        await self.db_pool.simple_insert(
394            "local_media_repository_url_cache",
395            {
396                "url": url,
397                "response_code": response_code,
398                "etag": etag,
399                "expires_ts": expires_ts,
400                "og": og,
401                "media_id": media_id,
402                "download_ts": download_ts,
403            },
404            desc="store_url_cache",
405        )
406
407    async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
408        return await self.db_pool.simple_select_list(
409            "local_media_repository_thumbnails",
410            {"media_id": media_id},
411            (
412                "thumbnail_width",
413                "thumbnail_height",
414                "thumbnail_method",
415                "thumbnail_type",
416                "thumbnail_length",
417            ),
418            desc="get_local_media_thumbnails",
419        )
420
421    async def store_local_thumbnail(
422        self,
423        media_id: str,
424        thumbnail_width: int,
425        thumbnail_height: int,
426        thumbnail_type: str,
427        thumbnail_method: str,
428        thumbnail_length: int,
429    ) -> None:
430        await self.db_pool.simple_upsert(
431            table="local_media_repository_thumbnails",
432            keyvalues={
433                "media_id": media_id,
434                "thumbnail_width": thumbnail_width,
435                "thumbnail_height": thumbnail_height,
436                "thumbnail_method": thumbnail_method,
437                "thumbnail_type": thumbnail_type,
438            },
439            values={"thumbnail_length": thumbnail_length},
440            desc="store_local_thumbnail",
441        )
442
443    async def get_cached_remote_media(
444        self, origin, media_id: str
445    ) -> Optional[Dict[str, Any]]:
446        return await self.db_pool.simple_select_one(
447            "remote_media_cache",
448            {"media_origin": origin, "media_id": media_id},
449            (
450                "media_type",
451                "media_length",
452                "upload_name",
453                "created_ts",
454                "filesystem_id",
455                "quarantined_by",
456            ),
457            allow_none=True,
458            desc="get_cached_remote_media",
459        )
460
461    async def store_cached_remote_media(
462        self,
463        origin: str,
464        media_id: str,
465        media_type: str,
466        media_length: int,
467        time_now_ms: int,
468        upload_name: Optional[str],
469        filesystem_id: str,
470    ) -> None:
471        await self.db_pool.simple_insert(
472            "remote_media_cache",
473            {
474                "media_origin": origin,
475                "media_id": media_id,
476                "media_type": media_type,
477                "media_length": media_length,
478                "created_ts": time_now_ms,
479                "upload_name": upload_name,
480                "filesystem_id": filesystem_id,
481                "last_access_ts": time_now_ms,
482            },
483            desc="store_cached_remote_media",
484        )
485
486    async def update_cached_last_access_time(
487        self,
488        local_media: Iterable[str],
489        remote_media: Iterable[Tuple[str, str]],
490        time_ms: int,
491    ) -> None:
492        """Updates the last access time of the given media
493
494        Args:
495            local_media: Set of media_ids
496            remote_media: Set of (server_name, media_id)
497            time_ms: Current time in milliseconds
498        """
499
500        def update_cache_txn(txn: LoggingTransaction) -> None:
501            sql = (
502                "UPDATE remote_media_cache SET last_access_ts = ?"
503                " WHERE media_origin = ? AND media_id = ?"
504            )
505
506            txn.execute_batch(
507                sql,
508                (
509                    (time_ms, media_origin, media_id)
510                    for media_origin, media_id in remote_media
511                ),
512            )
513
514            sql = (
515                "UPDATE local_media_repository SET last_access_ts = ?"
516                " WHERE media_id = ?"
517            )
518
519            txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
520
521        await self.db_pool.runInteraction(
522            "update_cached_last_access_time", update_cache_txn
523        )
524
525    async def get_remote_media_thumbnails(
526        self, origin: str, media_id: str
527    ) -> List[Dict[str, Any]]:
528        return await self.db_pool.simple_select_list(
529            "remote_media_cache_thumbnails",
530            {"media_origin": origin, "media_id": media_id},
531            (
532                "thumbnail_width",
533                "thumbnail_height",
534                "thumbnail_method",
535                "thumbnail_type",
536                "thumbnail_length",
537                "filesystem_id",
538            ),
539            desc="get_remote_media_thumbnails",
540        )
541
542    async def get_remote_media_thumbnail(
543        self,
544        origin: str,
545        media_id: str,
546        t_width: int,
547        t_height: int,
548        t_type: str,
549    ) -> Optional[Dict[str, Any]]:
550        """Fetch the thumbnail info of given width, height and type."""
551
552        return await self.db_pool.simple_select_one(
553            table="remote_media_cache_thumbnails",
554            keyvalues={
555                "media_origin": origin,
556                "media_id": media_id,
557                "thumbnail_width": t_width,
558                "thumbnail_height": t_height,
559                "thumbnail_type": t_type,
560            },
561            retcols=(
562                "thumbnail_width",
563                "thumbnail_height",
564                "thumbnail_method",
565                "thumbnail_type",
566                "thumbnail_length",
567                "filesystem_id",
568            ),
569            allow_none=True,
570            desc="get_remote_media_thumbnail",
571        )
572
573    async def store_remote_media_thumbnail(
574        self,
575        origin: str,
576        media_id: str,
577        filesystem_id: str,
578        thumbnail_width: int,
579        thumbnail_height: int,
580        thumbnail_type: str,
581        thumbnail_method: str,
582        thumbnail_length: int,
583    ) -> None:
584        await self.db_pool.simple_upsert(
585            table="remote_media_cache_thumbnails",
586            keyvalues={
587                "media_origin": origin,
588                "media_id": media_id,
589                "thumbnail_width": thumbnail_width,
590                "thumbnail_height": thumbnail_height,
591                "thumbnail_method": thumbnail_method,
592                "thumbnail_type": thumbnail_type,
593            },
594            values={"thumbnail_length": thumbnail_length},
595            insertion_values={"filesystem_id": filesystem_id},
596            desc="store_remote_media_thumbnail",
597        )
598
599    async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
600        sql = (
601            "SELECT media_origin, media_id, filesystem_id"
602            " FROM remote_media_cache"
603            " WHERE last_access_ts < ?"
604        )
605
606        return await self.db_pool.execute(
607            "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
608        )
609
610    async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
611        def delete_remote_media_txn(txn):
612            self.db_pool.simple_delete_txn(
613                txn,
614                "remote_media_cache",
615                keyvalues={"media_origin": media_origin, "media_id": media_id},
616            )
617            self.db_pool.simple_delete_txn(
618                txn,
619                "remote_media_cache_thumbnails",
620                keyvalues={"media_origin": media_origin, "media_id": media_id},
621            )
622
623        await self.db_pool.runInteraction(
624            "delete_remote_media", delete_remote_media_txn
625        )
626
627    async def get_expired_url_cache(self, now_ts: int) -> List[str]:
628        sql = (
629            "SELECT media_id FROM local_media_repository_url_cache"
630            " WHERE expires_ts < ?"
631            " ORDER BY expires_ts ASC"
632            " LIMIT 500"
633        )
634
635        def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
636            txn.execute(sql, (now_ts,))
637            return [row[0] for row in txn]
638
639        return await self.db_pool.runInteraction(
640            "get_expired_url_cache", _get_expired_url_cache_txn
641        )
642
643    async def delete_url_cache(self, media_ids: Collection[str]) -> None:
644        if len(media_ids) == 0:
645            return
646
647        sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
648
649        def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
650            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
651
652        await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
653
654    async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
655        sql = (
656            "SELECT media_id FROM local_media_repository"
657            " WHERE created_ts < ? AND url_cache IS NOT NULL"
658            " ORDER BY created_ts ASC"
659            " LIMIT 500"
660        )
661
662        def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
663            txn.execute(sql, (before_ts,))
664            return [row[0] for row in txn]
665
666        return await self.db_pool.runInteraction(
667            "get_url_cache_media_before", _get_url_cache_media_before_txn
668        )
669
670    async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
671        if len(media_ids) == 0:
672            return
673
674        def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
675            sql = "DELETE FROM local_media_repository WHERE media_id = ?"
676
677            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
678
679            sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
680
681            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
682
683        await self.db_pool.runInteraction(
684            "delete_url_cache_media", _delete_url_cache_media_txn
685        )
686