1 use winapi::shared::{sspi, winerror};
2 use winapi::shared::minwindef::ULONG;
3 use winapi::um::{minschannel, schannel};
4 use std::mem;
5 use std::ptr;
6 use std::io;
7 
8 use crate::{INIT_REQUESTS, Inner, secbuf, secbuf_desc};
9 use crate::alpn_list::AlpnList;
10 use crate::cert_context::CertContext;
11 use crate::context_buffer::ContextBuffer;
12 
13 use crate::schannel_cred::SchannelCred;
14 
15 pub struct SecurityContext(sspi::CtxtHandle);
16 
17 impl Drop for SecurityContext {
drop(&mut self)18     fn drop(&mut self) {
19         unsafe {
20             sspi::DeleteSecurityContext(&mut self.0);
21         }
22     }
23 }
24 
25 impl Inner<sspi::CtxtHandle> for SecurityContext {
from_inner(inner: sspi::CtxtHandle) -> SecurityContext26     unsafe fn from_inner(inner: sspi::CtxtHandle) -> SecurityContext {
27         SecurityContext(inner)
28     }
29 
as_inner(&self) -> sspi::CtxtHandle30     fn as_inner(&self) -> sspi::CtxtHandle {
31         self.0
32     }
33 
get_mut(&mut self) -> &mut sspi::CtxtHandle34     fn get_mut(&mut self) -> &mut sspi::CtxtHandle {
35         &mut self.0
36     }
37 }
38 
39 impl SecurityContext {
initialize(cred: &mut SchannelCred, accept: bool, domain: Option<&[u16]>, requested_application_protocols: &Option<Vec<Vec<u8>>>) -> io::Result<(SecurityContext, Option<ContextBuffer>)>40     pub fn initialize(cred: &mut SchannelCred,
41                       accept: bool,
42                       domain: Option<&[u16]>,
43                       requested_application_protocols: &Option<Vec<Vec<u8>>>)
44                       -> io::Result<(SecurityContext, Option<ContextBuffer>)> {
45         unsafe {
46             let mut ctxt = mem::zeroed();
47 
48             if accept {
49                 // If we're performing an accept then we need to wait to call
50                 // `AcceptSecurityContext` until we've actually read some data.
51                 return Ok((SecurityContext(ctxt), None))
52             }
53 
54             let domain = domain.map(|b| b.as_ptr() as *mut u16).unwrap_or(ptr::null_mut());
55 
56             let mut inbufs = vec![];
57 
58             // Make sure `AlpnList` is kept alive for the duration of this function.
59             let mut alpns = requested_application_protocols.as_ref().map(|alpn| AlpnList::new(&alpn));
60             if let Some(ref mut alpns) = alpns {
61                 inbufs.push(secbuf(sspi::SECBUFFER_APPLICATION_PROTOCOLS,
62                                    Some(&mut alpns[..])));
63             };
64 
65             let mut inbuf_desc = secbuf_desc(&mut inbufs[..]);
66 
67             let mut outbuf = [secbuf(sspi::SECBUFFER_EMPTY, None)];
68             let mut outbuf_desc = secbuf_desc(&mut outbuf);
69 
70             let mut attributes = 0;
71 
72             match sspi::InitializeSecurityContextW(&mut cred.as_inner(),
73                                                    ptr::null_mut(),
74                                                    domain,
75                                                    INIT_REQUESTS,
76                                                    0,
77                                                    0,
78                                                    &mut inbuf_desc,
79                                                    0,
80                                                    &mut ctxt,
81                                                    &mut outbuf_desc,
82                                                    &mut attributes,
83                                                    ptr::null_mut()) {
84                 winerror::SEC_I_CONTINUE_NEEDED => {
85                     Ok((SecurityContext(ctxt), Some(ContextBuffer(outbuf[0]))))
86                 }
87                 err => {
88                     Err(io::Error::from_raw_os_error(err as i32))
89                 }
90             }
91         }
92     }
93 
attribute<T>(&self, attr: ULONG) -> io::Result<T>94     unsafe fn attribute<T>(&self, attr: ULONG) -> io::Result<T> {
95         let mut value = std::mem::zeroed();
96         let status = sspi::QueryContextAttributesW(&self.0 as *const _ as *mut _,
97                                                    attr,
98                                                    &mut value as *mut _ as *mut _);
99         if status == winerror::SEC_E_OK {
100             Ok(value)
101         } else {
102             Err(io::Error::from_raw_os_error(status as i32))
103         }
104     }
105 
application_protocol(&self) -> io::Result<sspi::SecPkgContext_ApplicationProtocol>106     pub fn application_protocol(&self) -> io::Result<sspi::SecPkgContext_ApplicationProtocol> {
107         unsafe {
108             self.attribute(sspi::SECPKG_ATTR_APPLICATION_PROTOCOL)
109         }
110     }
111 
session_info(&self) -> io::Result<schannel::SecPkgContext_SessionInfo>112     pub fn session_info(&self) -> io::Result<schannel::SecPkgContext_SessionInfo> {
113         unsafe {
114             self.attribute(minschannel::SECPKG_ATTR_SESSION_INFO)
115         }
116     }
117 
stream_sizes(&self) -> io::Result<sspi::SecPkgContext_StreamSizes>118     pub fn stream_sizes(&self) -> io::Result<sspi::SecPkgContext_StreamSizes> {
119         unsafe {
120             self.attribute(sspi::SECPKG_ATTR_STREAM_SIZES)
121         }
122     }
123 
remote_cert(&self) -> io::Result<CertContext>124     pub fn remote_cert(&self) -> io::Result<CertContext> {
125         unsafe {
126             self.attribute(minschannel::SECPKG_ATTR_REMOTE_CERT_CONTEXT)
127                 .map(|p| CertContext::from_inner(p))
128         }
129     }
130 
local_cert(&self) -> io::Result<CertContext>131     pub fn local_cert(&self) -> io::Result<CertContext> {
132         unsafe {
133             self.attribute(minschannel::SECPKG_ATTR_LOCAL_CERT_CONTEXT)
134                 .map(|p| CertContext::from_inner(p))
135         }
136     }
137 }
138