1"""Copyright 2008 Orbitz WorldWide
2   Copyright 2011 Chris Davis
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8   http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License."""
15
16from hashlib import md5
17from itertools import chain
18import bisect
19import sys
20
21try:
22    import mmh3
23except ImportError:
24    mmh3 = None
25
26try:
27    import pyhash
28    hasher = pyhash.fnv1a_32()
29
30    def fnv32a(data, seed=0x811c9dc5):
31        return hasher(data, seed=seed)
32except ImportError:
33    def fnv32a(data, seed=0x811c9dc5):
34        """
35        FNV-1a Hash (http://isthe.com/chongo/tech/comp/fnv/) in Python.
36        Taken from https://gist.github.com/vaiorabbit/5670985
37        """
38        hval = seed
39        fnv_32_prime = 0x01000193
40        uint32_max = 2 ** 32
41        if sys.version_info >= (3, 0):
42            # data is a bytes object, s is an integer
43            for s in data:
44                hval = hval ^ s
45                hval = (hval * fnv_32_prime) % uint32_max
46        else:
47            # data is an str object, s is a single character
48            for s in data:
49                hval = hval ^ ord(s)
50                hval = (hval * fnv_32_prime) % uint32_max
51        return hval
52
53
54def hashRequest(request):
55    # Normalize the request parameters to ensure we're deterministic
56    queryParams = [
57        "%s=%s" % (key, '&'.join(values))
58        for (key,values) in chain(request.POST.lists(), request.GET.lists())
59        if not key.startswith('_')
60    ]
61    normalizedParams = ','.join( sorted(queryParams) )
62    return compactHash(normalizedParams)
63
64
65def hashData(targets, startTime, endTime, xFilesFactor):
66    targetsString = ','.join(sorted(targets))
67    startTimeString = startTime.strftime("%Y%m%d_%H%M")
68    endTimeString = endTime.strftime("%Y%m%d_%H%M")
69    myHash = targetsString + '@' + startTimeString + ':' + endTimeString + ':' + str(xFilesFactor)
70    return compactHash(myHash)
71
72
73def compactHash(string):
74    return md5(string.encode('utf-8')).hexdigest()
75
76
77def carbonHash(key, hash_type):
78    if hash_type == 'fnv1a_ch':
79        big_hash = int(fnv32a(key.encode('utf-8')))
80        small_hash = (big_hash >> 16) ^ (big_hash & 0xffff)
81    elif hash_type == 'mmh3_ch':
82        if mmh3 is None:
83            raise Exception('Install "mmh3" to use this hashing function.')
84        small_hash = mmh3.hash(key)
85    else:
86        big_hash = compactHash(key)
87        small_hash = int(big_hash[:4], 16)
88    return small_hash
89
90
91class ConsistentHashRing:
92    def __init__(self, nodes, replica_count=100, hash_type='carbon_ch'):
93        self.ring = []
94        self.ring_len = len(self.ring)
95        self.nodes = set()
96        self.nodes_len = len(self.nodes)
97        self.replica_count = replica_count
98        self.hash_type = hash_type
99        for node in nodes:
100            self.add_node(node)
101
102    def compute_ring_position(self, key):
103        return carbonHash(key, self.hash_type)
104
105    def add_node(self, key):
106        self.nodes.add(key)
107        self.nodes_len = len(self.nodes)
108        for i in range(self.replica_count):
109            if self.hash_type == 'fnv1a_ch':
110                replica_key = "%d-%s" % (i, key[1])
111            else:
112                replica_key = "%s:%d" % (key, i)
113            position = self.compute_ring_position(replica_key)
114            while position in [r[0] for r in self.ring]:
115                position = position + 1
116            entry = (position, key)
117            bisect.insort(self.ring, entry)
118        self.ring_len = len(self.ring)
119
120    def remove_node(self, key):
121        self.nodes.discard(key)
122        self.nodes_len = len(self.nodes)
123        self.ring = [entry for entry in self.ring if entry[1] != key]
124        self.ring_len = len(self.ring)
125
126    def get_node(self, key):
127        assert self.ring
128        position = self.compute_ring_position(key)
129        search_entry = (position, ())
130        index = bisect.bisect_left(self.ring, search_entry) % self.ring_len
131        entry = self.ring[index]
132        return entry[1]
133
134    def get_nodes(self, key):
135        nodes = set()
136        if not self.ring:
137            return
138        if self.nodes_len == 1:
139            for node in self.nodes:
140                yield node
141        position = self.compute_ring_position(key)
142        search_entry = (position, ())
143        index = bisect.bisect_left(self.ring, search_entry) % self.ring_len
144        last_index = (index - 1) % self.ring_len
145        nodes_len = len(nodes)
146        while nodes_len < self.nodes_len and index != last_index:
147            next_entry = self.ring[index]
148            (position, next_node) = next_entry
149            if next_node not in nodes:
150                nodes.add(next_node)
151                nodes_len += 1
152                yield next_node
153
154            index = (index + 1) % self.ring_len
155