1 use super::*;
2 use std::collections::BTreeMap;
3 
4 /// A reader of type information from Windows Metadata
5 pub struct TypeReader {
6     types: BTreeMap<&'static str, BTreeMap<&'static str, TypeRow>>,
7     nested: BTreeMap<tables::TypeDef, BTreeMap<&'static str, tables::TypeDef>>,
8 }
9 
10 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
11 enum TypeRow {
12     TypeDef(tables::TypeDef),
13     Function(tables::MethodDef),
14     Constant(tables::Field),
15 }
16 
17 impl TypeReader {
get() -> &'static Self18     pub fn get() -> &'static Self {
19         use std::{mem::MaybeUninit, sync::Once};
20         static ONCE: Once = Once::new();
21         static mut VALUE: MaybeUninit<TypeReader> = MaybeUninit::uninit();
22 
23         ONCE.call_once(|| {
24             // This is safe because `Once` provides thread-safe one-time initialization
25             unsafe { VALUE = MaybeUninit::new(Self::new()) }
26         });
27 
28         // This is safe because `call_once` has already been called.
29         unsafe { &*VALUE.as_ptr() }
30     }
31 
32     /// Insert WinRT metadata at the given paths
33     ///
34     /// # Panics
35     ///
36     /// This function panics if the if the files where the windows metadata are stored cannot be read.
new() -> Self37     fn new() -> Self {
38         let files = workspace_winmds();
39 
40         let mut types = BTreeMap::<&'static str, BTreeMap<&'static str, TypeRow>>::default();
41 
42         let mut nested =
43             BTreeMap::<tables::TypeDef, BTreeMap<&'static str, tables::TypeDef>>::new();
44 
45         for file in files {
46             let row_count = file.type_def_table().row_count;
47 
48             for row in 0..row_count {
49                 let def = tables::TypeDef(Row::new(row, TableIndex::TypeDef, file));
50                 let namespace = def.namespace();
51                 let name = trim_tick(def.name());
52 
53                 if namespace.is_empty() {
54                     continue;
55                 }
56 
57                 let flags = def.flags();
58                 let extends = def.extends();
59 
60                 if extends == ("System", "Attribute") {
61                     continue;
62                 }
63 
64                 types
65                     .entry(namespace)
66                     .or_default()
67                     .entry(name)
68                     .or_insert_with(|| TypeRow::TypeDef(def));
69 
70                 if flags.interface() || flags.windows_runtime() {
71                     continue;
72                 }
73 
74                 if extends == ("", "") {
75                     continue;
76                 }
77 
78                 if extends != ("System", "Object") {
79                     continue;
80                 }
81 
82                 for field in def.fields() {
83                     let name = field.name();
84 
85                     types
86                         .entry(namespace)
87                         .or_default()
88                         .entry(name)
89                         .or_insert_with(|| TypeRow::Constant(field));
90                 }
91 
92                 for method in def.methods() {
93                     let name = method.name();
94 
95                     types
96                         .entry(namespace)
97                         .or_default()
98                         .entry(name)
99                         .or_insert_with(|| TypeRow::Function(method));
100                 }
101             }
102 
103             let row_count = file.nested_class_table().row_count;
104 
105             for row in 0..row_count {
106                 let row = tables::NestedClass(Row::new(row, TableIndex::NestedClass, file));
107                 let enclosed = row.nested_type();
108                 let enclosing = row.enclosing_type();
109                 let name = enclosed.name();
110 
111                 nested.entry(enclosing).or_default().insert(name, enclosed);
112             }
113         }
114 
115         let exclude = &[
116             ("Windows.Foundation", "HResult"),
117             ("Windows.Win32.Com", "HRESULT"),
118             ("Windows.Win32.Com", "IUnknown"),
119             ("Windows.Win32.WinRT", "HSTRING"),
120             ("Windows.Win32.WinRT", "IActivationFactory"),
121             ("Windows.Win32.Direct2D", "D2D_MATRIX_3X2_F"),
122             ("Windows.Win32.SystemServices", "LARGE_INTEGER"),
123             ("Windows.Win32.SystemServices", "ULARGE_INTEGER"),
124         ];
125 
126         for (namespace, name) in exclude {
127             if let Some(value) = types.get_mut(*namespace) {
128                 value.remove(*name);
129             }
130         }
131 
132         Self { types, nested }
133     }
134 
resolve_namespace(&'static self, find: &str) -> &'static str135     pub fn resolve_namespace(&'static self, find: &str) -> &'static str {
136         self.types
137             .keys()
138             .find(|namespace| *namespace == &find)
139             .expect(&format!("Could not find namespace `{}`", find))
140     }
141 
142     /// Get all the namespace names that the [`TypeReader`] knows about
namespaces(&'static self) -> impl Iterator<Item = &'static str>143     pub fn namespaces(&'static self) -> impl Iterator<Item = &'static str> {
144         self.types.keys().copied()
145     }
146 
147     /// Get all types for a given namespace
148     ///
149     /// # Panics
150     ///
151     /// Panics if the namespace does not exist
namespace_types( &'static self, namespace: &str, ) -> impl Iterator<Item = ElementType> + '_152     pub fn namespace_types(
153         &'static self,
154         namespace: &str,
155     ) -> impl Iterator<Item = ElementType> + '_ {
156         self.types[namespace]
157             .values()
158             .map(move |row| self.to_element_type(row))
159     }
160 
161     // TODO: how to make this return an iterator?
nested_types(&'static self, enclosing: &tables::TypeDef) -> Vec<tables::TypeDef>162     pub fn nested_types(&'static self, enclosing: &tables::TypeDef) -> Vec<tables::TypeDef> {
163         self.nested
164             .get(enclosing)
165             .iter()
166             .flat_map(|t| t.values())
167             .copied()
168             .collect()
169     }
170 
resolve_type(&'static self, namespace: &str, name: &str) -> ElementType171     pub fn resolve_type(&'static self, namespace: &str, name: &str) -> ElementType {
172         if let Some(types) = self.types.get(namespace) {
173             if let Some(row) = types.get(trim_tick(name)) {
174                 return self.to_element_type(row);
175             }
176         }
177 
178         panic!("Could not find type `{}.{}`", namespace, name);
179     }
180 
to_element_type(&'static self, row: &TypeRow) -> ElementType181     fn to_element_type(&'static self, row: &TypeRow) -> ElementType {
182         match row {
183             TypeRow::TypeDef(row) => ElementType::from_type_def(*row, Vec::new()),
184             TypeRow::Function(row) => ElementType::Function(types::Function(*row)),
185             TypeRow::Constant(row) => ElementType::Constant(types::Constant(*row)),
186         }
187     }
188 
resolve_type_def(&'static self, namespace: &str, name: &str) -> tables::TypeDef189     pub fn resolve_type_def(&'static self, namespace: &str, name: &str) -> tables::TypeDef {
190         if let Some(types) = self.types.get(namespace) {
191             if let Some(TypeRow::TypeDef(row)) = types.get(trim_tick(name)) {
192                 return *row;
193             }
194         }
195 
196         panic!("Could not find type def `{}.{}`", namespace, name);
197     }
198 
resolve_type_ref(&'static self, type_ref: &tables::TypeRef) -> tables::TypeDef199     pub fn resolve_type_ref(&'static self, type_ref: &tables::TypeRef) -> tables::TypeDef {
200         if let ResolutionScope::TypeRef(scope) = type_ref.scope() {
201             self.nested[&scope.resolve()][type_ref.name()]
202         } else {
203             self.resolve_type_def(type_ref.namespace(), type_ref.name())
204         }
205     }
206 
207     #[cfg(test)]
get_class(namespace: &str, name: &str) -> types::Class208     pub fn get_class(namespace: &str, name: &str) -> types::Class {
209         if let ElementType::Class(value) = Self::get().resolve_type(namespace, name) {
210             value.clone()
211         } else {
212             unexpected!();
213         }
214     }
215 
216     #[cfg(test)]
get_struct(namespace: &str, name: &str) -> types::Struct217     pub fn get_struct(namespace: &str, name: &str) -> types::Struct {
218         if let ElementType::Struct(value) = Self::get().resolve_type(namespace, name) {
219             value.clone()
220         } else {
221             unexpected!();
222         }
223     }
224 
225     #[cfg(test)]
get_enum(namespace: &str, name: &str) -> types::Enum226     pub fn get_enum(namespace: &str, name: &str) -> types::Enum {
227         if let ElementType::Enum(value) = Self::get().resolve_type(namespace, name) {
228             value.clone()
229         } else {
230             unexpected!();
231         }
232     }
233 
234     #[cfg(test)]
get_interface(namespace: &str, name: &str) -> types::Interface235     pub fn get_interface(namespace: &str, name: &str) -> types::Interface {
236         if let ElementType::Interface(value) = Self::get().resolve_type(namespace, name) {
237             value.clone()
238         } else {
239             unexpected!();
240         }
241     }
242 }
243 
trim_tick(name: &str) -> &str244 fn trim_tick(name: &str) -> &str {
245     match name.as_bytes().get(name.len() - 2) {
246         Some(c) if *c == b'`' => &name[..name.len() - 2],
247         _ => name,
248     }
249 }
250