1 #![allow(non_camel_case_types, non_snake_case)]
2 
3 use libc::c_void;
4 
5 #[cfg(target_env = "msvc")]
6 mod win {
7     use schannel::cert_context::ValidUses;
8     use schannel::cert_store::CertStore;
9     use std::ffi::CString;
10     use std::mem;
11     use std::ptr;
12     use winapi::ctypes::*;
13     use winapi::um::libloaderapi::*;
14     use winapi::um::wincrypt::*;
15 
lookup(module: &str, symbol: &str) -> Option<*const c_void>16     fn lookup(module: &str, symbol: &str) -> Option<*const c_void> {
17         unsafe {
18             let symbol = CString::new(symbol).unwrap();
19             let mut mod_buf: Vec<u16> = module.encode_utf16().collect();
20             mod_buf.push(0);
21             let handle = GetModuleHandleW(mod_buf.as_mut_ptr());
22             let n = GetProcAddress(handle, symbol.as_ptr());
23             if n == ptr::null_mut() {
24                 None
25             } else {
26                 Some(n as *const c_void)
27             }
28         }
29     }
30 
31     pub enum X509_STORE {}
32     pub enum X509 {}
33     pub enum SSL_CTX {}
34 
35     type d2i_X509_fn = unsafe extern "C" fn(
36         a: *mut *mut X509,
37         pp: *mut *const c_uchar,
38         length: c_long,
39     ) -> *mut X509;
40     type X509_free_fn = unsafe extern "C" fn(x: *mut X509);
41     type X509_STORE_add_cert_fn =
42         unsafe extern "C" fn(store: *mut X509_STORE, x: *mut X509) -> c_int;
43     type SSL_CTX_get_cert_store_fn = unsafe extern "C" fn(ctx: *const SSL_CTX) -> *mut X509_STORE;
44 
45     struct OpenSSL {
46         d2i_X509: d2i_X509_fn,
47         X509_free: X509_free_fn,
48         X509_STORE_add_cert: X509_STORE_add_cert_fn,
49         SSL_CTX_get_cert_store: SSL_CTX_get_cert_store_fn,
50     }
51 
lookup_functions(crypto_module: &str, ssl_module: &str) -> Option<OpenSSL>52     unsafe fn lookup_functions(crypto_module: &str, ssl_module: &str) -> Option<OpenSSL> {
53         macro_rules! get {
54             ($(let $sym:ident in $module:expr;)*) => ($(
55                 let $sym = match lookup($module, stringify!($sym)) {
56                     Some(p) => p,
57                     None => return None,
58                 };
59             )*)
60         }
61         get! {
62             let d2i_X509 in crypto_module;
63             let X509_free in crypto_module;
64             let X509_STORE_add_cert in crypto_module;
65             let SSL_CTX_get_cert_store in ssl_module;
66         }
67         Some(OpenSSL {
68             d2i_X509: mem::transmute(d2i_X509),
69             X509_free: mem::transmute(X509_free),
70             X509_STORE_add_cert: mem::transmute(X509_STORE_add_cert),
71             SSL_CTX_get_cert_store: mem::transmute(SSL_CTX_get_cert_store),
72         })
73     }
74 
add_certs_to_context(ssl_ctx: *mut c_void)75     pub unsafe fn add_certs_to_context(ssl_ctx: *mut c_void) {
76         // check the runtime version of OpenSSL
77         let openssl = match crate::version::Version::get().ssl_version() {
78             Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.1.0") => {
79                 lookup_functions("libcrypto", "libssl")
80             }
81             Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.0.2") => {
82                 lookup_functions("libeay32", "ssleay32")
83             }
84             _ => return,
85         };
86         let openssl = match openssl {
87             Some(s) => s,
88             None => return,
89         };
90 
91         let openssl_store = (openssl.SSL_CTX_get_cert_store)(ssl_ctx as *const SSL_CTX);
92         let store = match CertStore::open_current_user("ROOT") {
93             Ok(s) => s,
94             Err(_) => return,
95         };
96 
97         for cert in store.certs() {
98             let valid_uses = match cert.valid_uses() {
99                 Ok(v) => v,
100                 Err(_) => continue,
101             };
102 
103             // check the extended key usage for the "Server Authentication" OID
104             match valid_uses {
105                 ValidUses::All => {}
106                 ValidUses::Oids(ref oids) => {
107                     let oid = szOID_PKIX_KP_SERVER_AUTH.to_owned();
108                     if !oids.contains(&oid) {
109                         continue;
110                     }
111                 }
112             }
113 
114             let der = cert.to_der();
115             let x509 = (openssl.d2i_X509)(ptr::null_mut(), &mut der.as_ptr(), der.len() as c_long);
116             if !x509.is_null() {
117                 (openssl.X509_STORE_add_cert)(openssl_store, x509);
118                 (openssl.X509_free)(x509);
119             }
120         }
121     }
122 }
123 
124 #[cfg(target_env = "msvc")]
add_certs_to_context(ssl_ctx: *mut c_void)125 pub fn add_certs_to_context(ssl_ctx: *mut c_void) {
126     unsafe {
127         win::add_certs_to_context(ssl_ctx as *mut _);
128     }
129 }
130 
131 #[cfg(not(target_env = "msvc"))]
add_certs_to_context(_: *mut c_void)132 pub fn add_certs_to_context(_: *mut c_void) {}
133