1 use std::alloc;
2 use std::mem;
3 use std::ptr;
4 use std::slice;
5 use winapi::ctypes;
6 use winapi::shared::sspi;
7 
8 // This is manually calculated here rather than using `size_of::<SEC_APPLICATION_PROTOCOL_LIST>()`,
9 // as the latter is 2 bytes too large because it accounts for padding at the end of the struct for
10 // alignment requirements, which is irrelevant in actual usage because there is a variable-length
11 // array at the end of the struct.
12 const SEC_APPLICATION_PROTOCOL_LIST_HEADER_SIZE: usize =
13     mem::size_of::<u32>() + mem::size_of::<ctypes::c_ushort>();
14 const SEC_APPLICATION_PROTOCOL_HEADER_SIZE: usize =
15     mem::size_of::<ctypes::c_ulong>() + SEC_APPLICATION_PROTOCOL_LIST_HEADER_SIZE;
16 
17 pub struct AlpnList {
18     layout: alloc::Layout,
19     memory: ptr::NonNull<u8>,
20 }
21 
22 impl Drop for AlpnList {
drop(&mut self)23     fn drop(&mut self) {
24         unsafe {
25             // Safety: `self.memory` was allocated with `self.layout` and is non-null.
26             alloc::dealloc(self.memory.as_ptr(), self.layout);
27         }
28     }
29 }
30 
31 impl AlpnList {
new(protos: &[Vec<u8>]) -> Self32     pub fn new(protos: &[Vec<u8>]) -> Self {
33         // ALPN wire format is each ALPN preceded by its length as a byte.
34         let mut alpn_wire_format = Vec::with_capacity(
35             protos.iter().map(Vec::len).sum::<usize>() + protos.len(),
36         );
37         for alpn in protos {
38             alpn_wire_format.push(alpn.len() as u8);
39             alpn_wire_format.extend(alpn);
40         }
41 
42         let size = SEC_APPLICATION_PROTOCOL_HEADER_SIZE + alpn_wire_format.len();
43         let layout = alloc::Layout::from_size_align(
44             size,
45             mem::align_of::<sspi::SEC_APPLICATION_PROTOCOLS>(),
46         ).unwrap();
47 
48         unsafe {
49             // Safety: `layout` is guaranteed to have non-zero size.
50             let memory = match ptr::NonNull::new(alloc::alloc(layout)) {
51                 Some(ptr) => ptr,
52                 None => alloc::handle_alloc_error(layout),
53             };
54 
55             // Safety: `memory` was created from `layout`.
56             let buf = slice::from_raw_parts_mut(memory.as_ptr(), layout.size());
57             let protocols = &mut *(buf.as_mut_ptr() as *mut sspi::SEC_APPLICATION_PROTOCOLS);
58             protocols.ProtocolListsSize =
59                 (SEC_APPLICATION_PROTOCOL_LIST_HEADER_SIZE + alpn_wire_format.len()) as ctypes::c_ulong;
60 
61             let protocol = &mut *protocols.ProtocolLists.as_mut_ptr();
62             protocol.ProtoNegoExt = sspi::SecApplicationProtocolNegotiationExt_ALPN;
63             protocol.ProtocolListSize = alpn_wire_format.len() as ctypes::c_ushort;
64 
65             let protocol_list_offset = protocol.ProtocolList.as_ptr() as usize - buf.as_ptr() as usize;
66             let protocol_list = &mut buf[protocol_list_offset..];
67             protocol_list.copy_from_slice(&alpn_wire_format);
68 
69             Self {
70                 layout,
71                 memory,
72             }
73         }
74     }
75 }
76 
77 impl std::ops::Deref for AlpnList {
78     type Target = [u8];
79 
deref(&self) -> &Self::Target80     fn deref(&self) -> &Self::Target {
81         unsafe {
82             // Safety: `self.memory` was created from `self.layout`.
83             slice::from_raw_parts(self.memory.as_ptr(), self.layout.size())
84         }
85     }
86 }
87 
88 impl std::ops::DerefMut for AlpnList {
deref_mut(&mut self) -> &mut Self::Target89     fn deref_mut(&mut self) -> &mut Self::Target {
90         unsafe {
91             // Safety: `self.memory` was created from `self.layout`.
92             slice::from_raw_parts_mut(self.memory.as_ptr(), self.layout.size())
93         }
94     }
95 }
96