1 #![allow(non_camel_case_types, non_snake_case)]
2
3 use libc::c_void;
4
5 #[cfg(target_env = "msvc")]
6 mod win {
7 extern crate winapi;
8 use self::winapi::ctypes::*;
9 use self::winapi::um::libloaderapi::*;
10 use self::winapi::um::wincrypt::*;
11 use schannel::cert_context::ValidUses;
12 use schannel::cert_store::CertStore;
13 use std::ffi::CString;
14 use std::mem;
15 use std::ptr;
16
lookup(module: &str, symbol: &str) -> Option<*const c_void>17 fn lookup(module: &str, symbol: &str) -> Option<*const c_void> {
18 unsafe {
19 let symbol = CString::new(symbol).unwrap();
20 let mut mod_buf: Vec<u16> = module.encode_utf16().collect();
21 mod_buf.push(0);
22 let handle = GetModuleHandleW(mod_buf.as_mut_ptr());
23 let n = GetProcAddress(handle, symbol.as_ptr());
24 if n == ptr::null_mut() {
25 None
26 } else {
27 Some(n as *const c_void)
28 }
29 }
30 }
31
32 pub enum X509_STORE {}
33 pub enum X509 {}
34 pub enum SSL_CTX {}
35
36 type d2i_X509_fn = unsafe extern "C" fn(
37 a: *mut *mut X509,
38 pp: *mut *const c_uchar,
39 length: c_long,
40 ) -> *mut X509;
41 type X509_free_fn = unsafe extern "C" fn(x: *mut X509);
42 type X509_STORE_add_cert_fn =
43 unsafe extern "C" fn(store: *mut X509_STORE, x: *mut X509) -> c_int;
44 type SSL_CTX_get_cert_store_fn = unsafe extern "C" fn(ctx: *const SSL_CTX) -> *mut X509_STORE;
45
46 struct OpenSSL {
47 d2i_X509: d2i_X509_fn,
48 X509_free: X509_free_fn,
49 X509_STORE_add_cert: X509_STORE_add_cert_fn,
50 SSL_CTX_get_cert_store: SSL_CTX_get_cert_store_fn,
51 }
52
lookup_functions(crypto_module: &str, ssl_module: &str) -> Option<OpenSSL>53 unsafe fn lookup_functions(crypto_module: &str, ssl_module: &str) -> Option<OpenSSL> {
54 macro_rules! get {
55 ($(let $sym:ident in $module:expr;)*) => ($(
56 let $sym = match lookup($module, stringify!($sym)) {
57 Some(p) => p,
58 None => return None,
59 };
60 )*)
61 }
62 get! {
63 let d2i_X509 in crypto_module;
64 let X509_free in crypto_module;
65 let X509_STORE_add_cert in crypto_module;
66 let SSL_CTX_get_cert_store in ssl_module;
67 }
68 Some(OpenSSL {
69 d2i_X509: mem::transmute(d2i_X509),
70 X509_free: mem::transmute(X509_free),
71 X509_STORE_add_cert: mem::transmute(X509_STORE_add_cert),
72 SSL_CTX_get_cert_store: mem::transmute(SSL_CTX_get_cert_store),
73 })
74 }
75
add_certs_to_context(ssl_ctx: *mut c_void)76 pub unsafe fn add_certs_to_context(ssl_ctx: *mut c_void) {
77 // check the runtime version of OpenSSL
78 let openssl = match ::version::Version::get().ssl_version() {
79 Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.1.0") => {
80 lookup_functions("libcrypto", "libssl")
81 }
82 Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.0.2") => {
83 lookup_functions("libeay32", "ssleay32")
84 }
85 _ => return,
86 };
87 let openssl = match openssl {
88 Some(s) => s,
89 None => return,
90 };
91
92 let openssl_store = (openssl.SSL_CTX_get_cert_store)(ssl_ctx as *const SSL_CTX);
93 let store = match CertStore::open_current_user("ROOT") {
94 Ok(s) => s,
95 Err(_) => return,
96 };
97
98 for cert in store.certs() {
99 let valid_uses = match cert.valid_uses() {
100 Ok(v) => v,
101 Err(_) => continue,
102 };
103
104 // check the extended key usage for the "Server Authentication" OID
105 match valid_uses {
106 ValidUses::All => {}
107 ValidUses::Oids(ref oids) => {
108 let oid = szOID_PKIX_KP_SERVER_AUTH.to_owned();
109 if !oids.contains(&oid) {
110 continue;
111 }
112 }
113 }
114
115 let der = cert.to_der();
116 let x509 = (openssl.d2i_X509)(ptr::null_mut(), &mut der.as_ptr(), der.len() as c_long);
117 if !x509.is_null() {
118 (openssl.X509_STORE_add_cert)(openssl_store, x509);
119 (openssl.X509_free)(x509);
120 }
121 }
122 }
123 }
124
125 #[cfg(target_env = "msvc")]
add_certs_to_context(ssl_ctx: *mut c_void)126 pub fn add_certs_to_context(ssl_ctx: *mut c_void) {
127 unsafe {
128 win::add_certs_to_context(ssl_ctx as *mut _);
129 }
130 }
131
132 #[cfg(not(target_env = "msvc"))]
add_certs_to_context(_: *mut c_void)133 pub fn add_certs_to_context(_: *mut c_void) {}
134