1 use winapi::shared::{sspi, winerror}; 2 use winapi::shared::minwindef::ULONG; 3 use winapi::um::{minschannel, schannel}; 4 use std::mem; 5 use std::ptr; 6 use std::io; 7 8 use crate::{INIT_REQUESTS, Inner, secbuf, secbuf_desc}; 9 use crate::alpn_list::AlpnList; 10 use crate::cert_context::CertContext; 11 use crate::context_buffer::ContextBuffer; 12 13 use crate::schannel_cred::SchannelCred; 14 15 pub struct SecurityContext(sspi::CtxtHandle); 16 17 impl Drop for SecurityContext { drop(&mut self)18 fn drop(&mut self) { 19 unsafe { 20 sspi::DeleteSecurityContext(&mut self.0); 21 } 22 } 23 } 24 25 impl Inner<sspi::CtxtHandle> for SecurityContext { from_inner(inner: sspi::CtxtHandle) -> SecurityContext26 unsafe fn from_inner(inner: sspi::CtxtHandle) -> SecurityContext { 27 SecurityContext(inner) 28 } 29 as_inner(&self) -> sspi::CtxtHandle30 fn as_inner(&self) -> sspi::CtxtHandle { 31 self.0 32 } 33 get_mut(&mut self) -> &mut sspi::CtxtHandle34 fn get_mut(&mut self) -> &mut sspi::CtxtHandle { 35 &mut self.0 36 } 37 } 38 39 impl SecurityContext { initialize(cred: &mut SchannelCred, accept: bool, domain: Option<&[u16]>, requested_application_protocols: &Option<Vec<Vec<u8>>>) -> io::Result<(SecurityContext, Option<ContextBuffer>)>40 pub fn initialize(cred: &mut SchannelCred, 41 accept: bool, 42 domain: Option<&[u16]>, 43 requested_application_protocols: &Option<Vec<Vec<u8>>>) 44 -> io::Result<(SecurityContext, Option<ContextBuffer>)> { 45 unsafe { 46 let mut ctxt = mem::zeroed(); 47 48 if accept { 49 // If we're performing an accept then we need to wait to call 50 // `AcceptSecurityContext` until we've actually read some data. 51 return Ok((SecurityContext(ctxt), None)) 52 } 53 54 let domain = domain.map(|b| b.as_ptr() as *mut u16).unwrap_or(ptr::null_mut()); 55 56 let mut inbufs = vec![]; 57 58 // Make sure `AlpnList` is kept alive for the duration of this function. 59 let mut alpns = requested_application_protocols.as_ref().map(|alpn| AlpnList::new(&alpn)); 60 if let Some(ref mut alpns) = alpns { 61 inbufs.push(secbuf(sspi::SECBUFFER_APPLICATION_PROTOCOLS, 62 Some(&mut alpns[..]))); 63 }; 64 65 let mut inbuf_desc = secbuf_desc(&mut inbufs[..]); 66 67 let mut outbuf = [secbuf(sspi::SECBUFFER_EMPTY, None)]; 68 let mut outbuf_desc = secbuf_desc(&mut outbuf); 69 70 let mut attributes = 0; 71 72 match sspi::InitializeSecurityContextW(&mut cred.as_inner(), 73 ptr::null_mut(), 74 domain, 75 INIT_REQUESTS, 76 0, 77 0, 78 &mut inbuf_desc, 79 0, 80 &mut ctxt, 81 &mut outbuf_desc, 82 &mut attributes, 83 ptr::null_mut()) { 84 winerror::SEC_I_CONTINUE_NEEDED => { 85 Ok((SecurityContext(ctxt), Some(ContextBuffer(outbuf[0])))) 86 } 87 err => { 88 Err(io::Error::from_raw_os_error(err as i32)) 89 } 90 } 91 } 92 } 93 attribute<T>(&self, attr: ULONG) -> io::Result<T>94 unsafe fn attribute<T>(&self, attr: ULONG) -> io::Result<T> { 95 let mut value = std::mem::zeroed(); 96 let status = sspi::QueryContextAttributesW(&self.0 as *const _ as *mut _, 97 attr, 98 &mut value as *mut _ as *mut _); 99 if status == winerror::SEC_E_OK { 100 Ok(value) 101 } else { 102 Err(io::Error::from_raw_os_error(status as i32)) 103 } 104 } 105 application_protocol(&self) -> io::Result<sspi::SecPkgContext_ApplicationProtocol>106 pub fn application_protocol(&self) -> io::Result<sspi::SecPkgContext_ApplicationProtocol> { 107 unsafe { 108 self.attribute(sspi::SECPKG_ATTR_APPLICATION_PROTOCOL) 109 } 110 } 111 session_info(&self) -> io::Result<schannel::SecPkgContext_SessionInfo>112 pub fn session_info(&self) -> io::Result<schannel::SecPkgContext_SessionInfo> { 113 unsafe { 114 self.attribute(minschannel::SECPKG_ATTR_SESSION_INFO) 115 } 116 } 117 stream_sizes(&self) -> io::Result<sspi::SecPkgContext_StreamSizes>118 pub fn stream_sizes(&self) -> io::Result<sspi::SecPkgContext_StreamSizes> { 119 unsafe { 120 self.attribute(sspi::SECPKG_ATTR_STREAM_SIZES) 121 } 122 } 123 remote_cert(&self) -> io::Result<CertContext>124 pub fn remote_cert(&self) -> io::Result<CertContext> { 125 unsafe { 126 self.attribute(minschannel::SECPKG_ATTR_REMOTE_CERT_CONTEXT) 127 .map(|p| CertContext::from_inner(p)) 128 } 129 } 130 local_cert(&self) -> io::Result<CertContext>131 pub fn local_cert(&self) -> io::Result<CertContext> { 132 unsafe { 133 self.attribute(minschannel::SECPKG_ATTR_LOCAL_CERT_CONTEXT) 134 .map(|p| CertContext::from_inner(p)) 135 } 136 } 137 } 138