1 //! Host header validation.
2 
3 use crate::matcher::{Matcher, Pattern};
4 use std::collections::HashSet;
5 use std::net::SocketAddr;
6 
7 const SPLIT_PROOF: &str = "split always returns non-empty iterator.";
8 
9 /// Port pattern
10 #[derive(Clone, Hash, PartialEq, Eq, Debug)]
11 pub enum Port {
12 	/// No port specified (default port)
13 	None,
14 	/// Port specified as a wildcard pattern
15 	Pattern(String),
16 	/// Fixed numeric port
17 	Fixed(u16),
18 }
19 
20 impl From<Option<u16>> for Port {
from(opt: Option<u16>) -> Self21 	fn from(opt: Option<u16>) -> Self {
22 		match opt {
23 			Some(port) => Port::Fixed(port),
24 			None => Port::None,
25 		}
26 	}
27 }
28 
29 impl From<u16> for Port {
from(port: u16) -> Port30 	fn from(port: u16) -> Port {
31 		Port::Fixed(port)
32 	}
33 }
34 
35 /// Host type
36 #[derive(Clone, Hash, PartialEq, Eq, Debug)]
37 pub struct Host {
38 	hostname: String,
39 	port: Port,
40 	as_string: String,
41 	matcher: Matcher,
42 }
43 
44 impl<T: AsRef<str>> From<T> for Host {
from(string: T) -> Self45 	fn from(string: T) -> Self {
46 		Host::parse(string.as_ref())
47 	}
48 }
49 
50 impl Host {
51 	/// Creates a new `Host` given hostname and port number.
new<T: Into<Port>>(hostname: &str, port: T) -> Self52 	pub fn new<T: Into<Port>>(hostname: &str, port: T) -> Self {
53 		let port = port.into();
54 		let hostname = Self::pre_process(hostname);
55 		let string = Self::to_string(&hostname, &port);
56 		let matcher = Matcher::new(&string);
57 
58 		Host {
59 			hostname,
60 			port,
61 			as_string: string,
62 			matcher,
63 		}
64 	}
65 
66 	/// Attempts to parse given string as a `Host`.
67 	/// NOTE: This method always succeeds and falls back to sensible defaults.
parse(hostname: &str) -> Self68 	pub fn parse(hostname: &str) -> Self {
69 		let hostname = Self::pre_process(hostname);
70 		let mut hostname = hostname.split(':');
71 		let host = hostname.next().expect(SPLIT_PROOF);
72 		let port = match hostname.next() {
73 			None => Port::None,
74 			Some(port) => match port.parse::<u16>().ok() {
75 				Some(num) => Port::Fixed(num),
76 				None => Port::Pattern(port.into()),
77 			},
78 		};
79 
80 		Host::new(host, port)
81 	}
82 
pre_process(host: &str) -> String83 	fn pre_process(host: &str) -> String {
84 		// Remove possible protocol definition
85 		let mut it = host.split("://");
86 		let protocol = it.next().expect(SPLIT_PROOF);
87 		let host = match it.next() {
88 			Some(data) => data,
89 			None => protocol,
90 		};
91 
92 		let mut it = host.split('/');
93 		it.next().expect(SPLIT_PROOF).to_lowercase()
94 	}
95 
to_string(hostname: &str, port: &Port) -> String96 	fn to_string(hostname: &str, port: &Port) -> String {
97 		format!(
98 			"{}{}",
99 			hostname,
100 			match *port {
101 				Port::Fixed(port) => format!(":{}", port),
102 				Port::Pattern(ref port) => format!(":{}", port),
103 				Port::None => "".into(),
104 			},
105 		)
106 	}
107 }
108 
109 impl Pattern for Host {
matches<T: AsRef<str>>(&self, other: T) -> bool110 	fn matches<T: AsRef<str>>(&self, other: T) -> bool {
111 		self.matcher.matches(other)
112 	}
113 }
114 
115 impl ::std::ops::Deref for Host {
116 	type Target = str;
deref(&self) -> &Self::Target117 	fn deref(&self) -> &Self::Target {
118 		&self.as_string
119 	}
120 }
121 
122 /// Specifies if domains should be validated.
123 #[derive(Clone, Debug, PartialEq, Eq)]
124 pub enum DomainsValidation<T> {
125 	/// Allow only domains on the list.
126 	AllowOnly(Vec<T>),
127 	/// Disable domains validation completely.
128 	Disabled,
129 }
130 
131 impl<T> Into<Option<Vec<T>>> for DomainsValidation<T> {
into(self) -> Option<Vec<T>>132 	fn into(self) -> Option<Vec<T>> {
133 		use self::DomainsValidation::*;
134 		match self {
135 			AllowOnly(list) => Some(list),
136 			Disabled => None,
137 		}
138 	}
139 }
140 
141 impl<T> From<Option<Vec<T>>> for DomainsValidation<T> {
from(other: Option<Vec<T>>) -> Self142 	fn from(other: Option<Vec<T>>) -> Self {
143 		match other {
144 			Some(list) => DomainsValidation::AllowOnly(list),
145 			None => DomainsValidation::Disabled,
146 		}
147 	}
148 }
149 
150 /// Returns `true` when `Host` header is whitelisted in `allowed_hosts`.
is_host_valid(host: Option<&str>, allowed_hosts: &Option<Vec<Host>>) -> bool151 pub fn is_host_valid(host: Option<&str>, allowed_hosts: &Option<Vec<Host>>) -> bool {
152 	match allowed_hosts.as_ref() {
153 		None => true,
154 		Some(ref allowed_hosts) => match host {
155 			None => false,
156 			Some(ref host) => allowed_hosts.iter().any(|h| h.matches(host)),
157 		},
158 	}
159 }
160 
161 /// Updates given list of hosts with the address.
update(hosts: Option<Vec<Host>>, address: &SocketAddr) -> Option<Vec<Host>>162 pub fn update(hosts: Option<Vec<Host>>, address: &SocketAddr) -> Option<Vec<Host>> {
163 	use std::net::{IpAddr, Ipv4Addr};
164 
165 	hosts.map(|current_hosts| {
166 		let mut new_hosts = current_hosts.into_iter().collect::<HashSet<_>>();
167 		let address_string = address.to_string();
168 
169 		if address.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED) {
170 			new_hosts.insert(address_string.replace("0.0.0.0", "127.0.0.1").into());
171 			new_hosts.insert(address_string.replace("0.0.0.0", "localhost").into());
172 		} else if address.ip() == IpAddr::V4(Ipv4Addr::LOCALHOST) {
173 			new_hosts.insert(address_string.replace("127.0.0.1", "localhost").into());
174 		}
175 
176 		new_hosts.insert(address_string.into());
177 		new_hosts.into_iter().collect()
178 	})
179 }
180 
181 #[cfg(test)]
182 mod tests {
183 	use super::{is_host_valid, Host};
184 
185 	#[test]
should_parse_host()186 	fn should_parse_host() {
187 		assert_eq!(Host::parse("http://parity.io"), Host::new("parity.io", None));
188 		assert_eq!(
189 			Host::parse("https://parity.io:8443"),
190 			Host::new("parity.io", Some(8443))
191 		);
192 		assert_eq!(
193 			Host::parse("chrome-extension://124.0.0.1"),
194 			Host::new("124.0.0.1", None)
195 		);
196 		assert_eq!(Host::parse("parity.io/somepath"), Host::new("parity.io", None));
197 		assert_eq!(
198 			Host::parse("127.0.0.1:8545/somepath"),
199 			Host::new("127.0.0.1", Some(8545))
200 		);
201 	}
202 
203 	#[test]
should_reject_when_there_is_no_header()204 	fn should_reject_when_there_is_no_header() {
205 		let valid = is_host_valid(None, &Some(vec![]));
206 		assert_eq!(valid, false);
207 	}
208 
209 	#[test]
should_reject_when_validation_is_disabled()210 	fn should_reject_when_validation_is_disabled() {
211 		let valid = is_host_valid(Some("any"), &None);
212 		assert_eq!(valid, true);
213 	}
214 
215 	#[test]
should_reject_if_header_not_on_the_list()216 	fn should_reject_if_header_not_on_the_list() {
217 		let valid = is_host_valid(Some("parity.io"), &Some(vec![]));
218 		assert_eq!(valid, false);
219 	}
220 
221 	#[test]
should_accept_if_on_the_list()222 	fn should_accept_if_on_the_list() {
223 		let valid = is_host_valid(Some("parity.io"), &Some(vec!["parity.io".into()]));
224 		assert_eq!(valid, true);
225 	}
226 
227 	#[test]
should_accept_if_on_the_list_with_port()228 	fn should_accept_if_on_the_list_with_port() {
229 		let valid = is_host_valid(Some("parity.io:443"), &Some(vec!["parity.io:443".into()]));
230 		assert_eq!(valid, true);
231 	}
232 
233 	#[test]
should_support_wildcards()234 	fn should_support_wildcards() {
235 		let valid = is_host_valid(Some("parity.web3.site:8180"), &Some(vec!["*.web3.site:*".into()]));
236 		assert_eq!(valid, true);
237 	}
238 }
239