1 //! Database used for testing `hir_def`.
2 
3 use std::{
4     fmt, panic,
5     sync::{Arc, Mutex},
6 };
7 
8 use base_db::{
9     salsa, AnchoredPath, CrateId, FileId, FileLoader, FileLoaderDelegate, FilePosition,
10     SourceDatabase, Upcast,
11 };
12 use hir_expand::{db::AstDatabase, InFile};
13 use rustc_hash::FxHashSet;
14 use syntax::{algo, ast, AstNode};
15 
16 use crate::{
17     db::DefDatabase,
18     nameres::{DefMap, ModuleSource},
19     src::HasSource,
20     LocalModuleId, Lookup, ModuleDefId, ModuleId,
21 };
22 
23 #[salsa::database(
24     base_db::SourceDatabaseExtStorage,
25     base_db::SourceDatabaseStorage,
26     hir_expand::db::AstDatabaseStorage,
27     crate::db::InternDatabaseStorage,
28     crate::db::DefDatabaseStorage
29 )]
30 pub(crate) struct TestDB {
31     storage: salsa::Storage<TestDB>,
32     events: Mutex<Option<Vec<salsa::Event>>>,
33 }
34 
35 impl Default for TestDB {
default() -> Self36     fn default() -> Self {
37         let mut this = Self { storage: Default::default(), events: Default::default() };
38         this.set_enable_proc_attr_macros(true);
39         this
40     }
41 }
42 
43 impl Upcast<dyn AstDatabase> for TestDB {
upcast(&self) -> &(dyn AstDatabase + 'static)44     fn upcast(&self) -> &(dyn AstDatabase + 'static) {
45         &*self
46     }
47 }
48 
49 impl Upcast<dyn DefDatabase> for TestDB {
upcast(&self) -> &(dyn DefDatabase + 'static)50     fn upcast(&self) -> &(dyn DefDatabase + 'static) {
51         &*self
52     }
53 }
54 
55 impl salsa::Database for TestDB {
salsa_event(&self, event: salsa::Event)56     fn salsa_event(&self, event: salsa::Event) {
57         let mut events = self.events.lock().unwrap();
58         if let Some(events) = &mut *events {
59             events.push(event);
60         }
61     }
62 }
63 
64 impl fmt::Debug for TestDB {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result65     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66         f.debug_struct("TestDB").finish()
67     }
68 }
69 
70 impl panic::RefUnwindSafe for TestDB {}
71 
72 impl FileLoader for TestDB {
file_text(&self, file_id: FileId) -> Arc<String>73     fn file_text(&self, file_id: FileId) -> Arc<String> {
74         FileLoaderDelegate(self).file_text(file_id)
75     }
resolve_path(&self, path: AnchoredPath) -> Option<FileId>76     fn resolve_path(&self, path: AnchoredPath) -> Option<FileId> {
77         FileLoaderDelegate(self).resolve_path(path)
78     }
relevant_crates(&self, file_id: FileId) -> Arc<FxHashSet<CrateId>>79     fn relevant_crates(&self, file_id: FileId) -> Arc<FxHashSet<CrateId>> {
80         FileLoaderDelegate(self).relevant_crates(file_id)
81     }
82 }
83 
84 impl TestDB {
module_for_file(&self, file_id: FileId) -> ModuleId85     pub(crate) fn module_for_file(&self, file_id: FileId) -> ModuleId {
86         for &krate in self.relevant_crates(file_id).iter() {
87             let crate_def_map = self.crate_def_map(krate);
88             for (local_id, data) in crate_def_map.modules() {
89                 if data.origin.file_id() == Some(file_id) {
90                     return crate_def_map.module_id(local_id);
91                 }
92             }
93         }
94         panic!("Can't find module for file")
95     }
96 
module_at_position(&self, position: FilePosition) -> ModuleId97     pub(crate) fn module_at_position(&self, position: FilePosition) -> ModuleId {
98         let file_module = self.module_for_file(position.file_id);
99         let mut def_map = file_module.def_map(self);
100         let module = self.mod_at_position(&def_map, position);
101 
102         def_map = match self.block_at_position(&def_map, position) {
103             Some(it) => it,
104             None => return def_map.module_id(module),
105         };
106         loop {
107             let new_map = self.block_at_position(&def_map, position);
108             match new_map {
109                 Some(new_block) if !Arc::ptr_eq(&new_block, &def_map) => {
110                     def_map = new_block;
111                 }
112                 _ => {
113                     // FIXME: handle `mod` inside block expression
114                     return def_map.module_id(def_map.root());
115                 }
116             }
117         }
118     }
119 
120     /// Finds the smallest/innermost module in `def_map` containing `position`.
mod_at_position(&self, def_map: &DefMap, position: FilePosition) -> LocalModuleId121     fn mod_at_position(&self, def_map: &DefMap, position: FilePosition) -> LocalModuleId {
122         let mut size = None;
123         let mut res = def_map.root();
124         for (module, data) in def_map.modules() {
125             let src = data.definition_source(self);
126             if src.file_id != position.file_id.into() {
127                 continue;
128             }
129 
130             let range = match src.value {
131                 ModuleSource::SourceFile(it) => it.syntax().text_range(),
132                 ModuleSource::Module(it) => it.syntax().text_range(),
133                 ModuleSource::BlockExpr(it) => it.syntax().text_range(),
134             };
135 
136             if !range.contains(position.offset) {
137                 continue;
138             }
139 
140             let new_size = match size {
141                 None => range.len(),
142                 Some(size) => {
143                     if range.len() < size {
144                         range.len()
145                     } else {
146                         size
147                     }
148                 }
149             };
150 
151             if size != Some(new_size) {
152                 cov_mark::hit!(submodule_in_testdb);
153                 size = Some(new_size);
154                 res = module;
155             }
156         }
157 
158         res
159     }
160 
block_at_position(&self, def_map: &DefMap, position: FilePosition) -> Option<Arc<DefMap>>161     fn block_at_position(&self, def_map: &DefMap, position: FilePosition) -> Option<Arc<DefMap>> {
162         // Find the smallest (innermost) function in `def_map` containing the cursor.
163         let mut size = None;
164         let mut fn_def = None;
165         for (_, module) in def_map.modules() {
166             let file_id = module.definition_source(self).file_id;
167             if file_id != position.file_id.into() {
168                 continue;
169             }
170             for decl in module.scope.declarations() {
171                 if let ModuleDefId::FunctionId(it) = decl {
172                     let range = it.lookup(self).source(self).value.syntax().text_range();
173 
174                     if !range.contains(position.offset) {
175                         continue;
176                     }
177 
178                     let new_size = match size {
179                         None => range.len(),
180                         Some(size) => {
181                             if range.len() < size {
182                                 range.len()
183                             } else {
184                                 size
185                             }
186                         }
187                     };
188                     if size != Some(new_size) {
189                         size = Some(new_size);
190                         fn_def = Some(it);
191                     }
192                 }
193             }
194         }
195 
196         // Find the innermost block expression that has a `DefMap`.
197         let def_with_body = fn_def?.into();
198         let (_, source_map) = self.body_with_source_map(def_with_body);
199         let scopes = self.expr_scopes(def_with_body);
200         let root = self.parse(position.file_id);
201 
202         let scope_iter = algo::ancestors_at_offset(&root.syntax_node(), position.offset)
203             .filter_map(|node| {
204                 let block = ast::BlockExpr::cast(node)?;
205                 let expr = ast::Expr::from(block);
206                 let expr_id = source_map.node_expr(InFile::new(position.file_id.into(), &expr))?;
207                 let scope = scopes.scope_for(expr_id).unwrap();
208                 Some(scope)
209             });
210 
211         for scope in scope_iter {
212             let containing_blocks =
213                 scopes.scope_chain(Some(scope)).filter_map(|scope| scopes.block(scope));
214 
215             for block in containing_blocks {
216                 if let Some(def_map) = self.block_def_map(block) {
217                     return Some(def_map);
218                 }
219             }
220         }
221 
222         None
223     }
224 
log(&self, f: impl FnOnce()) -> Vec<salsa::Event>225     pub(crate) fn log(&self, f: impl FnOnce()) -> Vec<salsa::Event> {
226         *self.events.lock().unwrap() = Some(Vec::new());
227         f();
228         self.events.lock().unwrap().take().unwrap()
229     }
230 
log_executed(&self, f: impl FnOnce()) -> Vec<String>231     pub(crate) fn log_executed(&self, f: impl FnOnce()) -> Vec<String> {
232         let events = self.log(f);
233         events
234             .into_iter()
235             .filter_map(|e| match e.kind {
236                 // This is pretty horrible, but `Debug` is the only way to inspect
237                 // QueryDescriptor at the moment.
238                 salsa::EventKind::WillExecute { database_key } => {
239                     Some(format!("{:?}", database_key.debug(self)))
240                 }
241                 _ => None,
242             })
243             .collect()
244     }
245 }
246