1from base64 import b64encode, urlsafe_b64encode
2import calendar
3import hashlib
4import hmac
5import logging
6import math
7import os
8import pprint
9import re
10import sys
11import time
12
13import six
14
15from .exc import (
16    BadHeaderValue,
17    HawkFail,
18    InvalidCredentials)
19
20
21HAWK_VER = 1
22log = logging.getLogger(__name__)
23allowable_header_keys = set(['id', 'ts', 'tsm', 'nonce', 'hash',
24                             'error', 'ext', 'mac', 'app', 'dlg'])
25
26
27def validate_credentials(creds):
28    if not hasattr(creds, '__getitem__'):
29        raise InvalidCredentials('credentials must be a dict-like object')
30    try:
31        creds['id']
32        creds['key']
33        creds['algorithm']
34    except KeyError:
35        etype, val, tb = sys.exc_info()
36        raise InvalidCredentials('{etype}: {val}'
37                                 .format(etype=etype, val=val))
38
39
40def random_string(length):
41    """Generates a random string for a given length."""
42    # this conservatively gets 8*length bits and then returns 6*length of
43    # them. Grabbing (6/8)*length bits could lose some entropy off the ends.
44    return urlsafe_b64encode(os.urandom(length))[:length]
45
46
47def calculate_payload_hash(payload, algorithm, content_type):
48    """Calculates a hash for a given payload."""
49    p_hash = hashlib.new(algorithm)
50
51    parts = []
52    parts.append('hawk.' + str(HAWK_VER) + '.payload\n')
53    parts.append(parse_content_type(content_type) + '\n')
54    parts.append(payload or '')
55    parts.append('\n')
56
57    for i, p in enumerate(parts):
58        # Make sure we are about to hash binary strings.
59        if not isinstance(p, six.binary_type):
60            p = p.encode('utf8')
61        p_hash.update(p)
62        parts[i] = p
63
64    log.debug('calculating payload hash from:\n{parts}'
65              .format(parts=pprint.pformat(parts)))
66
67    return b64encode(p_hash.digest())
68
69
70def calculate_mac(mac_type, resource, content_hash):
71    """Calculates a message authorization code (MAC)."""
72    normalized = normalize_string(mac_type, resource, content_hash)
73    log.debug(u'normalized resource for mac calc: {norm}'
74              .format(norm=normalized))
75    digestmod = getattr(hashlib, resource.credentials['algorithm'])
76
77    # Make sure we are about to hash binary strings.
78
79    if not isinstance(normalized, six.binary_type):
80        normalized = normalized.encode('utf8')
81    key = resource.credentials['key']
82    if not isinstance(key, six.binary_type):
83        key = key.encode('ascii')
84
85    result = hmac.new(key, normalized, digestmod)
86    return b64encode(result.digest())
87
88
89def calculate_ts_mac(ts, credentials):
90    """Calculates a message authorization code (MAC) for a timestamp."""
91    normalized = ('hawk.{hawk_ver}.ts\n{ts}\n'
92                  .format(hawk_ver=HAWK_VER, ts=ts))
93    log.debug(u'normalized resource for ts mac calc: {norm}'
94              .format(norm=normalized))
95    digestmod = getattr(hashlib, credentials['algorithm'])
96
97    if not isinstance(normalized, six.binary_type):
98        normalized = normalized.encode('utf8')
99    key = credentials['key']
100    if not isinstance(key, six.binary_type):
101        key = key.encode('ascii')
102
103    result = hmac.new(key, normalized, digestmod)
104    return b64encode(result.digest())
105
106
107def normalize_string(mac_type, resource, content_hash):
108    """Serializes mac_type and resource into a HAWK string."""
109
110    normalized = [
111        'hawk.' + str(HAWK_VER) + '.' + mac_type,
112        normalize_header_attr(resource.timestamp),
113        normalize_header_attr(resource.nonce),
114        normalize_header_attr(resource.method or ''),
115        normalize_header_attr(resource.name or ''),
116        normalize_header_attr(resource.host),
117        normalize_header_attr(resource.port),
118        normalize_header_attr(content_hash or '')
119    ]
120
121    # The blank lines are important. They follow what the Node Hawk lib does.
122
123    normalized.append(normalize_header_attr(resource.ext or ''))
124
125    if resource.app:
126        normalized.append(normalize_header_attr(resource.app))
127        normalized.append(normalize_header_attr(resource.dlg or ''))
128
129    # Add trailing new line.
130    normalized.append('')
131
132    normalized = '\n'.join(normalized)
133
134    return normalized
135
136
137def parse_content_type(content_type):
138    """Cleans up content_type."""
139    if content_type:
140        return content_type.split(';')[0].strip().lower()
141    else:
142        return ''
143
144
145def parse_authorization_header(auth_header):
146    """
147    Example Authorization header:
148
149        'Hawk id="dh37fgj492je", ts="1367076201", nonce="NPHgnG", ext="and
150        welcome!", mac="CeWHy4d9kbLGhDlkyw2Nh3PJ7SDOdZDa267KH4ZaNMY="'
151    """
152    attributes = {}
153
154    # Make sure we have a unicode object for consistency.
155    if isinstance(auth_header, six.binary_type):
156        auth_header = auth_header.decode('utf8')
157
158    parts = auth_header.split(',')
159    auth_scheme_parts = parts[0].split(' ')
160    if 'hawk' != auth_scheme_parts[0].lower():
161        raise HawkFail("Unknown scheme '{scheme}' when parsing header"
162                       .format(scheme=auth_scheme_parts[0].lower()))
163
164    # Replace 'Hawk key: value' with 'key: value'
165    # which matches the rest of parts
166    parts[0] = auth_scheme_parts[1]
167
168    for part in parts:
169        attr_parts = part.split('=')
170        key = attr_parts[0].strip()
171        if key not in allowable_header_keys:
172            raise HawkFail("Unknown Hawk key '{key}' when parsing header"
173                           .format(key=key))
174
175        if len(attr_parts) > 2:
176            attr_parts[1] = '='.join(attr_parts[1:])
177
178        # Chop of quotation marks
179        value = attr_parts[1]
180
181        if attr_parts[1].find('"') == 0:
182            value = attr_parts[1][1:]
183
184        if value.find('"') > -1:
185            value = value[0:-1]
186
187        validate_header_attr(value, name=key)
188        value = unescape_header_attr(value)
189        attributes[key] = value
190
191    log.debug('parsed Hawk header: {header} into: \n{parsed}'
192              .format(header=auth_header, parsed=pprint.pformat(attributes)))
193    return attributes
194
195
196def strings_match(a, b):
197    # Constant time string comparision, mitigates side channel attacks.
198    if len(a) != len(b):
199        return False
200    result = 0
201
202    def byte_ints(buf):
203        for ch in buf:
204            # In Python 3, if we have a bytes object, iterating it will
205            # already get the integer value. In older pythons, we need
206            # to use ord().
207            if not isinstance(ch, int):
208                ch = ord(ch)
209            yield ch
210
211    for x, y in zip(byte_ints(a), byte_ints(b)):
212        result |= x ^ y
213    return result == 0
214
215
216def utc_now(offset_in_seconds=0.0):
217    # TODO: add support for SNTP server? See ntplib module.
218    return int(math.floor(calendar.timegm(time.gmtime()) +
219                          float(offset_in_seconds)))
220
221
222# Allowed value characters:
223# !#$%&'()*+,-./:;<=>?@[]^_`{|}~ and space, a-z, A-Z, 0-9, \, "
224_header_attribute_chars = re.compile(
225    r"^[ a-zA-Z0-9_\!#\$%&'\(\)\*\+,\-\./\:;<\=>\?@\[\]\^`\{\|\}~\"\\]*$")
226
227
228def validate_header_attr(val, name=None):
229    if not _header_attribute_chars.match(val):
230        raise BadHeaderValue('header value name={name} value={val} '
231                             'contained an illegal character'
232                             .format(name=name or '?', val=repr(val)))
233
234
235def escape_header_attr(val):
236
237    # Ensure we are working with Unicode for consistency.
238    if isinstance(val, six.binary_type):
239        val = val.decode('utf8')
240
241    # Escape quotes and slash like the hawk reference code.
242    val = val.replace('\\', '\\\\')
243    val = val.replace('"', '\\"')
244    val = val.replace('\n', '\\n')
245    return val
246
247
248def unescape_header_attr(val):
249    # Un-do the hawk escaping.
250    val = val.replace('\\n', '\n')
251    val = val.replace('\\\\', '\\').replace('\\"', '"')
252    return val
253
254
255def prepare_header_val(val):
256    val = escape_header_attr(val)
257    validate_header_attr(val)
258    return val
259
260
261def normalize_header_attr(val):
262    if not val:
263        val = ''
264
265    # Normalize like the hawk reference code.
266    val = escape_header_attr(val)
267    return val
268