1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2018 New Vector Ltd
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import logging
17from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
18
19from synapse.push import PusherConfig, ThrottleParams
20from synapse.storage._base import SQLBaseStore, db_to_json
21from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
22from synapse.storage.util.id_generators import StreamIdGenerator
23from synapse.types import JsonDict
24from synapse.util import json_encoder
25from synapse.util.caches.descriptors import cached
26
27if TYPE_CHECKING:
28    from synapse.server import HomeServer
29
30logger = logging.getLogger(__name__)
31
32
33class PusherWorkerStore(SQLBaseStore):
34    def __init__(
35        self,
36        database: DatabasePool,
37        db_conn: LoggingDatabaseConnection,
38        hs: "HomeServer",
39    ):
40        super().__init__(database, db_conn, hs)
41        self._pushers_id_gen = StreamIdGenerator(
42            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
43        )
44
45        self.db_pool.updates.register_background_update_handler(
46            "remove_deactivated_pushers",
47            self._remove_deactivated_pushers,
48        )
49
50        self.db_pool.updates.register_background_update_handler(
51            "remove_stale_pushers",
52            self._remove_stale_pushers,
53        )
54
55        self.db_pool.updates.register_background_update_handler(
56            "remove_deleted_email_pushers",
57            self._remove_deleted_email_pushers,
58        )
59
60    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
61        """JSON-decode the data in the rows returned from the `pushers` table
62
63        Drops any rows whose data cannot be decoded
64        """
65        for r in rows:
66            data_json = r["data"]
67            try:
68                r["data"] = db_to_json(data_json)
69            except Exception as e:
70                logger.warning(
71                    "Invalid JSON in data for pusher %d: %s, %s",
72                    r["id"],
73                    data_json,
74                    e.args[0],
75                )
76                continue
77
78            yield PusherConfig(**r)
79
80    async def user_has_pusher(self, user_id: str) -> bool:
81        ret = await self.db_pool.simple_select_one_onecol(
82            "pushers", {"user_name": user_id}, "id", allow_none=True
83        )
84        return ret is not None
85
86    async def get_pushers_by_app_id_and_pushkey(
87        self, app_id: str, pushkey: str
88    ) -> Iterator[PusherConfig]:
89        return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
90
91    async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
92        return await self.get_pushers_by({"user_name": user_id})
93
94    async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
95        ret = await self.db_pool.simple_select_list(
96            "pushers",
97            keyvalues,
98            [
99                "id",
100                "user_name",
101                "access_token",
102                "profile_tag",
103                "kind",
104                "app_id",
105                "app_display_name",
106                "device_display_name",
107                "pushkey",
108                "ts",
109                "lang",
110                "data",
111                "last_stream_ordering",
112                "last_success",
113                "failing_since",
114            ],
115            desc="get_pushers_by",
116        )
117        return self._decode_pushers_rows(ret)
118
119    async def get_all_pushers(self) -> Iterator[PusherConfig]:
120        def get_pushers(txn):
121            txn.execute("SELECT * FROM pushers")
122            rows = self.db_pool.cursor_to_dict(txn)
123
124            return self._decode_pushers_rows(rows)
125
126        return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
127
128    async def get_all_updated_pushers_rows(
129        self, instance_name: str, last_id: int, current_id: int, limit: int
130    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
131        """Get updates for pushers replication stream.
132
133        Args:
134            instance_name: The writer we want to fetch updates from. Unused
135                here since there is only ever one writer.
136            last_id: The token to fetch updates from. Exclusive.
137            current_id: The token to fetch updates up to. Inclusive.
138            limit: The requested limit for the number of rows to return. The
139                function may return more or fewer rows.
140
141        Returns:
142            A tuple consisting of: the updates, a token to use to fetch
143            subsequent updates, and whether we returned fewer rows than exists
144            between the requested tokens due to the limit.
145
146            The token returned can be used in a subsequent call to this
147            function to get further updatees.
148
149            The updates are a list of 2-tuples of stream ID and the row data
150        """
151
152        if last_id == current_id:
153            return [], current_id, False
154
155        def get_all_updated_pushers_rows_txn(txn):
156            sql = """
157                SELECT id, user_name, app_id, pushkey
158                FROM pushers
159                WHERE ? < id AND id <= ?
160                ORDER BY id ASC LIMIT ?
161            """
162            txn.execute(sql, (last_id, current_id, limit))
163            updates = [
164                (stream_id, (user_name, app_id, pushkey, False))
165                for stream_id, user_name, app_id, pushkey in txn
166            ]
167
168            sql = """
169                SELECT stream_id, user_id, app_id, pushkey
170                FROM deleted_pushers
171                WHERE ? < stream_id AND stream_id <= ?
172                ORDER BY stream_id ASC LIMIT ?
173            """
174            txn.execute(sql, (last_id, current_id, limit))
175            updates.extend(
176                (stream_id, (user_name, app_id, pushkey, True))
177                for stream_id, user_name, app_id, pushkey in txn
178            )
179
180            updates.sort()  # Sort so that they're ordered by stream id
181
182            limited = False
183            upper_bound = current_id
184            if len(updates) >= limit:
185                limited = True
186                upper_bound = updates[-1][0]
187
188            return updates, upper_bound, limited
189
190        return await self.db_pool.runInteraction(
191            "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
192        )
193
194    @cached(num_args=1, max_entries=15000)
195    async def get_if_user_has_pusher(self, user_id: str):
196        # This only exists for the cachedList decorator
197        raise NotImplementedError()
198
199    async def update_pusher_last_stream_ordering(
200        self, app_id, pushkey, user_id, last_stream_ordering
201    ) -> None:
202        await self.db_pool.simple_update_one(
203            "pushers",
204            {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
205            {"last_stream_ordering": last_stream_ordering},
206            desc="update_pusher_last_stream_ordering",
207        )
208
209    async def update_pusher_last_stream_ordering_and_success(
210        self,
211        app_id: str,
212        pushkey: str,
213        user_id: str,
214        last_stream_ordering: int,
215        last_success: int,
216    ) -> bool:
217        """Update the last stream ordering position we've processed up to for
218        the given pusher.
219
220        Args:
221            app_id
222            pushkey
223            user_id
224            last_stream_ordering
225            last_success
226
227        Returns:
228            True if the pusher still exists; False if it has been deleted.
229        """
230        updated = await self.db_pool.simple_update(
231            table="pushers",
232            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
233            updatevalues={
234                "last_stream_ordering": last_stream_ordering,
235                "last_success": last_success,
236            },
237            desc="update_pusher_last_stream_ordering_and_success",
238        )
239
240        return bool(updated)
241
242    async def update_pusher_failing_since(
243        self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
244    ) -> None:
245        await self.db_pool.simple_update(
246            table="pushers",
247            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
248            updatevalues={"failing_since": failing_since},
249            desc="update_pusher_failing_since",
250        )
251
252    async def get_throttle_params_by_room(
253        self, pusher_id: str
254    ) -> Dict[str, ThrottleParams]:
255        res = await self.db_pool.simple_select_list(
256            "pusher_throttle",
257            {"pusher": pusher_id},
258            ["room_id", "last_sent_ts", "throttle_ms"],
259            desc="get_throttle_params_by_room",
260        )
261
262        params_by_room = {}
263        for row in res:
264            params_by_room[row["room_id"]] = ThrottleParams(
265                row["last_sent_ts"],
266                row["throttle_ms"],
267            )
268
269        return params_by_room
270
271    async def set_throttle_params(
272        self, pusher_id: str, room_id: str, params: ThrottleParams
273    ) -> None:
274        # no need to lock because `pusher_throttle` has a primary key on
275        # (pusher, room_id) so simple_upsert will retry
276        await self.db_pool.simple_upsert(
277            "pusher_throttle",
278            {"pusher": pusher_id, "room_id": room_id},
279            {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
280            desc="set_throttle_params",
281            lock=False,
282        )
283
284    async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
285        """A background update that deletes all pushers for deactivated users.
286
287        Note that we don't proacively tell the pusherpool that we've deleted
288        these (just because its a bit off a faff to do from here), but they will
289        get cleaned up at the next restart
290        """
291
292        last_user = progress.get("last_user", "")
293
294        def _delete_pushers(txn) -> int:
295
296            sql = """
297                SELECT name FROM users
298                WHERE deactivated = ? and name > ?
299                ORDER BY name ASC
300                LIMIT ?
301            """
302
303            txn.execute(sql, (1, last_user, batch_size))
304            users = [row[0] for row in txn]
305
306            self.db_pool.simple_delete_many_txn(
307                txn,
308                table="pushers",
309                column="user_name",
310                values=users,
311                keyvalues={},
312            )
313
314            if users:
315                self.db_pool.updates._background_update_progress_txn(
316                    txn, "remove_deactivated_pushers", {"last_user": users[-1]}
317                )
318
319            return len(users)
320
321        number_deleted = await self.db_pool.runInteraction(
322            "_remove_deactivated_pushers", _delete_pushers
323        )
324
325        if number_deleted < batch_size:
326            await self.db_pool.updates._end_background_update(
327                "remove_deactivated_pushers"
328            )
329
330        return number_deleted
331
332    async def _remove_stale_pushers(self, progress: dict, batch_size: int) -> int:
333        """A background update that deletes all pushers for logged out devices.
334
335        Note that we don't proacively tell the pusherpool that we've deleted
336        these (just because its a bit off a faff to do from here), but they will
337        get cleaned up at the next restart
338        """
339
340        last_pusher = progress.get("last_pusher", 0)
341
342        def _delete_pushers(txn) -> int:
343
344            sql = """
345                SELECT p.id, access_token FROM pushers AS p
346                LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
347                WHERE p.id > ?
348                ORDER BY p.id ASC
349                LIMIT ?
350            """
351
352            txn.execute(sql, (last_pusher, batch_size))
353            pushers = [(row[0], row[1]) for row in txn]
354
355            self.db_pool.simple_delete_many_txn(
356                txn,
357                table="pushers",
358                column="id",
359                values=[pusher_id for pusher_id, token in pushers if token is None],
360                keyvalues={},
361            )
362
363            if pushers:
364                self.db_pool.updates._background_update_progress_txn(
365                    txn, "remove_stale_pushers", {"last_pusher": pushers[-1][0]}
366                )
367
368            return len(pushers)
369
370        number_deleted = await self.db_pool.runInteraction(
371            "_remove_stale_pushers", _delete_pushers
372        )
373
374        if number_deleted < batch_size:
375            await self.db_pool.updates._end_background_update("remove_stale_pushers")
376
377        return number_deleted
378
379    async def _remove_deleted_email_pushers(
380        self, progress: dict, batch_size: int
381    ) -> int:
382        """A background update that deletes all pushers for deleted email addresses.
383
384        In previous versions of synapse, when users deleted their email address, it didn't
385        also delete all the pushers for that email address. This background update removes
386        those to prevent unwanted emails. This should only need to be run once (when users
387        upgrade to v1.42.0
388
389        Args:
390            progress: dict used to store progress of this background update
391            batch_size: the maximum number of rows to retrieve in a single select query
392
393        Returns:
394            The number of deleted rows
395        """
396
397        last_pusher = progress.get("last_pusher", 0)
398
399        def _delete_pushers(txn) -> int:
400
401            sql = """
402                SELECT p.id, p.user_name, p.app_id, p.pushkey
403                FROM pushers AS p
404                    LEFT JOIN user_threepids AS t
405                        ON t.user_id = p.user_name
406                        AND t.medium = 'email'
407                        AND t.address = p.pushkey
408                WHERE t.user_id is NULL
409                    AND p.app_id = 'm.email'
410                    AND p.id > ?
411                ORDER BY p.id ASC
412                LIMIT ?
413            """
414
415            txn.execute(sql, (last_pusher, batch_size))
416            rows = txn.fetchall()
417
418            last = None
419            num_deleted = 0
420            for row in rows:
421                last = row[0]
422                num_deleted += 1
423                self.db_pool.simple_delete_txn(
424                    txn,
425                    "pushers",
426                    {"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
427                )
428
429            if last is not None:
430                self.db_pool.updates._background_update_progress_txn(
431                    txn, "remove_deleted_email_pushers", {"last_pusher": last}
432                )
433
434            return num_deleted
435
436        number_deleted = await self.db_pool.runInteraction(
437            "_remove_deleted_email_pushers", _delete_pushers
438        )
439
440        if number_deleted < batch_size:
441            await self.db_pool.updates._end_background_update(
442                "remove_deleted_email_pushers"
443            )
444
445        return number_deleted
446
447
448class PusherStore(PusherWorkerStore):
449    def get_pushers_stream_token(self) -> int:
450        return self._pushers_id_gen.get_current_token()
451
452    async def add_pusher(
453        self,
454        user_id: str,
455        access_token: Optional[int],
456        kind: str,
457        app_id: str,
458        app_display_name: str,
459        device_display_name: str,
460        pushkey: str,
461        pushkey_ts: int,
462        lang: Optional[str],
463        data: Optional[JsonDict],
464        last_stream_ordering: int,
465        profile_tag: str = "",
466    ) -> None:
467        async with self._pushers_id_gen.get_next() as stream_id:
468            # no need to lock because `pushers` has a unique key on
469            # (app_id, pushkey, user_name) so simple_upsert will retry
470            await self.db_pool.simple_upsert(
471                table="pushers",
472                keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
473                values={
474                    "access_token": access_token,
475                    "kind": kind,
476                    "app_display_name": app_display_name,
477                    "device_display_name": device_display_name,
478                    "ts": pushkey_ts,
479                    "lang": lang,
480                    "data": json_encoder.encode(data),
481                    "last_stream_ordering": last_stream_ordering,
482                    "profile_tag": profile_tag,
483                    "id": stream_id,
484                },
485                desc="add_pusher",
486                lock=False,
487            )
488
489            user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
490                (user_id,), None, update_metrics=False
491            )
492
493            if user_has_pusher is not True:
494                # invalidate, since we the user might not have had a pusher before
495                await self.db_pool.runInteraction(
496                    "add_pusher",
497                    self._invalidate_cache_and_stream,  # type: ignore[attr-defined]
498                    self.get_if_user_has_pusher,
499                    (user_id,),
500                )
501
502    async def delete_pusher_by_app_id_pushkey_user_id(
503        self, app_id: str, pushkey: str, user_id: str
504    ) -> None:
505        def delete_pusher_txn(txn, stream_id):
506            self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
507                txn, self.get_if_user_has_pusher, (user_id,)
508            )
509
510            # It is expected that there is exactly one pusher to delete, but
511            # if it isn't there (or there are multiple) delete them all.
512            self.db_pool.simple_delete_txn(
513                txn,
514                "pushers",
515                {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
516            )
517
518            # it's possible for us to end up with duplicate rows for
519            # (app_id, pushkey, user_id) at different stream_ids, but that
520            # doesn't really matter.
521            self.db_pool.simple_insert_txn(
522                txn,
523                table="deleted_pushers",
524                values={
525                    "stream_id": stream_id,
526                    "app_id": app_id,
527                    "pushkey": pushkey,
528                    "user_id": user_id,
529                },
530            )
531
532        async with self._pushers_id_gen.get_next() as stream_id:
533            await self.db_pool.runInteraction(
534                "delete_pusher", delete_pusher_txn, stream_id
535            )
536
537    async def delete_all_pushers_for_user(self, user_id: str) -> None:
538        """Delete all pushers associated with an account."""
539
540        # We want to generate a row in `deleted_pushers` for each pusher we're
541        # deleting, so we fetch the list now so we can generate the appropriate
542        # number of stream IDs.
543        #
544        # Note: technically there could be a race here between adding/deleting
545        # pushers, but a) the worst case if we don't stop a pusher until the
546        # next restart and b) this is only called when we're deactivating an
547        # account.
548        pushers = list(await self.get_pushers_by_user_id(user_id))
549
550        def delete_pushers_txn(txn, stream_ids):
551            self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
552                txn, self.get_if_user_has_pusher, (user_id,)
553            )
554
555            self.db_pool.simple_delete_txn(
556                txn,
557                table="pushers",
558                keyvalues={"user_name": user_id},
559            )
560
561            self.db_pool.simple_insert_many_txn(
562                txn,
563                table="deleted_pushers",
564                values=[
565                    {
566                        "stream_id": stream_id,
567                        "app_id": pusher.app_id,
568                        "pushkey": pusher.pushkey,
569                        "user_id": user_id,
570                    }
571                    for stream_id, pusher in zip(stream_ids, pushers)
572                ],
573            )
574
575        async with self._pushers_id_gen.get_next_mult(len(pushers)) as stream_ids:
576            await self.db_pool.runInteraction(
577                "delete_all_pushers_for_user", delete_pushers_txn, stream_ids
578            )
579