1 use std::fs;
2 use std::io::{Read, Write};
3 use std::net::{TcpListener, TcpStream};
4 use std::process::{Command, Stdio};
5 use std::string::String;
6 use std::thread;
7 
8 use super::*;
9 
10 macro_rules! p {
11     ($e:expr) => {
12         match $e {
13             Ok(r) => r,
14             Err(e) => panic!("{:?}", e),
15         }
16     };
17 }
18 
19 #[test]
connect_google()20 fn connect_google() {
21     let builder = p!(TlsConnector::new());
22     let s = p!(TcpStream::connect("google.com:443"));
23     let mut socket = p!(builder.connect("google.com", s));
24 
25     p!(socket.write_all(b"GET / HTTP/1.0\r\n\r\n"));
26     let mut result = vec![];
27     p!(socket.read_to_end(&mut result));
28 
29     println!("{}", String::from_utf8_lossy(&result));
30     assert!(result.starts_with(b"HTTP/1.0"));
31     assert!(result.ends_with(b"</HTML>\r\n") || result.ends_with(b"</html>"));
32 }
33 
34 #[test]
connect_bad_hostname()35 fn connect_bad_hostname() {
36     let builder = p!(TlsConnector::new());
37     let s = p!(TcpStream::connect("google.com:443"));
38     builder.connect("goggle.com", s).unwrap_err();
39 }
40 
41 #[test]
connect_bad_hostname_ignored()42 fn connect_bad_hostname_ignored() {
43     let builder = p!(TlsConnector::builder()
44         .danger_accept_invalid_hostnames(true)
45         .build());
46     let s = p!(TcpStream::connect("google.com:443"));
47     builder.connect("goggle.com", s).unwrap();
48 }
49 
50 #[test]
connect_no_root_certs()51 fn connect_no_root_certs() {
52     let builder = p!(TlsConnector::builder().disable_built_in_roots(true).build());
53     let s = p!(TcpStream::connect("google.com:443"));
54     assert!(builder.connect("google.com", s).is_err());
55 }
56 
57 #[test]
server_no_root_certs()58 fn server_no_root_certs() {
59     let keys = test_cert_gen::keys();
60 
61     let identity = p!(Identity::from_pkcs12(
62         &keys.server.pkcs12,
63         &keys.server.pkcs12_password
64     ));
65     let builder = p!(TlsAcceptor::new(identity));
66 
67     let listener = p!(TcpListener::bind("0.0.0.0:0"));
68     let port = p!(listener.local_addr()).port();
69 
70     let j = thread::spawn(move || {
71         let socket = p!(listener.accept()).0;
72         let mut socket = p!(builder.accept(socket));
73 
74         let mut buf = [0; 5];
75         p!(socket.read_exact(&mut buf));
76         assert_eq!(&buf, b"hello");
77 
78         p!(socket.write_all(b"world"));
79     });
80 
81     let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap();
82 
83     let socket = p!(TcpStream::connect(("localhost", port)));
84     let builder = p!(TlsConnector::builder()
85         .disable_built_in_roots(true)
86         .add_root_certificate(root_ca)
87         .build());
88     let mut socket = p!(builder.connect("localhost", socket));
89 
90     p!(socket.write_all(b"hello"));
91     let mut buf = vec![];
92     p!(socket.read_to_end(&mut buf));
93     assert_eq!(buf, b"world");
94 
95     p!(j.join());
96 }
97 
98 #[test]
server()99 fn server() {
100     let keys = test_cert_gen::keys();
101 
102     let identity = p!(Identity::from_pkcs12(
103         &keys.server.pkcs12,
104         &keys.server.pkcs12_password
105     ));
106     let builder = p!(TlsAcceptor::new(identity));
107 
108     let listener = p!(TcpListener::bind("0.0.0.0:0"));
109     let port = p!(listener.local_addr()).port();
110 
111     let j = thread::spawn(move || {
112         let socket = p!(listener.accept()).0;
113         let mut socket = p!(builder.accept(socket));
114 
115         let mut buf = [0; 5];
116         p!(socket.read_exact(&mut buf));
117         assert_eq!(&buf, b"hello");
118 
119         p!(socket.write_all(b"world"));
120     });
121 
122     let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap();
123 
124     let socket = p!(TcpStream::connect(("localhost", port)));
125     let builder = p!(TlsConnector::builder()
126         .add_root_certificate(root_ca)
127         .build());
128     let mut socket = p!(builder.connect("localhost", socket));
129 
130     p!(socket.write_all(b"hello"));
131     let mut buf = vec![];
132     p!(socket.read_to_end(&mut buf));
133     assert_eq!(buf, b"world");
134 
135     p!(j.join());
136 }
137 
138 #[test]
certificate_from_pem()139 fn certificate_from_pem() {
140     let dir = tempfile::tempdir().unwrap();
141     let keys = test_cert_gen::keys();
142 
143     let der_path = dir.path().join("cert.der");
144     fs::write(&der_path, &keys.client.cert_der).unwrap();
145     let output = Command::new("openssl")
146         .arg("x509")
147         .arg("-in")
148         .arg(der_path)
149         .arg("-inform")
150         .arg("der")
151         .stderr(Stdio::piped())
152         .output()
153         .unwrap();
154 
155     assert!(output.status.success());
156 
157     let cert = Certificate::from_pem(&output.stdout).unwrap();
158     assert_eq!(cert.to_der().unwrap(), keys.client.cert_der);
159 }
160 
161 #[test]
peer_certificate()162 fn peer_certificate() {
163     let keys = test_cert_gen::keys();
164 
165     let identity = p!(Identity::from_pkcs12(
166         &keys.server.pkcs12,
167         &keys.server.pkcs12_password
168     ));
169     let builder = p!(TlsAcceptor::new(identity));
170 
171     let listener = p!(TcpListener::bind("0.0.0.0:0"));
172     let port = p!(listener.local_addr()).port();
173 
174     let j = thread::spawn(move || {
175         let socket = p!(listener.accept()).0;
176         let socket = p!(builder.accept(socket));
177         assert!(socket.peer_certificate().unwrap().is_none());
178     });
179 
180     let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap();
181 
182     let socket = p!(TcpStream::connect(("localhost", port)));
183     let builder = p!(TlsConnector::builder()
184         .add_root_certificate(root_ca)
185         .build());
186     let socket = p!(builder.connect("localhost", socket));
187 
188     let cert = socket.peer_certificate().unwrap().unwrap();
189     assert_eq!(cert.to_der().unwrap(), keys.client.cert_der);
190 
191     p!(j.join());
192 }
193 
194 #[test]
server_tls11_only()195 fn server_tls11_only() {
196     let keys = test_cert_gen::keys();
197 
198     let identity = p!(Identity::from_pkcs12(
199         &keys.server.pkcs12,
200         &keys.server.pkcs12_password
201     ));
202     let builder = p!(TlsAcceptor::builder(identity)
203         .min_protocol_version(Some(Protocol::Tlsv11))
204         .max_protocol_version(Some(Protocol::Tlsv11))
205         .build());
206 
207     let listener = p!(TcpListener::bind("0.0.0.0:0"));
208     let port = p!(listener.local_addr()).port();
209 
210     let j = thread::spawn(move || {
211         let socket = p!(listener.accept()).0;
212         let mut socket = p!(builder.accept(socket));
213 
214         let mut buf = [0; 5];
215         p!(socket.read_exact(&mut buf));
216         assert_eq!(&buf, b"hello");
217 
218         p!(socket.write_all(b"world"));
219     });
220 
221     let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap();
222 
223     let socket = p!(TcpStream::connect(("localhost", port)));
224     let builder = p!(TlsConnector::builder()
225         .add_root_certificate(root_ca)
226         .min_protocol_version(Some(Protocol::Tlsv11))
227         .max_protocol_version(Some(Protocol::Tlsv11))
228         .build());
229     let mut socket = p!(builder.connect("localhost", socket));
230 
231     p!(socket.write_all(b"hello"));
232     let mut buf = vec![];
233     p!(socket.read_to_end(&mut buf));
234     assert_eq!(buf, b"world");
235 
236     p!(j.join());
237 }
238 
239 #[test]
server_no_shared_protocol()240 fn server_no_shared_protocol() {
241     let keys = test_cert_gen::keys();
242 
243     let identity = p!(Identity::from_pkcs12(
244         &keys.server.pkcs12,
245         &keys.server.pkcs12_password
246     ));
247     let builder = p!(TlsAcceptor::builder(identity)
248         .min_protocol_version(Some(Protocol::Tlsv12))
249         .build());
250 
251     let listener = p!(TcpListener::bind("0.0.0.0:0"));
252     let port = p!(listener.local_addr()).port();
253 
254     let j = thread::spawn(move || {
255         let socket = p!(listener.accept()).0;
256         assert!(builder.accept(socket).is_err());
257     });
258 
259     let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap();
260 
261     let socket = p!(TcpStream::connect(("localhost", port)));
262     let builder = p!(TlsConnector::builder()
263         .add_root_certificate(root_ca)
264         .min_protocol_version(Some(Protocol::Tlsv11))
265         .max_protocol_version(Some(Protocol::Tlsv11))
266         .build());
267     assert!(builder.connect("localhost", socket).is_err());
268 
269     p!(j.join());
270 }
271 
272 #[test]
server_untrusted()273 fn server_untrusted() {
274     let keys = test_cert_gen::keys();
275 
276     let identity = p!(Identity::from_pkcs12(
277         &keys.server.pkcs12,
278         &keys.server.pkcs12_password
279     ));
280     let builder = p!(TlsAcceptor::new(identity));
281 
282     let listener = p!(TcpListener::bind("0.0.0.0:0"));
283     let port = p!(listener.local_addr()).port();
284 
285     let j = thread::spawn(move || {
286         let socket = p!(listener.accept()).0;
287         // FIXME should assert error
288         // https://github.com/steffengy/schannel-rs/issues/20
289         let _ = builder.accept(socket);
290     });
291 
292     let socket = p!(TcpStream::connect(("localhost", port)));
293     let builder = p!(TlsConnector::new());
294     builder.connect("localhost", socket).unwrap_err();
295 
296     p!(j.join());
297 }
298 
299 #[test]
server_untrusted_unverified()300 fn server_untrusted_unverified() {
301     let keys = test_cert_gen::keys();
302 
303     let identity = p!(Identity::from_pkcs12(
304         &keys.server.pkcs12,
305         &keys.server.pkcs12_password
306     ));
307     let builder = p!(TlsAcceptor::new(identity));
308 
309     let listener = p!(TcpListener::bind("0.0.0.0:0"));
310     let port = p!(listener.local_addr()).port();
311 
312     let j = thread::spawn(move || {
313         let socket = p!(listener.accept()).0;
314         let mut socket = p!(builder.accept(socket));
315 
316         let mut buf = [0; 5];
317         p!(socket.read_exact(&mut buf));
318         assert_eq!(&buf, b"hello");
319 
320         p!(socket.write_all(b"world"));
321     });
322 
323     let socket = p!(TcpStream::connect(("localhost", port)));
324     let builder = p!(TlsConnector::builder()
325         .danger_accept_invalid_certs(true)
326         .build());
327     let mut socket = p!(builder.connect("localhost", socket));
328 
329     p!(socket.write_all(b"hello"));
330     let mut buf = vec![];
331     p!(socket.read_to_end(&mut buf));
332     assert_eq!(buf, b"world");
333 
334     p!(j.join());
335 }
336 
337 #[test]
import_same_identity_multiple_times()338 fn import_same_identity_multiple_times() {
339     let keys = test_cert_gen::keys();
340 
341     let _ = p!(Identity::from_pkcs12(
342         &keys.server.pkcs12,
343         &keys.server.pkcs12_password
344     ));
345     let _ = p!(Identity::from_pkcs12(
346         &keys.server.pkcs12,
347         &keys.server.pkcs12_password
348     ));
349 }
350 
351 #[test]
shutdown()352 fn shutdown() {
353     let keys = test_cert_gen::keys();
354 
355     let identity = p!(Identity::from_pkcs12(
356         &keys.server.pkcs12,
357         &keys.server.pkcs12_password
358     ));
359     let builder = p!(TlsAcceptor::new(identity));
360 
361     let listener = p!(TcpListener::bind("0.0.0.0:0"));
362     let port = p!(listener.local_addr()).port();
363 
364     let j = thread::spawn(move || {
365         let socket = p!(listener.accept()).0;
366         let mut socket = p!(builder.accept(socket));
367 
368         let mut buf = [0; 5];
369         p!(socket.read_exact(&mut buf));
370         assert_eq!(&buf, b"hello");
371 
372         assert_eq!(p!(socket.read(&mut buf)), 0);
373         p!(socket.shutdown());
374     });
375 
376     let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap();
377 
378     let socket = p!(TcpStream::connect(("localhost", port)));
379     let builder = p!(TlsConnector::builder()
380         .add_root_certificate(root_ca)
381         .build());
382     let mut socket = p!(builder.connect("localhost", socket));
383 
384     p!(socket.write_all(b"hello"));
385     p!(socket.shutdown());
386 
387     p!(j.join());
388 }
389 
390 #[test]
391 #[cfg(feature = "alpn")]
alpn_google_h2()392 fn alpn_google_h2() {
393     let builder = p!(TlsConnector::builder().request_alpns(&["h2"]).build());
394     let s = p!(TcpStream::connect("google.com:443"));
395     let socket = p!(builder.connect("google.com", s));
396     let alpn = p!(socket.negotiated_alpn());
397     assert_eq!(alpn, Some(b"h2".to_vec()));
398 }
399 
400 #[test]
401 #[cfg(feature = "alpn")]
alpn_google_invalid()402 fn alpn_google_invalid() {
403     let builder = p!(TlsConnector::builder().request_alpns(&["h2c"]).build());
404     let s = p!(TcpStream::connect("google.com:443"));
405     let socket = p!(builder.connect("google.com", s));
406     let alpn = p!(socket.negotiated_alpn());
407     assert_eq!(alpn, None);
408 }
409 
410 #[test]
411 #[cfg(feature = "alpn")]
alpn_google_none()412 fn alpn_google_none() {
413     let builder = p!(TlsConnector::new());
414     let s = p!(TcpStream::connect("google.com:443"));
415     let socket = p!(builder.connect("google.com", s));
416     let alpn = p!(socket.negotiated_alpn());
417     assert_eq!(alpn, None);
418 }
419