1from functools import wraps
2import logging
3
4
5logger = logging.getLogger('peewee')
6
7
8class _QueryLogHandler(logging.Handler):
9    def __init__(self, *args, **kwargs):
10        self.queries = []
11        logging.Handler.__init__(self, *args, **kwargs)
12
13    def emit(self, record):
14        self.queries.append(record)
15
16
17class count_queries(object):
18    def __init__(self, only_select=False):
19        self.only_select = only_select
20        self.count = 0
21
22    def get_queries(self):
23        return self._handler.queries
24
25    def __enter__(self):
26        self._handler = _QueryLogHandler()
27        logger.setLevel(logging.DEBUG)
28        logger.addHandler(self._handler)
29        return self
30
31    def __exit__(self, exc_type, exc_val, exc_tb):
32        logger.removeHandler(self._handler)
33        if self.only_select:
34            self.count = len([q for q in self._handler.queries
35                              if q.msg[0].startswith('SELECT ')])
36        else:
37            self.count = len(self._handler.queries)
38
39
40class assert_query_count(count_queries):
41    def __init__(self, expected, only_select=False):
42        super(assert_query_count, self).__init__(only_select=only_select)
43        self.expected = expected
44
45    def __call__(self, f):
46        @wraps(f)
47        def decorated(*args, **kwds):
48            with self:
49                ret = f(*args, **kwds)
50
51            self._assert_count()
52            return ret
53
54        return decorated
55
56    def _assert_count(self):
57        error_msg = '%s != %s' % (self.count, self.expected)
58        assert self.count == self.expected, error_msg
59
60    def __exit__(self, exc_type, exc_val, exc_tb):
61        super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb)
62        self._assert_count()
63