1import asyncio 2import logging 3import time 4 5from postfix_mta_sts_resolver import constants 6from postfix_mta_sts_resolver.base_cache import CacheEntry 7from postfix_mta_sts_resolver.resolver import STSResolver, STSFetchResult 8 9 10# pylint: disable=too-many-instance-attributes 11class STSProactiveFetcher: 12 def __init__(self, cfg, loop, cache): 13 self._shutdown_timeout = cfg['shutdown_timeout'] 14 self._pf_interval = cfg['proactive_policy_fetching']['interval'] 15 self._pf_concurrency_limit = cfg['proactive_policy_fetching']['concurrency_limit'] 16 self._pf_grace_ratio = cfg['proactive_policy_fetching']['grace_ratio'] 17 self._logger = logging.getLogger("PF") 18 self._loop = loop 19 self._cache = cache 20 self._periodic_fetch_task = None 21 self._resolver = STSResolver(loop=loop, 22 timeout=cfg["default_zone"]["timeout"]) 23 24 async def process_domain(self, domain_queue): 25 async def update(cached): 26 status, policy = await self._resolver.resolve(domain, cached.pol_id) 27 if status is STSFetchResult.VALID: 28 pol_id, pol_body = policy 29 updated = CacheEntry(ts, pol_id, pol_body) 30 await self._cache.safe_set(domain, updated, self._logger) 31 elif status is STSFetchResult.NOT_CHANGED: 32 updated = CacheEntry(ts, cached.pol_id, cached.pol_body) 33 await self._cache.safe_set(domain, updated, self._logger) 34 else: 35 self._logger.warning("Domain %s does not have a valid policy.", domain) 36 37 while True: # Run until cancelled 38 cache_item = await domain_queue.get() 39 ts = time.time() # pylint: disable=invalid-name 40 try: 41 domain, cached = cache_item 42 if ts - cached.ts < self._pf_interval / self._pf_grace_ratio: 43 self._logger.debug("Domain %s skipped (cache recent enough).", domain) 44 else: 45 await update(cached) 46 except asyncio.CancelledError: # pragma: no cover pylint: disable=try-except-raise 47 raise 48 except Exception as exc: # pragma: no cover 49 self._logger.exception("Unhandled exception: %s", exc) 50 finally: 51 domain_queue.task_done() 52 53 async def iterate_domains(self): 54 self._logger.info("Proactive policy fetching " 55 "for all domains in cache started...") 56 57 # Create domain processor tasks 58 domain_processors = [] 59 domain_queue = asyncio.Queue(maxsize=constants.DOMAIN_QUEUE_LIMIT) 60 for _ in range(self._pf_concurrency_limit): 61 domain_processor = self._loop.create_task(self.process_domain(domain_queue)) 62 domain_processors.append(domain_processor) 63 64 # Produce work for domain processors 65 try: 66 token = None 67 while True: 68 token, cache_items = await self._cache.scan(token, constants.DOMAIN_QUEUE_LIMIT) 69 self._logger.debug("Enqueued %d domains for processing.", len(cache_items)) 70 for cache_item in cache_items: 71 await domain_queue.put(cache_item) 72 if token is None: 73 break 74 75 # Wait for queue to clear 76 await domain_queue.join() 77 # Clean up the domain processors 78 finally: 79 for domain_processor in domain_processors: 80 domain_processor.cancel() 81 await asyncio.gather(*domain_processors, return_exceptions=True) 82 83 # Update the proactive fetch timestamp 84 await self._cache.set_proactive_fetch_ts(time.time()) 85 86 self._logger.info("Proactive policy fetching " 87 "for all domains in cache finished.") 88 89 async def fetch_periodically(self): 90 while True: # Run until cancelled 91 next_fetch_ts = await self._cache.get_proactive_fetch_ts() + self._pf_interval 92 sleep_duration = max(constants.MIN_PROACTIVE_FETCH_INTERVAL, 93 next_fetch_ts - time.time() + 1) 94 95 self._logger.debug("Sleeping for %ds until next fetch.", sleep_duration) 96 await asyncio.sleep(sleep_duration) 97 await self.iterate_domains() 98 99 async def start(self): 100 self._periodic_fetch_task = self._loop.create_task(self.fetch_periodically()) 101 102 async def stop(self): 103 self._periodic_fetch_task.cancel() 104 105 try: 106 self._logger.warning("Awaiting periodic fetching to finish...") 107 await self._periodic_fetch_task 108 except asyncio.CancelledError: # pragma: no cover 109 pass 110