1# Copyright 2012-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13
14from itertools import tee
15
16from botocore.compat import six
17
18import jmespath
19import json
20import base64
21import logging
22from botocore.exceptions import PaginationError
23from botocore.compat import zip
24from botocore.utils import set_value_from_jmespath, merge_dicts
25
26
27log = logging.getLogger(__name__)
28
29
30class TokenEncoder(object):
31    """Encodes dictionaries into opaque strings.
32
33    This for the most part json dumps + base64 encoding, but also supports
34    having bytes in the dictionary in addition to the types that json can
35    handle by default.
36
37    This is intended for use in encoding pagination tokens, which in some
38    cases can be complex structures and / or contain bytes.
39    """
40
41    def encode(self, token):
42        """Encodes a dictionary to an opaque string.
43
44        :type token: dict
45        :param token: A dictionary containing pagination information,
46            particularly the service pagination token(s) but also other boto
47            metadata.
48
49        :rtype: str
50        :returns: An opaque string
51        """
52        try:
53            # Try just using json dumps first to avoid having to traverse
54            # and encode the dict. In 99.9999% of cases this will work.
55            json_string = json.dumps(token)
56        except (TypeError, UnicodeDecodeError):
57            # If normal dumping failed, go through and base64 encode all bytes.
58            encoded_token, encoded_keys = self._encode(token, [])
59
60            # Save the list of all the encoded key paths. We can safely
61            # assume that no service will ever use this key.
62            encoded_token['boto_encoded_keys'] = encoded_keys
63
64            # Now that the bytes are all encoded, dump the json.
65            json_string = json.dumps(encoded_token)
66
67        # base64 encode the json string to produce an opaque token string.
68        return base64.b64encode(json_string.encode('utf-8')).decode('utf-8')
69
70    def _encode(self, data, path):
71        """Encode bytes in given data, keeping track of the path traversed."""
72        if isinstance(data, dict):
73            return self._encode_dict(data, path)
74        elif isinstance(data, list):
75            return self._encode_list(data, path)
76        elif isinstance(data, six.binary_type):
77            return self._encode_bytes(data, path)
78        else:
79            return data, []
80
81    def _encode_list(self, data, path):
82        """Encode any bytes in a list, noting the index of what is encoded."""
83        new_data = []
84        encoded = []
85        for i, value in enumerate(data):
86            new_path = path + [i]
87            new_value, new_encoded = self._encode(value, new_path)
88            new_data.append(new_value)
89            encoded.extend(new_encoded)
90        return new_data, encoded
91
92    def _encode_dict(self, data, path):
93        """Encode any bytes in a dict, noting the index of what is encoded."""
94        new_data = {}
95        encoded = []
96        for key, value in data.items():
97            new_path = path + [key]
98            new_value, new_encoded = self._encode(value, new_path)
99            new_data[key] = new_value
100            encoded.extend(new_encoded)
101        return new_data, encoded
102
103    def _encode_bytes(self, data, path):
104        """Base64 encode a byte string."""
105        return base64.b64encode(data).decode('utf-8'), [path]
106
107
108class TokenDecoder(object):
109    """Decodes token strings back into dictionaries.
110
111    This performs the inverse operation to the TokenEncoder, accepting
112    opaque strings and decoding them into a useable form.
113    """
114
115    def decode(self, token):
116        """Decodes an opaque string to a dictionary.
117
118        :type token: str
119        :param token: A token string given by the botocore pagination
120            interface.
121
122        :rtype: dict
123        :returns: A dictionary containing pagination information,
124            particularly the service pagination token(s) but also other boto
125            metadata.
126        """
127        json_string = base64.b64decode(token.encode('utf-8')).decode('utf-8')
128        decoded_token = json.loads(json_string)
129
130        # Remove the encoding metadata as it is read since it will no longer
131        # be needed.
132        encoded_keys = decoded_token.pop('boto_encoded_keys', None)
133        if encoded_keys is None:
134            return decoded_token
135        else:
136            return self._decode(decoded_token, encoded_keys)
137
138    def _decode(self, token, encoded_keys):
139        """Find each encoded value and decode it."""
140        for key in encoded_keys:
141            encoded = self._path_get(token, key)
142            decoded = base64.b64decode(encoded.encode('utf-8'))
143            self._path_set(token, key, decoded)
144        return token
145
146    def _path_get(self, data, path):
147        """Return the nested data at the given path.
148
149        For instance:
150            data = {'foo': ['bar', 'baz']}
151            path = ['foo', 0]
152            ==> 'bar'
153        """
154        # jmespath isn't used here because it would be difficult to actually
155        # create the jmespath query when taking all of the unknowns of key
156        # structure into account. Gross though this is, it is simple and not
157        # very error prone.
158        d = data
159        for step in path:
160            d = d[step]
161        return d
162
163    def _path_set(self, data, path, value):
164        """Set the value of a key in the given data.
165
166        Example:
167            data = {'foo': ['bar', 'baz']}
168            path = ['foo', 1]
169            value = 'bin'
170            ==> data = {'foo': ['bar', 'bin']}
171        """
172        container = self._path_get(data, path[:-1])
173        container[path[-1]] = value
174
175
176class PaginatorModel(object):
177    def __init__(self, paginator_config):
178        self._paginator_config = paginator_config['pagination']
179
180    def get_paginator(self, operation_name):
181        try:
182            single_paginator_config = self._paginator_config[operation_name]
183        except KeyError:
184            raise ValueError("Paginator for operation does not exist: %s"
185                             % operation_name)
186        return single_paginator_config
187
188
189class PageIterator(object):
190    def __init__(self, method, input_token, output_token, more_results,
191                 result_keys, non_aggregate_keys, limit_key, max_items,
192                 starting_token, page_size, op_kwargs):
193        self._method = method
194        self._input_token = input_token
195        self._output_token = output_token
196        self._more_results = more_results
197        self._result_keys = result_keys
198        self._max_items = max_items
199        self._limit_key = limit_key
200        self._starting_token = starting_token
201        self._page_size = page_size
202        self._op_kwargs = op_kwargs
203        self._resume_token = None
204        self._non_aggregate_key_exprs = non_aggregate_keys
205        self._non_aggregate_part = {}
206        self._token_encoder = TokenEncoder()
207        self._token_decoder = TokenDecoder()
208
209    @property
210    def result_keys(self):
211        return self._result_keys
212
213    @property
214    def resume_token(self):
215        """Token to specify to resume pagination."""
216        return self._resume_token
217
218    @resume_token.setter
219    def resume_token(self, value):
220        if not isinstance(value, dict):
221            raise ValueError("Bad starting token: %s" % value)
222
223        if 'boto_truncate_amount' in value:
224            token_keys = sorted(self._input_token + ['boto_truncate_amount'])
225        else:
226            token_keys = sorted(self._input_token)
227        dict_keys = sorted(value.keys())
228
229        if token_keys == dict_keys:
230            self._resume_token = self._token_encoder.encode(value)
231        else:
232            raise ValueError("Bad starting token: %s" % value)
233
234    @property
235    def non_aggregate_part(self):
236        return self._non_aggregate_part
237
238    def __iter__(self):
239        current_kwargs = self._op_kwargs
240        previous_next_token = None
241        next_token = dict((key, None) for key in self._input_token)
242        if self._starting_token is not None:
243            # If the starting token exists, populate the next_token with the
244            # values inside it. This ensures that we have the service's
245            # pagination token on hand if we need to truncate after the
246            # first response.
247            next_token = self._parse_starting_token()[0]
248        # The number of items from result_key we've seen so far.
249        total_items = 0
250        first_request = True
251        primary_result_key = self.result_keys[0]
252        starting_truncation = 0
253        self._inject_starting_params(current_kwargs)
254        while True:
255            response = self._make_request(current_kwargs)
256            parsed = self._extract_parsed_response(response)
257            if first_request:
258                # The first request is handled differently.  We could
259                # possibly have a resume/starting token that tells us where
260                # to index into the retrieved page.
261                if self._starting_token is not None:
262                    starting_truncation = self._handle_first_request(
263                        parsed, primary_result_key, starting_truncation)
264                first_request = False
265                self._record_non_aggregate_key_values(parsed)
266            else:
267                # If this isn't the first request, we have already sliced into
268                # the first request and had to make additional requests after.
269                # We no longer need to add this to truncation.
270                starting_truncation = 0
271            current_response = primary_result_key.search(parsed)
272            if current_response is None:
273                current_response = []
274            num_current_response = len(current_response)
275            truncate_amount = 0
276            if self._max_items is not None:
277                truncate_amount = (total_items + num_current_response) \
278                                  - self._max_items
279            if truncate_amount > 0:
280                self._truncate_response(parsed, primary_result_key,
281                                        truncate_amount, starting_truncation,
282                                        next_token)
283                yield response
284                break
285            else:
286                yield response
287                total_items += num_current_response
288                next_token = self._get_next_token(parsed)
289                if all(t is None for t in next_token.values()):
290                    break
291                if self._max_items is not None and \
292                        total_items == self._max_items:
293                    # We're on a page boundary so we can set the current
294                    # next token to be the resume token.
295                    self.resume_token = next_token
296                    break
297                if previous_next_token is not None and \
298                        previous_next_token == next_token:
299                    message = ("The same next token was received "
300                               "twice: %s" % next_token)
301                    raise PaginationError(message=message)
302                self._inject_token_into_kwargs(current_kwargs, next_token)
303                previous_next_token = next_token
304
305    def search(self, expression):
306        """Applies a JMESPath expression to a paginator
307
308        Each page of results is searched using the provided JMESPath
309        expression. If the result is not a list, it is yielded
310        directly. If the result is a list, each element in the result
311        is yielded individually (essentially implementing a flatmap in
312        which the JMESPath search is the mapping function).
313
314        :type expression: str
315        :param expression: JMESPath expression to apply to each page.
316
317        :return: Returns an iterator that yields the individual
318            elements of applying a JMESPath expression to each page of
319            results.
320        """
321        compiled = jmespath.compile(expression)
322        for page in self:
323            results = compiled.search(page)
324            if isinstance(results, list):
325                for element in results:
326                    yield element
327            else:
328                # Yield result directly if it is not a list.
329                yield results
330
331    def _make_request(self, current_kwargs):
332        return self._method(**current_kwargs)
333
334    def _extract_parsed_response(self, response):
335        return response
336
337    def _record_non_aggregate_key_values(self, response):
338        non_aggregate_keys = {}
339        for expression in self._non_aggregate_key_exprs:
340            result = expression.search(response)
341            set_value_from_jmespath(non_aggregate_keys,
342                                    expression.expression,
343                                    result)
344        self._non_aggregate_part = non_aggregate_keys
345
346    def _inject_starting_params(self, op_kwargs):
347        # If the user has specified a starting token we need to
348        # inject that into the operation's kwargs.
349        if self._starting_token is not None:
350            # Don't need to do anything special if there is no starting
351            # token specified.
352            next_token = self._parse_starting_token()[0]
353            self._inject_token_into_kwargs(op_kwargs, next_token)
354        if self._page_size is not None:
355            # Pass the page size as the parameter name for limiting
356            # page size, also known as the limit_key.
357            op_kwargs[self._limit_key] = self._page_size
358
359    def _inject_token_into_kwargs(self, op_kwargs, next_token):
360        for name, token in next_token.items():
361            if (token is not None) and (token != 'None'):
362                op_kwargs[name] = token
363            elif name in op_kwargs:
364                del op_kwargs[name]
365
366    def _handle_first_request(self, parsed, primary_result_key,
367                              starting_truncation):
368        # If the payload is an array or string, we need to slice into it
369        # and only return the truncated amount.
370        starting_truncation = self._parse_starting_token()[1]
371        all_data = primary_result_key.search(parsed)
372        if isinstance(all_data, (list, six.string_types)):
373            data = all_data[starting_truncation:]
374        else:
375            data = None
376        set_value_from_jmespath(
377            parsed,
378            primary_result_key.expression,
379            data
380        )
381        # We also need to truncate any secondary result keys
382        # because they were not truncated in the previous last
383        # response.
384        for token in self.result_keys:
385            if token == primary_result_key:
386                continue
387            sample = token.search(parsed)
388            if isinstance(sample, list):
389                empty_value = []
390            elif isinstance(sample, six.string_types):
391                empty_value = ''
392            elif isinstance(sample, (int, float)):
393                empty_value = 0
394            else:
395                empty_value = None
396            set_value_from_jmespath(parsed, token.expression, empty_value)
397        return starting_truncation
398
399    def _truncate_response(self, parsed, primary_result_key, truncate_amount,
400                           starting_truncation, next_token):
401        original = primary_result_key.search(parsed)
402        if original is None:
403            original = []
404        amount_to_keep = len(original) - truncate_amount
405        truncated = original[:amount_to_keep]
406        set_value_from_jmespath(
407            parsed,
408            primary_result_key.expression,
409            truncated
410        )
411        # The issue here is that even though we know how much we've truncated
412        # we need to account for this globally including any starting
413        # left truncation. For example:
414        # Raw response: [0,1,2,3]
415        # Starting index: 1
416        # Max items: 1
417        # Starting left truncation: [1, 2, 3]
418        # End right truncation for max items: [1]
419        # However, even though we only kept 1, this is post
420        # left truncation so the next starting index should be 2, not 1
421        # (left_truncation + amount_to_keep).
422        next_token['boto_truncate_amount'] = \
423            amount_to_keep + starting_truncation
424        self.resume_token = next_token
425
426    def _get_next_token(self, parsed):
427        if self._more_results is not None:
428            if not self._more_results.search(parsed):
429                return {}
430        next_tokens = {}
431        for output_token, input_key in \
432                zip(self._output_token, self._input_token):
433            next_token = output_token.search(parsed)
434            # We do not want to include any empty strings as actual tokens.
435            # Treat them as None.
436            if next_token:
437                next_tokens[input_key] = next_token
438            else:
439                next_tokens[input_key] = None
440        return next_tokens
441
442    def result_key_iters(self):
443        teed_results = tee(self, len(self.result_keys))
444        return [ResultKeyIterator(i, result_key) for i, result_key
445                in zip(teed_results, self.result_keys)]
446
447    def build_full_result(self):
448        complete_result = {}
449        for response in self:
450            page = response
451            # We want to try to catch operation object pagination
452            # and format correctly for those. They come in the form
453            # of a tuple of two elements: (http_response, parsed_responsed).
454            # We want the parsed_response as that is what the page iterator
455            # uses. We can remove it though once operation objects are removed.
456            if isinstance(response, tuple) and len(response) == 2:
457                page = response[1]
458            # We're incrementally building the full response page
459            # by page.  For each page in the response we need to
460            # inject the necessary components from the page
461            # into the complete_result.
462            for result_expression in self.result_keys:
463                # In order to incrementally update a result key
464                # we need to search the existing value from complete_result,
465                # then we need to search the _current_ page for the
466                # current result key value.  Then we append the current
467                # value onto the existing value, and re-set that value
468                # as the new value.
469                result_value = result_expression.search(page)
470                if result_value is None:
471                    continue
472                existing_value = result_expression.search(complete_result)
473                if existing_value is None:
474                    # Set the initial result
475                    set_value_from_jmespath(
476                        complete_result, result_expression.expression,
477                        result_value)
478                    continue
479                # Now both result_value and existing_value contain something
480                if isinstance(result_value, list):
481                    existing_value.extend(result_value)
482                elif isinstance(result_value, (int, float, six.string_types)):
483                    # Modify the existing result with the sum or concatenation
484                    set_value_from_jmespath(
485                        complete_result, result_expression.expression,
486                        existing_value + result_value)
487        merge_dicts(complete_result, self.non_aggregate_part)
488        if self.resume_token is not None:
489            complete_result['NextToken'] = self.resume_token
490        return complete_result
491
492    def _parse_starting_token(self):
493        if self._starting_token is None:
494            return None
495
496        # The starting token is a dict passed as a base64 encoded string.
497        next_token = self._starting_token
498        try:
499            next_token = self._token_decoder.decode(next_token)
500            index = 0
501            if 'boto_truncate_amount' in next_token:
502                index = next_token.get('boto_truncate_amount')
503                del next_token['boto_truncate_amount']
504        except (ValueError, TypeError):
505            next_token, index = self._parse_starting_token_deprecated()
506        return next_token, index
507
508    def _parse_starting_token_deprecated(self):
509        """
510        This handles parsing of old style starting tokens, and attempts to
511        coerce them into the new style.
512        """
513        log.debug("Attempting to fall back to old starting token parser. For "
514                  "token: %s" % self._starting_token)
515        if self._starting_token is None:
516            return None
517
518        parts = self._starting_token.split('___')
519        next_token = []
520        index = 0
521        if len(parts) == len(self._input_token) + 1:
522            try:
523                index = int(parts.pop())
524            except ValueError:
525                # This doesn't look like a valid old-style token, so we're
526                # passing it along as an opaque service token.
527                parts = [self._starting_token]
528
529        for part in parts:
530            if part == 'None':
531                next_token.append(None)
532            else:
533                next_token.append(part)
534        return self._convert_deprecated_starting_token(next_token), index
535
536    def _convert_deprecated_starting_token(self, deprecated_token):
537        """
538        This attempts to convert a deprecated starting token into the new
539        style.
540        """
541        len_deprecated_token = len(deprecated_token)
542        len_input_token = len(self._input_token)
543        if len_deprecated_token > len_input_token:
544            raise ValueError("Bad starting token: %s" % self._starting_token)
545        elif len_deprecated_token < len_input_token:
546            log.debug("Old format starting token does not contain all input "
547                      "tokens. Setting the rest, in order, as None.")
548            for i in range(len_input_token - len_deprecated_token):
549                deprecated_token.append(None)
550        return dict(zip(self._input_token, deprecated_token))
551
552
553class Paginator(object):
554    PAGE_ITERATOR_CLS = PageIterator
555
556    def __init__(self, method, pagination_config, model):
557        self._model = model
558        self._method = method
559        self._pagination_cfg = pagination_config
560        self._output_token = self._get_output_tokens(self._pagination_cfg)
561        self._input_token = self._get_input_tokens(self._pagination_cfg)
562        self._more_results = self._get_more_results_token(self._pagination_cfg)
563        self._non_aggregate_keys = self._get_non_aggregate_keys(
564            self._pagination_cfg)
565        self._result_keys = self._get_result_keys(self._pagination_cfg)
566        self._limit_key = self._get_limit_key(self._pagination_cfg)
567
568    @property
569    def result_keys(self):
570        return self._result_keys
571
572    def _get_non_aggregate_keys(self, config):
573        keys = []
574        for key in config.get('non_aggregate_keys', []):
575            keys.append(jmespath.compile(key))
576        return keys
577
578    def _get_output_tokens(self, config):
579        output = []
580        output_token = config['output_token']
581        if not isinstance(output_token, list):
582            output_token = [output_token]
583        for config in output_token:
584            output.append(jmespath.compile(config))
585        return output
586
587    def _get_input_tokens(self, config):
588        input_token = self._pagination_cfg['input_token']
589        if not isinstance(input_token, list):
590            input_token = [input_token]
591        return input_token
592
593    def _get_more_results_token(self, config):
594        more_results = config.get('more_results')
595        if more_results is not None:
596            return jmespath.compile(more_results)
597
598    def _get_result_keys(self, config):
599        result_key = config.get('result_key')
600        if result_key is not None:
601            if not isinstance(result_key, list):
602                result_key = [result_key]
603            result_key = [jmespath.compile(rk) for rk in result_key]
604            return result_key
605
606    def _get_limit_key(self, config):
607        return config.get('limit_key')
608
609    def paginate(self, **kwargs):
610        """Create paginator object for an operation.
611
612        This returns an iterable object.  Iterating over
613        this object will yield a single page of a response
614        at a time.
615
616        """
617        page_params = self._extract_paging_params(kwargs)
618        return self.PAGE_ITERATOR_CLS(
619            self._method, self._input_token,
620            self._output_token, self._more_results,
621            self._result_keys, self._non_aggregate_keys,
622            self._limit_key,
623            page_params['MaxItems'],
624            page_params['StartingToken'],
625            page_params['PageSize'],
626            kwargs)
627
628    def _extract_paging_params(self, kwargs):
629        pagination_config = kwargs.pop('PaginationConfig', {})
630        max_items = pagination_config.get('MaxItems', None)
631        if max_items is not None:
632            max_items = int(max_items)
633        page_size = pagination_config.get('PageSize', None)
634        if page_size is not None:
635            if self._limit_key is None:
636                raise PaginationError(
637                    message="PageSize parameter is not supported for the "
638                            "pagination interface for this operation.")
639            input_members = self._model.input_shape.members
640            limit_key_shape = input_members.get(self._limit_key)
641            if limit_key_shape.type_name == 'string':
642                if not isinstance(page_size, six.string_types):
643                    page_size = str(page_size)
644            else:
645                page_size = int(page_size)
646        return {
647            'MaxItems': max_items,
648            'StartingToken': pagination_config.get('StartingToken', None),
649            'PageSize': page_size,
650        }
651
652
653class ResultKeyIterator(object):
654    """Iterates over the results of paginated responses.
655
656    Each iterator is associated with a single result key.
657    Iterating over this object will give you each element in
658    the result key list.
659
660    :param pages_iterator: An iterator that will give you
661        pages of results (a ``PageIterator`` class).
662    :param result_key: The JMESPath expression representing
663        the result key.
664
665    """
666
667    def __init__(self, pages_iterator, result_key):
668        self._pages_iterator = pages_iterator
669        self.result_key = result_key
670
671    def __iter__(self):
672        for page in self._pages_iterator:
673            results = self.result_key.search(page)
674            if results is None:
675                results = []
676            for result in results:
677                yield result
678