1import asyncio
2import logging
3import time
4
5from postfix_mta_sts_resolver import constants
6from postfix_mta_sts_resolver.base_cache import CacheEntry
7from postfix_mta_sts_resolver.resolver import STSResolver, STSFetchResult
8
9
10# pylint: disable=too-many-instance-attributes
11class STSProactiveFetcher:
12    def __init__(self, cfg, loop, cache):
13        self._shutdown_timeout = cfg['shutdown_timeout']
14        self._pf_interval = cfg['proactive_policy_fetching']['interval']
15        self._pf_concurrency_limit = cfg['proactive_policy_fetching']['concurrency_limit']
16        self._pf_grace_ratio = cfg['proactive_policy_fetching']['grace_ratio']
17        self._logger = logging.getLogger("PF")
18        self._loop = loop
19        self._cache = cache
20        self._periodic_fetch_task = None
21        self._resolver = STSResolver(loop=loop,
22                                     timeout=cfg["default_zone"]["timeout"])
23
24    async def process_domain(self, domain_queue):
25        async def update(cached):
26            status, policy = await self._resolver.resolve(domain, cached.pol_id)
27            if status is STSFetchResult.VALID:
28                pol_id, pol_body = policy
29                updated = CacheEntry(ts, pol_id, pol_body)
30                await self._cache.safe_set(domain, updated, self._logger)
31            elif status is STSFetchResult.NOT_CHANGED:
32                updated = CacheEntry(ts, cached.pol_id, cached.pol_body)
33                await self._cache.safe_set(domain, updated, self._logger)
34            else:
35                self._logger.warning("Domain %s does not have a valid policy.", domain)
36
37        while True:  # Run until cancelled
38            cache_item = await domain_queue.get()
39            ts = time.time()  # pylint: disable=invalid-name
40            try:
41                domain, cached = cache_item
42                if ts - cached.ts < self._pf_interval / self._pf_grace_ratio:
43                    self._logger.debug("Domain %s skipped (cache recent enough).", domain)
44                else:
45                    await update(cached)
46            except asyncio.CancelledError:  # pragma: no cover pylint: disable=try-except-raise
47                raise
48            except Exception as exc:  # pragma: no cover
49                self._logger.exception("Unhandled exception: %s", exc)
50            finally:
51                domain_queue.task_done()
52
53    async def iterate_domains(self):
54        self._logger.info("Proactive policy fetching "
55                          "for all domains in cache started...")
56
57        # Create domain processor tasks
58        domain_processors = []
59        domain_queue = asyncio.Queue(maxsize=constants.DOMAIN_QUEUE_LIMIT)
60        for _ in range(self._pf_concurrency_limit):
61            domain_processor = self._loop.create_task(self.process_domain(domain_queue))
62            domain_processors.append(domain_processor)
63
64        # Produce work for domain processors
65        try:
66            token = None
67            while True:
68                token, cache_items = await self._cache.scan(token, constants.DOMAIN_QUEUE_LIMIT)
69                self._logger.debug("Enqueued %d domains for processing.", len(cache_items))
70                for cache_item in cache_items:
71                    await domain_queue.put(cache_item)
72                if token is None:
73                    break
74
75            # Wait for queue to clear
76            await domain_queue.join()
77        # Clean up the domain processors
78        finally:
79            for domain_processor in domain_processors:
80                domain_processor.cancel()
81            await asyncio.gather(*domain_processors, return_exceptions=True)
82
83        # Update the proactive fetch timestamp
84        await self._cache.set_proactive_fetch_ts(time.time())
85
86        self._logger.info("Proactive policy fetching "
87                          "for all domains in cache finished.")
88
89    async def fetch_periodically(self):
90        while True:  # Run until cancelled
91            next_fetch_ts = await self._cache.get_proactive_fetch_ts() + self._pf_interval
92            sleep_duration = max(constants.MIN_PROACTIVE_FETCH_INTERVAL,
93                                 next_fetch_ts - time.time() + 1)
94
95            self._logger.debug("Sleeping for %ds until next fetch.", sleep_duration)
96            await asyncio.sleep(sleep_duration)
97            await self.iterate_domains()
98
99    async def start(self):
100        self._periodic_fetch_task = self._loop.create_task(self.fetch_periodically())
101
102    async def stop(self):
103        self._periodic_fetch_task.cancel()
104
105        try:
106            self._logger.warning("Awaiting periodic fetching to finish...")
107            await self._periodic_fetch_task
108        except asyncio.CancelledError:  # pragma: no cover
109            pass
110