1from functools import wraps 2import logging 3 4 5logger = logging.getLogger('peewee') 6 7 8class _QueryLogHandler(logging.Handler): 9 def __init__(self, *args, **kwargs): 10 self.queries = [] 11 logging.Handler.__init__(self, *args, **kwargs) 12 13 def emit(self, record): 14 self.queries.append(record) 15 16 17class count_queries(object): 18 def __init__(self, only_select=False): 19 self.only_select = only_select 20 self.count = 0 21 22 def get_queries(self): 23 return self._handler.queries 24 25 def __enter__(self): 26 self._handler = _QueryLogHandler() 27 logger.setLevel(logging.DEBUG) 28 logger.addHandler(self._handler) 29 return self 30 31 def __exit__(self, exc_type, exc_val, exc_tb): 32 logger.removeHandler(self._handler) 33 if self.only_select: 34 self.count = len([q for q in self._handler.queries 35 if q.msg[0].startswith('SELECT ')]) 36 else: 37 self.count = len(self._handler.queries) 38 39 40class assert_query_count(count_queries): 41 def __init__(self, expected, only_select=False): 42 super(assert_query_count, self).__init__(only_select=only_select) 43 self.expected = expected 44 45 def __call__(self, f): 46 @wraps(f) 47 def decorated(*args, **kwds): 48 with self: 49 ret = f(*args, **kwds) 50 51 self._assert_count() 52 return ret 53 54 return decorated 55 56 def _assert_count(self): 57 error_msg = '%s != %s' % (self.count, self.expected) 58 assert self.count == self.expected, error_msg 59 60 def __exit__(self, exc_type, exc_val, exc_tb): 61 super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb) 62 self._assert_count() 63