1"""Backport of urllib3's HTTPHeaderDict for older versions of requests.
2
3This code was originally licensed under the MIT License and is copyrighted by
4Andrey Petrov and contributors to urllib3.
5
6This version was imported from:
7https://github.com/shazow/urllib3/blob/3bd63406bef7c16d007c17563b6af14582567d4b/urllib3/_collections.py
8"""
9import sys
10
11from collections import Mapping, MutableMapping
12
13__all__ = ('HTTPHeaderDict',)
14
15PY3 = sys.version_info >= (3, 0)
16
17
18class HTTPHeaderDict(MutableMapping):
19    """
20    :param headers:
21        An iterable of field-value pairs. Must not contain multiple field names
22        when compared case-insensitively.
23
24    :param kwargs:
25        Additional field-value pairs to pass in to ``dict.update``.
26
27    A ``dict`` like container for storing HTTP Headers.
28
29    Field names are stored and compared case-insensitively in compliance with
30    RFC 7230. Iteration provides the first case-sensitive key seen for each
31    case-insensitive pair.
32
33    Using ``__setitem__`` syntax overwrites fields that compare equal
34    case-insensitively in order to maintain ``dict``'s api. For fields that
35    compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add``
36    in a loop.
37
38    If multiple fields that are equal case-insensitively are passed to the
39    constructor or ``.update``, the behavior is undefined and some will be
40    lost.
41
42    >>> headers = HTTPHeaderDict()
43    >>> headers.add('Set-Cookie', 'foo=bar')
44    >>> headers.add('set-cookie', 'baz=quxx')
45    >>> headers['content-length'] = '7'
46    >>> headers['SET-cookie']
47    'foo=bar, baz=quxx'
48    >>> headers['Content-Length']
49    '7'
50    """
51
52    def __init__(self, headers=None, **kwargs):
53        super(HTTPHeaderDict, self).__init__()
54        self._container = {}
55        if headers is not None:
56            if isinstance(headers, HTTPHeaderDict):
57                self._copy_from(headers)
58            else:
59                self.extend(headers)
60        if kwargs:
61            self.extend(kwargs)
62
63    def __setitem__(self, key, val):
64        self._container[key.lower()] = (key, val)
65        return self._container[key.lower()]
66
67    def __getitem__(self, key):
68        val = self._container[key.lower()]
69        return ', '.join(val[1:])
70
71    def __delitem__(self, key):
72        del self._container[key.lower()]
73
74    def __contains__(self, key):
75        return key.lower() in self._container
76
77    def __eq__(self, other):
78        if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
79            return False
80        if not isinstance(other, type(self)):
81            other = type(self)(other)
82        return (dict((k.lower(), v) for k, v in self.itermerged()) ==
83                dict((k.lower(), v) for k, v in other.itermerged()))
84
85    def __ne__(self, other):
86        return not self.__eq__(other)
87
88    if not PY3:  # Python 2
89        iterkeys = MutableMapping.iterkeys
90        itervalues = MutableMapping.itervalues
91
92    __marker = object()
93
94    def __len__(self):
95        return len(self._container)
96
97    def __iter__(self):
98        # Only provide the originally cased names
99        for vals in self._container.values():
100            yield vals[0]
101
102    def pop(self, key, default=__marker):
103        """D.pop(k[,d]) -> v, remove specified key and return the value.
104
105        If key is not found, d is returned if given, otherwise KeyError is
106        raised.
107        """
108        # Using the MutableMapping function directly fails due to the private
109        # marker.
110        # Using ordinary dict.pop would expose the internal structures.
111        # So let's reinvent the wheel.
112        try:
113            value = self[key]
114        except KeyError:
115            if default is self.__marker:
116                raise
117            return default
118        else:
119            del self[key]
120            return value
121
122    def discard(self, key):
123        try:
124            del self[key]
125        except KeyError:
126            pass
127
128    def add(self, key, val):
129        """Adds a (name, value) pair, doesn't overwrite the value if it already
130        exists.
131
132        >>> headers = HTTPHeaderDict(foo='bar')
133        >>> headers.add('Foo', 'baz')
134        >>> headers['foo']
135        'bar, baz'
136        """
137        key_lower = key.lower()
138        new_vals = key, val
139        # Keep the common case aka no item present as fast as possible
140        vals = self._container.setdefault(key_lower, new_vals)
141        if new_vals is not vals:
142            # new_vals was not inserted, as there was a previous one
143            if isinstance(vals, list):
144                # If already several items got inserted, we have a list
145                vals.append(val)
146            else:
147                # vals should be a tuple then, i.e. only one item so far
148                # Need to convert the tuple to list for further extension
149                self._container[key_lower] = [vals[0], vals[1], val]
150
151    def extend(self, *args, **kwargs):
152        """Generic import function for any type of header-like object.
153        Adapted version of MutableMapping.update in order to insert items
154        with self.add instead of self.__setitem__
155        """
156        if len(args) > 1:
157            raise TypeError("extend() takes at most 1 positional "
158                            "arguments ({0} given)".format(len(args)))
159        other = args[0] if len(args) >= 1 else ()
160
161        if isinstance(other, HTTPHeaderDict):
162            for key, val in other.iteritems():
163                self.add(key, val)
164        elif isinstance(other, Mapping):
165            for key in other:
166                self.add(key, other[key])
167        elif hasattr(other, "keys"):
168            for key in other.keys():
169                self.add(key, other[key])
170        else:
171            for key, value in other:
172                self.add(key, value)
173
174        for key, value in kwargs.items():
175            self.add(key, value)
176
177    def getlist(self, key):
178        """Returns a list of all the values for the named field. Returns an
179        empty list if the key doesn't exist."""
180        try:
181            vals = self._container[key.lower()]
182        except KeyError:
183            return []
184        else:
185            if isinstance(vals, tuple):
186                return [vals[1]]
187            else:
188                return vals[1:]
189
190    # Backwards compatibility for httplib
191    getheaders = getlist
192    getallmatchingheaders = getlist
193    iget = getlist
194
195    def __repr__(self):
196        return "%s(%s)" % (type(self).__name__, dict(self.itermerged()))
197
198    def _copy_from(self, other):
199        for key in other:
200            val = other.getlist(key)
201            if isinstance(val, list):
202                # Don't need to convert tuples
203                val = list(val)
204            self._container[key.lower()] = [key] + val
205
206    def copy(self):
207        clone = type(self)()
208        clone._copy_from(self)
209        return clone
210
211    def iteritems(self):
212        """Iterate over all header lines, including duplicate ones."""
213        for key in self:
214            vals = self._container[key.lower()]
215            for val in vals[1:]:
216                yield vals[0], val
217
218    def itermerged(self):
219        """Iterate over all headers, merging duplicate ones together."""
220        for key in self:
221            val = self._container[key.lower()]
222            yield val[0], ', '.join(val[1:])
223
224    def items(self):
225        return list(self.iteritems())
226
227    @classmethod
228    def from_httplib(cls, message):  # Python 2
229        """Read headers from a Python 2 httplib message object."""
230        # python2.7 does not expose a proper API for exporting multiheaders
231        # efficiently. This function re-reads raw lines from the message
232        # object and extracts the multiheaders properly.
233        headers = []
234
235        for line in message.headers:
236            if line.startswith((' ', '\t')):
237                key, value = headers[-1]
238                headers[-1] = (key, value + '\r\n' + line.rstrip())
239                continue
240
241            key, value = line.split(':', 1)
242            headers.append((key, value.strip()))
243
244        return cls(headers)
245