1#!/usr/bin/env python
2import random
3import string
4from tests.compat import unittest, mock
5
6import boto
7
8
9RESPONSE_TEMPLATE = r"""
10<InvalidationList>
11   <Marker/>
12   <NextMarker>%(next_marker)s</NextMarker>
13   <MaxItems>%(max_items)s</MaxItems>
14   <IsTruncated>%(is_truncated)s</IsTruncated>
15   %(inval_summaries)s
16</InvalidationList>
17"""
18
19INVAL_SUMMARY_TEMPLATE = r"""
20   <InvalidationSummary>
21      <Id>%(cfid)s</Id>
22      <Status>%(status)s</Status>
23   </InvalidationSummary>
24"""
25
26
27class CFInvalidationListTest(unittest.TestCase):
28
29    cloudfront = True
30
31    def setUp(self):
32        self.cf = boto.connect_cloudfront('aws.aws_access_key_id',
33                                          'aws.aws_secret_access_key')
34
35    def _get_random_id(self, length=14):
36        return ''.join([random.choice(string.ascii_letters) for i in
37                        range(length)])
38
39    def _group_iter(self, iterator, n):
40        accumulator = []
41        for item in iterator:
42            accumulator.append(item)
43            if len(accumulator) == n:
44                yield accumulator
45                accumulator = []
46        if len(accumulator) != 0:
47            yield accumulator
48
49    def _get_mock_responses(self, num, max_items):
50        max_items = min(max_items, 100)
51        cfid_groups = list(self._group_iter([self._get_random_id() for i in
52                                             range(num)], max_items))
53        cfg = dict(status='Completed', max_items=max_items, next_marker='')
54        responses = []
55        is_truncated = 'true'
56        for i, group in enumerate(cfid_groups):
57            next_marker = group[-1]
58            if (i + 1) == len(cfid_groups):
59                is_truncated = 'false'
60                next_marker = ''
61            invals = ''
62            cfg.update(dict(next_marker=next_marker,
63                            is_truncated=is_truncated))
64            for cfid in group:
65                cfg.update(dict(cfid=cfid))
66                invals += INVAL_SUMMARY_TEMPLATE % cfg
67            cfg.update(dict(inval_summaries=invals))
68            mock_response = mock.Mock()
69            mock_response.read.return_value = (RESPONSE_TEMPLATE % cfg).encode('utf-8')
70            mock_response.status = 200
71            responses.append(mock_response)
72        return responses
73
74    def test_manual_pagination(self, num_invals=30, max_items=4):
75        """
76        Test that paginating manually works properly
77        """
78        self.assertGreater(num_invals, max_items)
79        responses = self._get_mock_responses(num=num_invals,
80                                             max_items=max_items)
81        self.cf.make_request = mock.Mock(side_effect=responses)
82        ir = self.cf.get_invalidation_requests('dist-id-here',
83                                               max_items=max_items)
84        all_invals = list(ir)
85        self.assertEqual(len(all_invals), max_items)
86        while ir.is_truncated:
87            ir = self.cf.get_invalidation_requests('dist-id-here',
88                                                   marker=ir.next_marker,
89                                                   max_items=max_items)
90            invals = list(ir)
91            self.assertLessEqual(len(invals), max_items)
92            all_invals.extend(invals)
93        remainder = num_invals % max_items
94        if remainder != 0:
95            self.assertEqual(len(invals), remainder)
96        self.assertEqual(len(all_invals), num_invals)
97
98    def test_auto_pagination(self, num_invals=1024):
99        """
100        Test that auto-pagination works properly
101        """
102        max_items = 100
103        self.assertGreaterEqual(num_invals, max_items)
104        responses = self._get_mock_responses(num=num_invals,
105                                             max_items=max_items)
106        self.cf.make_request = mock.Mock(side_effect=responses)
107        ir = self.cf.get_invalidation_requests('dist-id-here')
108        self.assertEqual(len(ir._inval_cache), max_items)
109        self.assertEqual(len(list(ir)), num_invals)
110
111if __name__ == '__main__':
112    unittest.main()
113