1# Copyright 2014-2016 OpenMarket Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15from typing import (
16    TYPE_CHECKING,
17    AsyncContextManager,
18    Awaitable,
19    Callable,
20    Dict,
21    Iterable,
22    Optional,
23)
24
25import attr
26
27from synapse.metrics.background_process_metrics import run_as_background_process
28from synapse.storage.types import Connection
29from synapse.types import JsonDict
30from synapse.util import Clock, json_encoder
31
32from . import engines
33
34if TYPE_CHECKING:
35    from synapse.server import HomeServer
36    from synapse.storage.database import DatabasePool, LoggingTransaction
37
38logger = logging.getLogger(__name__)
39
40
41ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
42DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
43MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
44
45
46@attr.s(slots=True, frozen=True, auto_attribs=True)
47class _BackgroundUpdateHandler:
48    """A handler for a given background update.
49
50    Attributes:
51        callback: The function to call to make progress on the background
52            update.
53        oneshot: Wether the update is likely to happen all in one go, ignoring
54            the supplied target duration, e.g. index creation. This is used by
55            the update controller to help correctly schedule the update.
56    """
57
58    callback: Callable[[JsonDict, int], Awaitable[int]]
59    oneshot: bool = False
60
61
62class _BackgroundUpdateContextManager:
63    BACKGROUND_UPDATE_INTERVAL_MS = 1000
64    BACKGROUND_UPDATE_DURATION_MS = 100
65
66    def __init__(self, sleep: bool, clock: Clock):
67        self._sleep = sleep
68        self._clock = clock
69
70    async def __aenter__(self) -> int:
71        if self._sleep:
72            await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
73
74        return self.BACKGROUND_UPDATE_DURATION_MS
75
76    async def __aexit__(self, *exc) -> None:
77        pass
78
79
80class BackgroundUpdatePerformance:
81    """Tracks the how long a background update is taking to update its items"""
82
83    def __init__(self, name: str):
84        self.name = name
85        self.total_item_count = 0
86        self.total_duration_ms = 0.0
87        self.avg_item_count = 0.0
88        self.avg_duration_ms = 0.0
89
90    def update(self, item_count: int, duration_ms: float) -> None:
91        """Update the stats after doing an update"""
92        self.total_item_count += item_count
93        self.total_duration_ms += duration_ms
94
95        # Exponential moving averages for the number of items updated and
96        # the duration.
97        self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
98        self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
99
100    def average_items_per_ms(self) -> Optional[float]:
101        """An estimate of how long it takes to do a single update.
102        Returns:
103            A duration in ms as a float
104        """
105        if self.avg_duration_ms == 0:
106            return 0
107        elif self.total_item_count == 0:
108            return None
109        else:
110            # Use the exponential moving average so that we can adapt to
111            # changes in how long the update process takes.
112            return float(self.avg_item_count) / float(self.avg_duration_ms)
113
114    def total_items_per_ms(self) -> Optional[float]:
115        """An estimate of how long it takes to do a single update.
116        Returns:
117            A duration in ms as a float
118        """
119        if self.total_duration_ms == 0:
120            return 0
121        elif self.total_item_count == 0:
122            return None
123        else:
124            return float(self.total_item_count) / float(self.total_duration_ms)
125
126
127class BackgroundUpdater:
128    """Background updates are updates to the database that run in the
129    background. Each update processes a batch of data at once. We attempt to
130    limit the impact of each update by monitoring how long each batch takes to
131    process and autotuning the batch size.
132    """
133
134    MINIMUM_BACKGROUND_BATCH_SIZE = 1
135    DEFAULT_BACKGROUND_BATCH_SIZE = 100
136
137    def __init__(self, hs: "HomeServer", database: "DatabasePool"):
138        self._clock = hs.get_clock()
139        self.db_pool = database
140
141        self._database_name = database.name()
142
143        # if a background update is currently running, its name.
144        self._current_background_update: Optional[str] = None
145
146        self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
147        self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
148        self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
149
150        self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
151        self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
152        self._all_done = False
153
154        # Whether we're currently running updates
155        self._running = False
156
157        # Whether background updates are enabled. This allows us to
158        # enable/disable background updates via the admin API.
159        self.enabled = True
160
161    def register_update_controller_callbacks(
162        self,
163        on_update: ON_UPDATE_CALLBACK,
164        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
165        min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
166    ) -> None:
167        """Register callbacks from a module for each hook."""
168        if self._on_update_callback is not None:
169            logger.warning(
170                "More than one module tried to register callbacks for controlling"
171                " background updates. Only the callbacks registered by the first module"
172                " (in order of appearance in Synapse's configuration file) that tried to"
173                " do so will be called."
174            )
175
176            return
177
178        self._on_update_callback = on_update
179
180        if default_batch_size is not None:
181            self._default_batch_size_callback = default_batch_size
182
183        if min_batch_size is not None:
184            self._min_batch_size_callback = min_batch_size
185
186    def _get_context_manager_for_update(
187        self,
188        sleep: bool,
189        update_name: str,
190        database_name: str,
191        oneshot: bool,
192    ) -> AsyncContextManager[int]:
193        """Get a context manager to run a background update with.
194
195        If a module has registered a `update_handler` callback, use the context manager
196        it returns.
197
198        Otherwise, returns a context manager that will return a default value, optionally
199        sleeping if needed.
200
201        Args:
202            sleep: Whether we can sleep between updates.
203            update_name: The name of the update.
204            database_name: The name of the database the update is being run on.
205            oneshot: Whether the update will complete all in one go, e.g. index creation.
206                In such cases the returned target duration is ignored.
207
208        Returns:
209            The target duration in milliseconds that the background update should run for.
210
211            Note: this is a *target*, and an iteration may take substantially longer or
212            shorter.
213        """
214        if self._on_update_callback is not None:
215            return self._on_update_callback(update_name, database_name, oneshot)
216
217        return _BackgroundUpdateContextManager(sleep, self._clock)
218
219    async def _default_batch_size(self, update_name: str, database_name: str) -> int:
220        """The batch size to use for the first iteration of a new background
221        update.
222        """
223        if self._default_batch_size_callback is not None:
224            return await self._default_batch_size_callback(update_name, database_name)
225
226        return self.DEFAULT_BACKGROUND_BATCH_SIZE
227
228    async def _min_batch_size(self, update_name: str, database_name: str) -> int:
229        """A lower bound on the batch size of a new background update.
230
231        Used to ensure that progress is always made. Must be greater than 0.
232        """
233        if self._min_batch_size_callback is not None:
234            return await self._min_batch_size_callback(update_name, database_name)
235
236        return self.MINIMUM_BACKGROUND_BATCH_SIZE
237
238    def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
239        """Returns the current background update, if any."""
240
241        update_name = self._current_background_update
242        if not update_name:
243            return None
244
245        perf = self._background_update_performance.get(update_name)
246        if not perf:
247            perf = BackgroundUpdatePerformance(update_name)
248
249        return perf
250
251    def start_doing_background_updates(self) -> None:
252        if self.enabled:
253            # if we start a new background update, not all updates are done.
254            self._all_done = False
255            run_as_background_process("background_updates", self.run_background_updates)
256
257    async def run_background_updates(self, sleep: bool = True) -> None:
258        if self._running or not self.enabled:
259            return
260
261        self._running = True
262
263        try:
264            logger.info("Starting background schema updates")
265            while self.enabled:
266                try:
267                    result = await self.do_next_background_update(sleep)
268                except Exception:
269                    logger.exception("Error doing update")
270                else:
271                    if result:
272                        logger.info(
273                            "No more background updates to do."
274                            " Unscheduling background update task."
275                        )
276                        self._all_done = True
277                        return None
278        finally:
279            self._running = False
280
281    async def has_completed_background_updates(self) -> bool:
282        """Check if all the background updates have completed
283
284        Returns:
285            True if all background updates have completed
286        """
287        # if we've previously determined that there is nothing left to do, that
288        # is easy
289        if self._all_done:
290            return True
291
292        # obviously, if we are currently processing an update, we're not done.
293        if self._current_background_update:
294            return False
295
296        # otherwise, check if there are updates to be run. This is important,
297        # as we may be running on a worker which doesn't perform the bg updates
298        # itself, but still wants to wait for them to happen.
299        updates = await self.db_pool.simple_select_onecol(
300            "background_updates",
301            keyvalues=None,
302            retcol="1",
303            desc="has_completed_background_updates",
304        )
305        if not updates:
306            self._all_done = True
307            return True
308
309        return False
310
311    async def has_completed_background_update(self, update_name: str) -> bool:
312        """Check if the given background update has finished running."""
313        if self._all_done:
314            return True
315
316        if update_name == self._current_background_update:
317            return False
318
319        update_exists = await self.db_pool.simple_select_one_onecol(
320            "background_updates",
321            keyvalues={"update_name": update_name},
322            retcol="1",
323            desc="has_completed_background_update",
324            allow_none=True,
325        )
326
327        return not update_exists
328
329    async def do_next_background_update(self, sleep: bool = True) -> bool:
330        """Does some amount of work on the next queued background update
331
332        Returns once some amount of work is done.
333
334        Args:
335            sleep: Whether to limit how quickly we run background updates or
336                not.
337
338        Returns:
339            True if we have finished running all the background updates, otherwise False
340        """
341
342        def get_background_updates_txn(txn):
343            txn.execute(
344                """
345                SELECT update_name, depends_on FROM background_updates
346                ORDER BY ordering, update_name
347                """
348            )
349            return self.db_pool.cursor_to_dict(txn)
350
351        if not self._current_background_update:
352            all_pending_updates = await self.db_pool.runInteraction(
353                "background_updates",
354                get_background_updates_txn,
355            )
356            if not all_pending_updates:
357                # no work left to do
358                return True
359
360            # find the first update which isn't dependent on another one in the queue.
361            pending = {update["update_name"] for update in all_pending_updates}
362            for upd in all_pending_updates:
363                depends_on = upd["depends_on"]
364                if not depends_on or depends_on not in pending:
365                    break
366                logger.info(
367                    "Not starting on bg update %s until %s is done",
368                    upd["update_name"],
369                    depends_on,
370                )
371            else:
372                # if we get to the end of that for loop, there is a problem
373                raise Exception(
374                    "Unable to find a background update which doesn't depend on "
375                    "another: dependency cycle?"
376                )
377
378            self._current_background_update = upd["update_name"]
379
380        # We have a background update to run, otherwise we would have returned
381        # early.
382        assert self._current_background_update is not None
383        update_info = self._background_update_handlers[self._current_background_update]
384
385        async with self._get_context_manager_for_update(
386            sleep=sleep,
387            update_name=self._current_background_update,
388            database_name=self._database_name,
389            oneshot=update_info.oneshot,
390        ) as desired_duration_ms:
391            await self._do_background_update(desired_duration_ms)
392
393        return False
394
395    async def _do_background_update(self, desired_duration_ms: float) -> int:
396        assert self._current_background_update is not None
397        update_name = self._current_background_update
398        logger.info("Starting update batch on background update '%s'", update_name)
399
400        update_handler = self._background_update_handlers[update_name].callback
401
402        performance = self._background_update_performance.get(update_name)
403
404        if performance is None:
405            performance = BackgroundUpdatePerformance(update_name)
406            self._background_update_performance[update_name] = performance
407
408        items_per_ms = performance.average_items_per_ms()
409
410        if items_per_ms is not None:
411            batch_size = int(desired_duration_ms * items_per_ms)
412            # Clamp the batch size so that we always make progress
413            batch_size = max(
414                batch_size,
415                await self._min_batch_size(update_name, self._database_name),
416            )
417        else:
418            batch_size = await self._default_batch_size(
419                update_name, self._database_name
420            )
421
422        progress_json = await self.db_pool.simple_select_one_onecol(
423            "background_updates",
424            keyvalues={"update_name": update_name},
425            retcol="progress_json",
426        )
427
428        # Avoid a circular import.
429        from synapse.storage._base import db_to_json
430
431        progress = db_to_json(progress_json)
432
433        time_start = self._clock.time_msec()
434        items_updated = await update_handler(progress, batch_size)
435        time_stop = self._clock.time_msec()
436
437        duration_ms = time_stop - time_start
438
439        performance.update(items_updated, duration_ms)
440
441        logger.info(
442            "Running background update %r. Processed %r items in %rms."
443            " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
444            update_name,
445            items_updated,
446            duration_ms,
447            performance.total_items_per_ms(),
448            performance.average_items_per_ms(),
449            performance.total_item_count,
450            batch_size,
451        )
452
453        return len(self._background_update_performance)
454
455    def register_background_update_handler(
456        self,
457        update_name: str,
458        update_handler: Callable[[JsonDict, int], Awaitable[int]],
459    ):
460        """Register a handler for doing a background update.
461
462        The handler should take two arguments:
463
464        * A dict of the current progress
465        * An integer count of the number of items to update in this batch.
466
467        The handler should return a deferred or coroutine which returns an integer count
468        of items updated.
469
470        The handler is responsible for updating the progress of the update.
471
472        Args:
473            update_name: The name of the update that this code handles.
474            update_handler: The function that does the update.
475        """
476        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
477            update_handler
478        )
479
480    def register_noop_background_update(self, update_name: str) -> None:
481        """Register a noop handler for a background update.
482
483        This is useful when we previously did a background update, but no
484        longer wish to do the update. In this case the background update should
485        be removed from the schema delta files, but there may still be some
486        users who have the background update queued, so this method should
487        also be called to clear the update.
488
489        Args:
490            update_name: Name of update
491        """
492
493        async def noop_update(progress: JsonDict, batch_size: int) -> int:
494            await self._end_background_update(update_name)
495            return 1
496
497        self.register_background_update_handler(update_name, noop_update)
498
499    def register_background_index_update(
500        self,
501        update_name: str,
502        index_name: str,
503        table: str,
504        columns: Iterable[str],
505        where_clause: Optional[str] = None,
506        unique: bool = False,
507        psql_only: bool = False,
508    ) -> None:
509        """Helper for store classes to do a background index addition
510
511        To use:
512
513        1. use a schema delta file to add a background update. Example:
514            INSERT INTO background_updates (update_name, progress_json) VALUES
515                ('my_new_index', '{}');
516
517        2. In the Store constructor, call this method
518
519        Args:
520            update_name: update_name to register for
521            index_name: name of index to add
522            table: table to add index to
523            columns: columns/expressions to include in index
524            unique: true to make a UNIQUE index
525            psql_only: true to only create this index on psql databases (useful
526                for virtual sqlite tables)
527        """
528
529        def create_index_psql(conn: Connection) -> None:
530            conn.rollback()
531            # postgres insists on autocommit for the index
532            conn.set_session(autocommit=True)  # type: ignore
533
534            try:
535                c = conn.cursor()
536
537                # If a previous attempt to create the index was interrupted,
538                # we may already have a half-built index. Let's just drop it
539                # before trying to create it again.
540
541                sql = "DROP INDEX IF EXISTS %s" % (index_name,)
542                logger.debug("[SQL] %s", sql)
543                c.execute(sql)
544
545                sql = (
546                    "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
547                    " ON %(table)s"
548                    " (%(columns)s) %(where_clause)s"
549                ) % {
550                    "unique": "UNIQUE" if unique else "",
551                    "name": index_name,
552                    "table": table,
553                    "columns": ", ".join(columns),
554                    "where_clause": "WHERE " + where_clause if where_clause else "",
555                }
556                logger.debug("[SQL] %s", sql)
557                c.execute(sql)
558            finally:
559                conn.set_session(autocommit=False)  # type: ignore
560
561        def create_index_sqlite(conn: Connection) -> None:
562            # Sqlite doesn't support concurrent creation of indexes.
563            #
564            # We don't use partial indices on SQLite as it wasn't introduced
565            # until 3.8, and wheezy and CentOS 7 have 3.7
566            #
567            # We assume that sqlite doesn't give us invalid indices; however
568            # we may still end up with the index existing but the
569            # background_updates not having been recorded if synapse got shut
570            # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite
571            # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.)
572            sql = (
573                "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s"
574                " (%(columns)s)"
575            ) % {
576                "unique": "UNIQUE" if unique else "",
577                "name": index_name,
578                "table": table,
579                "columns": ", ".join(columns),
580            }
581
582            c = conn.cursor()
583            logger.debug("[SQL] %s", sql)
584            c.execute(sql)
585
586        if isinstance(self.db_pool.engine, engines.PostgresEngine):
587            runner: Optional[Callable[[Connection], None]] = create_index_psql
588        elif psql_only:
589            runner = None
590        else:
591            runner = create_index_sqlite
592
593        async def updater(progress, batch_size):
594            if runner is not None:
595                logger.info("Adding index %s to %s", index_name, table)
596                await self.db_pool.runWithConnection(runner)
597            await self._end_background_update(update_name)
598            return 1
599
600        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
601            updater, oneshot=True
602        )
603
604    async def _end_background_update(self, update_name: str) -> None:
605        """Removes a completed background update task from the queue.
606
607        Args:
608            update_name:: The name of the completed task to remove
609
610        Returns:
611            None, completes once the task is removed.
612        """
613        if update_name != self._current_background_update:
614            raise Exception(
615                "Cannot end background update %s which isn't currently running"
616                % update_name
617            )
618        self._current_background_update = None
619        await self.db_pool.simple_delete_one(
620            "background_updates", keyvalues={"update_name": update_name}
621        )
622
623    async def _background_update_progress(
624        self, update_name: str, progress: dict
625    ) -> None:
626        """Update the progress of a background update
627
628        Args:
629            update_name: The name of the background update task
630            progress: The progress of the update.
631        """
632
633        await self.db_pool.runInteraction(
634            "background_update_progress",
635            self._background_update_progress_txn,
636            update_name,
637            progress,
638        )
639
640    def _background_update_progress_txn(
641        self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
642    ) -> None:
643        """Update the progress of a background update
644
645        Args:
646            txn: The transaction.
647            update_name: The name of the background update task
648            progress: The progress of the update.
649        """
650
651        progress_json = json_encoder.encode(progress)
652
653        self.db_pool.simple_update_one_txn(
654            txn,
655            "background_updates",
656            keyvalues={"update_name": update_name},
657            updatevalues={"progress_json": progress_json},
658        )
659