1 use super::*;
2 use std::collections::BTreeMap;
3
4 /// A reader of type information from Windows Metadata
5 pub struct TypeReader {
6 types: BTreeMap<&'static str, BTreeMap<&'static str, TypeRow>>,
7 nested: BTreeMap<tables::TypeDef, BTreeMap<&'static str, tables::TypeDef>>,
8 }
9
10 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
11 enum TypeRow {
12 TypeDef(tables::TypeDef),
13 Function(tables::MethodDef),
14 Constant(tables::Field),
15 }
16
17 impl TypeReader {
get() -> &'static Self18 pub fn get() -> &'static Self {
19 use std::{mem::MaybeUninit, sync::Once};
20 static ONCE: Once = Once::new();
21 static mut VALUE: MaybeUninit<TypeReader> = MaybeUninit::uninit();
22
23 ONCE.call_once(|| {
24 // This is safe because `Once` provides thread-safe one-time initialization
25 unsafe { VALUE = MaybeUninit::new(Self::new()) }
26 });
27
28 // This is safe because `call_once` has already been called.
29 unsafe { &*VALUE.as_ptr() }
30 }
31
32 /// Insert WinRT metadata at the given paths
33 ///
34 /// # Panics
35 ///
36 /// This function panics if the if the files where the windows metadata are stored cannot be read.
new() -> Self37 fn new() -> Self {
38 let files = workspace_winmds();
39
40 let mut types = BTreeMap::<&'static str, BTreeMap<&'static str, TypeRow>>::default();
41
42 let mut nested =
43 BTreeMap::<tables::TypeDef, BTreeMap<&'static str, tables::TypeDef>>::new();
44
45 for file in files {
46 let row_count = file.type_def_table().row_count;
47
48 for row in 0..row_count {
49 let def = tables::TypeDef(Row::new(row, TableIndex::TypeDef, file));
50 let namespace = def.namespace();
51 let name = trim_tick(def.name());
52
53 if namespace.is_empty() {
54 continue;
55 }
56
57 let flags = def.flags();
58 let extends = def.extends();
59
60 if extends == ("System", "Attribute") {
61 continue;
62 }
63
64 types
65 .entry(namespace)
66 .or_default()
67 .entry(name)
68 .or_insert_with(|| TypeRow::TypeDef(def));
69
70 if flags.interface() || flags.windows_runtime() {
71 continue;
72 }
73
74 if extends == ("", "") {
75 continue;
76 }
77
78 if extends != ("System", "Object") {
79 continue;
80 }
81
82 for field in def.fields() {
83 let name = field.name();
84
85 types
86 .entry(namespace)
87 .or_default()
88 .entry(name)
89 .or_insert_with(|| TypeRow::Constant(field));
90 }
91
92 for method in def.methods() {
93 let name = method.name();
94
95 types
96 .entry(namespace)
97 .or_default()
98 .entry(name)
99 .or_insert_with(|| TypeRow::Function(method));
100 }
101 }
102
103 let row_count = file.nested_class_table().row_count;
104
105 for row in 0..row_count {
106 let row = tables::NestedClass(Row::new(row, TableIndex::NestedClass, file));
107 let enclosed = row.nested_type();
108 let enclosing = row.enclosing_type();
109 let name = enclosed.name();
110
111 nested.entry(enclosing).or_default().insert(name, enclosed);
112 }
113 }
114
115 let exclude = &[
116 ("Windows.Foundation", "HResult"),
117 ("Windows.Win32.Com", "HRESULT"),
118 ("Windows.Win32.Com", "IUnknown"),
119 ("Windows.Win32.WinRT", "HSTRING"),
120 ("Windows.Win32.WinRT", "IActivationFactory"),
121 ("Windows.Win32.Direct2D", "D2D_MATRIX_3X2_F"),
122 ("Windows.Win32.SystemServices", "LARGE_INTEGER"),
123 ("Windows.Win32.SystemServices", "ULARGE_INTEGER"),
124 ];
125
126 for (namespace, name) in exclude {
127 if let Some(value) = types.get_mut(*namespace) {
128 value.remove(*name);
129 }
130 }
131
132 Self { types, nested }
133 }
134
resolve_namespace(&'static self, find: &str) -> &'static str135 pub fn resolve_namespace(&'static self, find: &str) -> &'static str {
136 self.types
137 .keys()
138 .find(|namespace| *namespace == &find)
139 .expect(&format!("Could not find namespace `{}`", find))
140 }
141
142 /// Get all the namespace names that the [`TypeReader`] knows about
namespaces(&'static self) -> impl Iterator<Item = &'static str>143 pub fn namespaces(&'static self) -> impl Iterator<Item = &'static str> {
144 self.types.keys().copied()
145 }
146
147 /// Get all types for a given namespace
148 ///
149 /// # Panics
150 ///
151 /// Panics if the namespace does not exist
namespace_types( &'static self, namespace: &str, ) -> impl Iterator<Item = ElementType> + '_152 pub fn namespace_types(
153 &'static self,
154 namespace: &str,
155 ) -> impl Iterator<Item = ElementType> + '_ {
156 self.types[namespace]
157 .values()
158 .map(move |row| self.to_element_type(row))
159 }
160
161 // TODO: how to make this return an iterator?
nested_types(&'static self, enclosing: &tables::TypeDef) -> Vec<tables::TypeDef>162 pub fn nested_types(&'static self, enclosing: &tables::TypeDef) -> Vec<tables::TypeDef> {
163 self.nested
164 .get(enclosing)
165 .iter()
166 .flat_map(|t| t.values())
167 .copied()
168 .collect()
169 }
170
resolve_type(&'static self, namespace: &str, name: &str) -> ElementType171 pub fn resolve_type(&'static self, namespace: &str, name: &str) -> ElementType {
172 if let Some(types) = self.types.get(namespace) {
173 if let Some(row) = types.get(trim_tick(name)) {
174 return self.to_element_type(row);
175 }
176 }
177
178 panic!("Could not find type `{}.{}`", namespace, name);
179 }
180
to_element_type(&'static self, row: &TypeRow) -> ElementType181 fn to_element_type(&'static self, row: &TypeRow) -> ElementType {
182 match row {
183 TypeRow::TypeDef(row) => ElementType::from_type_def(*row, Vec::new()),
184 TypeRow::Function(row) => ElementType::Function(types::Function(*row)),
185 TypeRow::Constant(row) => ElementType::Constant(types::Constant(*row)),
186 }
187 }
188
resolve_type_def(&'static self, namespace: &str, name: &str) -> tables::TypeDef189 pub fn resolve_type_def(&'static self, namespace: &str, name: &str) -> tables::TypeDef {
190 if let Some(types) = self.types.get(namespace) {
191 if let Some(TypeRow::TypeDef(row)) = types.get(trim_tick(name)) {
192 return *row;
193 }
194 }
195
196 panic!("Could not find type def `{}.{}`", namespace, name);
197 }
198
resolve_type_ref(&'static self, type_ref: &tables::TypeRef) -> tables::TypeDef199 pub fn resolve_type_ref(&'static self, type_ref: &tables::TypeRef) -> tables::TypeDef {
200 if let ResolutionScope::TypeRef(scope) = type_ref.scope() {
201 self.nested[&scope.resolve()][type_ref.name()]
202 } else {
203 self.resolve_type_def(type_ref.namespace(), type_ref.name())
204 }
205 }
206
207 #[cfg(test)]
get_class(namespace: &str, name: &str) -> types::Class208 pub fn get_class(namespace: &str, name: &str) -> types::Class {
209 if let ElementType::Class(value) = Self::get().resolve_type(namespace, name) {
210 value.clone()
211 } else {
212 unexpected!();
213 }
214 }
215
216 #[cfg(test)]
get_struct(namespace: &str, name: &str) -> types::Struct217 pub fn get_struct(namespace: &str, name: &str) -> types::Struct {
218 if let ElementType::Struct(value) = Self::get().resolve_type(namespace, name) {
219 value.clone()
220 } else {
221 unexpected!();
222 }
223 }
224
225 #[cfg(test)]
get_enum(namespace: &str, name: &str) -> types::Enum226 pub fn get_enum(namespace: &str, name: &str) -> types::Enum {
227 if let ElementType::Enum(value) = Self::get().resolve_type(namespace, name) {
228 value.clone()
229 } else {
230 unexpected!();
231 }
232 }
233
234 #[cfg(test)]
get_interface(namespace: &str, name: &str) -> types::Interface235 pub fn get_interface(namespace: &str, name: &str) -> types::Interface {
236 if let ElementType::Interface(value) = Self::get().resolve_type(namespace, name) {
237 value.clone()
238 } else {
239 unexpected!();
240 }
241 }
242 }
243
trim_tick(name: &str) -> &str244 fn trim_tick(name: &str) -> &str {
245 match name.as_bytes().get(name.len() - 2) {
246 Some(c) if *c == b'`' => &name[..name.len() - 2],
247 _ => name,
248 }
249 }
250