1 // Copyright 2016 conhash-rs developers
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 use std::collections::{BTreeMap, HashMap};
10 
11 use md5;
12 
13 use node::Node;
14 
default_md5_hash_fn(input: &[u8]) -> Vec<u8>15 fn default_md5_hash_fn(input: &[u8]) -> Vec<u8> {
16     let digest = md5::compute(input);
17     digest.to_vec()
18 }
19 
20 /// Consistent Hash
21 pub struct ConsistentHash<N: Node> {
22     hash_fn: fn(&[u8]) -> Vec<u8>,
23     nodes: BTreeMap<Vec<u8>, N>,
24     replicas: HashMap<String, usize>,
25 }
26 
27 impl<N: Node> ConsistentHash<N> {
28     /// Construct with default hash function (Md5)
new() -> ConsistentHash<N>29     pub fn new() -> ConsistentHash<N> {
30         ConsistentHash::with_hash(default_md5_hash_fn)
31     }
32 
33     /// Construct with customized hash function
with_hash(hash_fn: fn(&[u8]) -> Vec<u8>) -> ConsistentHash<N>34     pub fn with_hash(hash_fn: fn(&[u8]) -> Vec<u8>) -> ConsistentHash<N> {
35         ConsistentHash {
36             hash_fn: hash_fn,
37             nodes: BTreeMap::new(),
38             replicas: HashMap::new(),
39         }
40     }
41 
42     /// Add a new node
add(&mut self, node: &N, num_replicas: usize)43     pub fn add(&mut self, node: &N, num_replicas: usize) {
44         let node_name = node.name();
45         debug!("Adding node {:?} with {} replicas", node_name, num_replicas);
46 
47         // Remove it first
48         self.remove(&node);
49 
50         self.replicas.insert(node_name.clone(), num_replicas);
51         for replica in 0..num_replicas {
52             let node_ident = format!("{}:{}", node_name, replica);
53             let key = (self.hash_fn)(node_ident.as_bytes());
54             debug!(
55                 "Adding node {:?} of replica {}, hashed key is {:?}",
56                 node.name(),
57                 replica,
58                 key
59             );
60 
61             self.nodes.insert(key, node.clone());
62         }
63     }
64 
65     /// Get a node by key. Return `None` if no valid node inside
get<'a>(&'a self, key: &[u8]) -> Option<&'a N>66     pub fn get<'a>(&'a self, key: &[u8]) -> Option<&'a N> {
67         let hashed_key = (self.hash_fn)(key);
68         debug!("Getting key {:?}, hashed key is {:?}", key, hashed_key);
69 
70         let mut first_one = None;
71         for (k, v) in self.nodes.iter() {
72             if hashed_key <= *k {
73                 debug!("Found node {:?}", v.name());
74                 return Some(v);
75             }
76 
77             if first_one.is_none() {
78                 first_one = Some(v);
79             }
80         }
81 
82         debug!("Search to the end, coming back to the head ...");
83         match first_one {
84             Some(ref v) => debug!("Found node {:?}", v.name()),
85             None => debug!("The container is empty"),
86         }
87         // Back to the first one
88         first_one
89     }
90 
91     /// Get a node by string key
get_str<'a>(&'a self, key: &str) -> Option<&'a N>92     pub fn get_str<'a>(&'a self, key: &str) -> Option<&'a N> {
93         self.get(key.as_bytes())
94     }
95 
96     /// Get a node by key. Return `None` if no valid node inside
get_mut<'a>(&'a mut self, key: &[u8]) -> Option<&'a mut N>97     pub fn get_mut<'a>(&'a mut self, key: &[u8]) -> Option<&'a mut N> {
98         let hashed_key = (self.hash_fn)(key);
99         debug!("Getting key {:?}, hashed key is {:?}", key, hashed_key);
100 
101         let mut first_one = None;
102         for (k, v) in self.nodes.iter_mut() {
103             if hashed_key <= *k {
104                 debug!("Found node {:?}", v.name());
105                 return Some(v);
106             }
107 
108             if first_one.is_none() {
109                 first_one = Some(v);
110             }
111         }
112 
113         debug!("Search to the end, coming back to the head ...");
114         match first_one {
115             Some(ref v) => debug!("Found node {:?}", v.name()),
116             None => debug!("The container is empty"),
117         }
118         // Back to the first one
119         first_one
120     }
121 
122     /// Get a node by string key
get_str_mut<'a>(&'a mut self, key: &str) -> Option<&'a mut N>123     pub fn get_str_mut<'a>(&'a mut self, key: &str) -> Option<&'a mut N> {
124         self.get_mut(key.as_bytes())
125     }
126 
127     /// Remove a node with all replicas (virtual nodes)
remove(&mut self, node: &N)128     pub fn remove(&mut self, node: &N) {
129         let node_name = node.name();
130         debug!("Removing node {:?}", node_name);
131 
132         let num_replicas = match self.replicas.remove(&node_name) {
133             Some(val) => {
134                 debug!("Node {:?} has {} replicas", node_name, val);
135                 val
136             }
137             None => {
138                 debug!("Node {:?} not exists", node_name);
139                 return;
140             }
141         };
142 
143         debug!("Node {:?} replicas {}", node_name, num_replicas);
144 
145         for replica in 0..num_replicas {
146             let node_ident = format!("{}:{}", node.name(), replica);
147             let key = (self.hash_fn)(node_ident.as_bytes());
148             self.nodes.remove(&key);
149         }
150     }
151 
152     /// Number of nodes
len(&self) -> usize153     pub fn len(&self) -> usize {
154         self.nodes.len()
155     }
156 }
157 
158 #[cfg(test)]
159 mod test {
160     use super::ConsistentHash;
161     use node::Node;
162 
163     #[derive(Debug, Clone, Eq, PartialEq)]
164     struct ServerNode {
165         host: String,
166         port: u16,
167     }
168 
169     impl Node for ServerNode {
name(&self) -> String170         fn name(&self) -> String {
171             format!("{}:{}", self.host, self.port)
172         }
173     }
174 
175     impl ServerNode {
new(host: &str, port: u16) -> ServerNode176         fn new(host: &str, port: u16) -> ServerNode {
177             ServerNode {
178                 host: host.to_owned(),
179                 port: port,
180             }
181         }
182     }
183 
184     #[test]
test_basic()185     fn test_basic() {
186         let nodes = [
187             ServerNode::new("localhost", 12345),
188             ServerNode::new("localhost", 12346),
189             ServerNode::new("localhost", 12347),
190             ServerNode::new("localhost", 12348),
191             ServerNode::new("localhost", 12349),
192             ServerNode::new("localhost", 12350),
193             ServerNode::new("localhost", 12351),
194             ServerNode::new("localhost", 12352),
195             ServerNode::new("localhost", 12353),
196         ];
197 
198         const REPLICAS: usize = 20;
199 
200         let mut ch = ConsistentHash::new();
201 
202         for node in nodes.iter() {
203             ch.add(node, REPLICAS);
204         }
205 
206         assert_eq!(ch.len(), nodes.len() * REPLICAS);
207 
208         let node_for_hello = ch.get_str("hello").unwrap().clone();
209         assert_eq!(node_for_hello, ServerNode::new("localhost", 12347));
210 
211         ch.remove(&ServerNode::new("localhost", 12350));
212         assert_eq!(ch.get_str("hello").unwrap().clone(), node_for_hello);
213 
214         assert_eq!(ch.len(), (nodes.len() - 1) * REPLICAS);
215 
216         ch.remove(&ServerNode::new("localhost", 12347));
217         assert_ne!(ch.get_str("hello").unwrap().clone(), node_for_hello);
218 
219         assert_eq!(ch.len(), (nodes.len() - 2) * REPLICAS);
220     }
221 }
222