1 use super::*;
2 
3 // TODO: distinguish between COM and WinRT interfaces
4 struct Implements(Vec<gen::ElementType>);
5 
6 impl syn::parse::Parse for Implements {
parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self>7     fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
8         let mut types = Vec::new();
9         let reader = gen::TypeReader::get();
10 
11         loop {
12             if input.is_empty() {
13                 break;
14             }
15 
16             use_tree_to_types(reader, &input.parse::<ImplementTree>()?, &mut types)?;
17 
18             if !input.is_empty() {
19                 input.parse::<syn::Token![,]>()?;
20             }
21         }
22 
23         Ok(Self(types))
24     }
25 }
26 
use_tree_to_types( reader: &'static gen::TypeReader, tree: &ImplementTree, types: &mut Vec<gen::ElementType>, ) -> syn::parse::Result<()>27 fn use_tree_to_types(
28     reader: &'static gen::TypeReader,
29     tree: &ImplementTree,
30     types: &mut Vec<gen::ElementType>,
31 ) -> syn::parse::Result<()> {
32     fn recurse(
33         reader: &'static gen::TypeReader,
34         tree: &ImplementTree,
35         types: &mut Vec<gen::ElementType>,
36         current: &mut String,
37     ) -> syn::parse::Result<()> {
38         match tree {
39             ImplementTree::Path(path) => {
40                 if !current.is_empty() {
41                     current.push('.');
42                 }
43 
44                 current.push_str(&path.ident.to_string());
45                 recurse(reader, &*path.tree, types, current)?;
46             }
47             ImplementTree::Group(group) => {
48                 let prev = current.clone();
49 
50                 for tree in &group.items {
51                     recurse(reader, tree, types, current)?;
52                     *current = prev.clone();
53                 }
54             }
55             ImplementTree::Name(name) => {
56                 let namespace = current.trim_matches('"');
57 
58                 let mut meta_name = name.ident.to_string();
59                 let generic_count = name.generics.params.len();
60 
61                 if generic_count > 0 {
62                     meta_name.push('`');
63                     meta_name.push_str(&generic_count.to_string());
64                 }
65 
66                 types.push(reader.resolve_type(namespace, &meta_name));
67 
68                 // TODO
69                 // If type is a class, add any required interfaces.
70                 // If type is an interface, add any required interfaces.
71                 // If any other kind of type, return an error.
72                 // If more than one class, return an error.
73                 // If dupe interface, produce warning but continue,
74                 //   unless warning is unavoidable (same interface required by different mentioned interfaces)
75                 // Finally, remove any dupes (TypeName can be used as key for set container)
76             }
77         }
78 
79         Ok(())
80     }
81 
82     recurse(reader, tree, types, &mut String::new())
83 }
84 
gen( attribute: proc_macro::TokenStream, original_type: proc_macro::TokenStream, ) -> proc_macro::TokenStream85 pub fn gen(
86     attribute: proc_macro::TokenStream,
87     original_type: proc_macro::TokenStream,
88 ) -> proc_macro::TokenStream {
89     let inner_type = original_type.clone();
90 
91     let implements = syn::parse_macro_input!(attribute as Implements);
92     let inner_type = syn::parse_macro_input!(inner_type as syn::ItemStruct);
93     let inner_name = inner_type.ident.to_string();
94     let inner_ident = format_ident!("{}", inner_name); // because squote doesn't know how to deal with syn::*
95     let box_ident = format_ident!("{}_box", inner_name);
96 
97     let mut tokens = TokenStream::new();
98     let mut vtable_idents = vec![];
99     let mut vtable_ordinals = vec![];
100     let mut vtable_ctors = TokenStream::new();
101     let mut shims = TokenStream::new();
102     let mut queries = TokenStream::new();
103 
104     for (interface_count, implement) in implements.0.iter().enumerate() {
105         if let gen::ElementType::Interface(t) = implement {
106             vtable_ordinals.push(Literal::usize_unsuffixed(interface_count));
107 
108             let query_interface = format_ident!("QueryInterface_abi{}", interface_count);
109             let add_ref = format_ident!("AddRef_abi{}", interface_count);
110             let release = format_ident!("Release_abi{}", interface_count);
111 
112             let mut vtable_ptrs = quote! {
113                 Self::#query_interface,
114                 Self::#add_ref,
115                 Self::#release,
116                 Self::GetIids,
117                 Self::GetRuntimeClassName,
118                 Self::GetTrustLevel,
119             };
120 
121             shims.combine(&quote! {
122                 unsafe extern "system" fn #query_interface(this: ::windows::RawPtr, iid: &::windows::Guid, interface: *mut ::windows::RawPtr) -> ::windows::HRESULT {
123                     let this = (this as *mut ::windows::RawPtr).sub(#interface_count) as *mut Self;
124                     (*this).QueryInterface(iid, interface)
125                 }
126                 unsafe extern "system" fn #add_ref(this: ::windows::RawPtr) -> u32 {
127                     let this = (this as *mut ::windows::RawPtr).sub(#interface_count) as *mut Self;
128                     (*this).AddRef()
129                 }
130                 unsafe extern "system" fn #release(this: ::windows::RawPtr) -> u32 {
131                     let this = (this as *mut ::windows::RawPtr).sub(#interface_count) as *mut Self;
132                     (*this).Release()
133                 }
134             });
135 
136             let empty = gen::TypeTree::from_namespace("");
137             let gen = gen::Gen::absolute(&empty);
138 
139             let vtable_ident = t.0.gen_abi_name(&gen);
140             let interface_ident = t.0.gen_name(&gen);
141             let interface_literal = Literal::usize_unsuffixed(interface_count);
142 
143             for (vtable_offset, method) in t.0.def.methods().enumerate() {
144                 let method_ident = gen::to_ident(&method.rust_name());
145                 let vcall_ident = format_ident!("abi{}_{}", interface_count, vtable_offset + 6);
146 
147                 vtable_ptrs.combine(&quote! {
148                     Self::#vcall_ident,
149                 });
150 
151                 let signature = method.signature(&[]);
152                 let abi_signature = signature.gen_winrt_abi(&gen);
153                 let upcall =
154                     signature.gen_winrt_upcall(quote! { (*this).inner.#method_ident }, &gen);
155 
156                 shims.combine(&quote! {
157                     unsafe extern "system" fn #vcall_ident #abi_signature {
158                         let this = (this as *mut ::windows::RawPtr).sub(#interface_count) as *mut Self;
159                         #upcall
160                     }
161                 });
162 
163                 queries.combine(&quote! {
164                     &<#interface_ident as ::windows::Interface>::IID => {
165                         &mut self.vtable.#interface_literal as *mut _ as _
166                     }
167                 });
168             }
169 
170             tokens.combine(&quote! {
171                 impl ::std::convert::From<#inner_ident> for #interface_ident {
172                     fn from(inner: #inner_ident) -> Self {
173                         let com = #box_ident::new(inner);
174 
175                         unsafe {
176                             let ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(com));
177                             ::std::mem::transmute_copy(&::std::ptr::NonNull::new_unchecked(&mut (*ptr).vtable.#interface_literal as *mut _ as _))
178                         }
179                     }
180                 }
181             });
182 
183             vtable_ctors.combine(&quote! {
184                 #vtable_ident(
185                     #vtable_ptrs
186                 ),
187             });
188 
189             vtable_idents.push(vtable_ident);
190         }
191     }
192 
193     tokens.combine(&quote! {
194         #[repr(C)]
195         struct #box_ident {
196             vtable: (#(*const #vtable_idents,)*),
197             inner: #inner_ident,
198             count: ::windows::RefCount,
199         }
200         impl #box_ident {
201             const VTABLE: (#(#vtable_idents,)*) = (
202                 #vtable_ctors
203             );
204             fn new(inner: #inner_ident) -> Self {
205                 Self {
206                     vtable: (#(&Self::VTABLE.#vtable_ordinals,)*),
207                     inner,
208                     count: ::windows::RefCount::new()
209                 }
210             }
211             fn QueryInterface(&mut self, iid: &::windows::Guid, interface: *mut ::windows::RawPtr) -> ::windows::HRESULT {
212                 unsafe {
213                     *interface = match iid {
214                         #queries
215                         &<::windows::IUnknown as ::windows::Interface>::IID
216                         | &<::windows::Object as ::windows::Interface>::IID
217                         | &<::windows::IAgileObject as ::windows::Interface>::IID => {
218                             &mut self.vtable.0 as *mut _ as _
219                         }
220                         _ => ::std::ptr::null_mut(),
221                     };
222 
223                     if (*interface).is_null() {
224                         ::windows::HRESULT(0x8000_4002) // E_NOINTERFACE
225                     } else {
226                         self.count.add_ref();
227                         ::windows::HRESULT(0)
228                     }
229                 }
230             }
231             fn AddRef(&mut self) -> u32 {
232                 self.count.add_ref()
233             }
234             fn Release(&mut self) -> u32 {
235                 let remaining = self.count.release();
236                 if remaining == 0 {
237                     unsafe {
238                         ::std::boxed::Box::from_raw(self);
239                     }
240                 }
241                 remaining
242             }
243             unsafe extern "system" fn GetIids(
244                 _: ::windows::RawPtr,
245                 count: *mut u32,
246                 values: *mut *mut ::windows::Guid,
247             ) -> ::windows::HRESULT {
248                 // Note: even if we end up implementing this in future, it still doesn't need a this pointer
249                 // since the data to be returned is type- not instance-specific so can be shared for all
250                 // interfaces.
251                 *count = 0;
252                 *values = ::std::ptr::null_mut();
253                 ::windows::HRESULT(0)
254             }
255             unsafe extern "system" fn GetRuntimeClassName(
256                 _: ::windows::RawPtr,
257                 value: *mut ::windows::RawPtr,
258             ) -> ::windows::HRESULT {
259                 let h: ::windows::HString = "Thing".into(); // TODO: replace with class name or first interface
260                 *value = ::std::mem::transmute(h);
261                 ::windows::HRESULT(0)
262             }
263             unsafe extern "system" fn GetTrustLevel(_: ::windows::RawPtr, value: *mut i32) -> ::windows::HRESULT {
264                 // Note: even if we end up implementing this in future, it still doesn't need a this pointer
265                 // since the data to be returned is type- not instance-specific so can be shared for all
266                 // interfaces.
267                 *value = 0;
268                 ::windows::HRESULT(0)
269             }
270             #shims
271         }
272     });
273 
274     let mut tokens = tokens.parse::<proc_macro::TokenStream>().unwrap();
275     tokens.extend(std::iter::once(original_type));
276     tokens
277 }
278