1 //! Helper functions for working with def, which don't need to be a separate
2 //! query, but can't be computed directly from `*Data` (ie, which need a `db`).
3 
4 use std::iter;
5 
6 use base_db::CrateId;
7 use chalk_ir::{fold::Shift, BoundVar, DebruijnIndex};
8 use hir_def::{
9     db::DefDatabase,
10     generics::{
11         GenericParams, TypeParamData, TypeParamProvenance, WherePredicate, WherePredicateTypeTarget,
12     },
13     intern::Interned,
14     path::Path,
15     resolver::{HasResolver, TypeNs},
16     type_ref::{TraitBoundModifier, TypeRef},
17     AssocContainerId, GenericDefId, Lookup, TraitId, TypeAliasId, TypeParamId,
18 };
19 use hir_expand::name::{name, Name};
20 use rustc_hash::FxHashSet;
21 
22 use crate::{
23     db::HirDatabase, ChalkTraitId, Interner, Substitution, TraitRef, TraitRefExt, TyKind,
24     WhereClause,
25 };
26 
fn_traits(db: &dyn DefDatabase, krate: CrateId) -> impl Iterator<Item = TraitId>27 pub(crate) fn fn_traits(db: &dyn DefDatabase, krate: CrateId) -> impl Iterator<Item = TraitId> {
28     [
29         db.lang_item(krate, "fn".into()),
30         db.lang_item(krate, "fn_mut".into()),
31         db.lang_item(krate, "fn_once".into()),
32     ]
33     .into_iter()
34     .flatten()
35     .flat_map(|it| it.as_trait())
36 }
37 
direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId>38 fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
39     let resolver = trait_.resolver(db);
40     // returning the iterator directly doesn't easily work because of
41     // lifetime problems, but since there usually shouldn't be more than a
42     // few direct traits this should be fine (we could even use some kind of
43     // SmallVec if performance is a concern)
44     let generic_params = db.generic_params(trait_.into());
45     let trait_self = generic_params.find_trait_self_param();
46     generic_params
47         .where_predicates
48         .iter()
49         .filter_map(|pred| match pred {
50             WherePredicate::ForLifetime { target, bound, .. }
51             | WherePredicate::TypeBound { target, bound } => match target {
52                 WherePredicateTypeTarget::TypeRef(type_ref) => match &**type_ref {
53                     TypeRef::Path(p) if p == &Path::from(name![Self]) => bound.as_path(),
54                     _ => None,
55                 },
56                 WherePredicateTypeTarget::TypeParam(local_id) if Some(*local_id) == trait_self => {
57                     bound.as_path()
58                 }
59                 _ => None,
60             },
61             WherePredicate::Lifetime { .. } => None,
62         })
63         .filter_map(|(path, bound_modifier)| match bound_modifier {
64             TraitBoundModifier::None => Some(path),
65             TraitBoundModifier::Maybe => None,
66         })
67         .filter_map(|path| match resolver.resolve_path_in_type_ns_fully(db, path.mod_path()) {
68             Some(TypeNs::TraitId(t)) => Some(t),
69             _ => None,
70         })
71         .collect()
72 }
73 
direct_super_trait_refs(db: &dyn HirDatabase, trait_ref: &TraitRef) -> Vec<TraitRef>74 fn direct_super_trait_refs(db: &dyn HirDatabase, trait_ref: &TraitRef) -> Vec<TraitRef> {
75     // returning the iterator directly doesn't easily work because of
76     // lifetime problems, but since there usually shouldn't be more than a
77     // few direct traits this should be fine (we could even use some kind of
78     // SmallVec if performance is a concern)
79     let generic_params = db.generic_params(trait_ref.hir_trait_id().into());
80     let trait_self = match generic_params.find_trait_self_param() {
81         Some(p) => TypeParamId { parent: trait_ref.hir_trait_id().into(), local_id: p },
82         None => return Vec::new(),
83     };
84     db.generic_predicates_for_param(trait_self, None)
85         .iter()
86         .filter_map(|pred| {
87             pred.as_ref().filter_map(|pred| match pred.skip_binders() {
88                 // FIXME: how to correctly handle higher-ranked bounds here?
89                 WhereClause::Implemented(tr) => Some(
90                     tr.clone()
91                         .shifted_out_to(&Interner, DebruijnIndex::ONE)
92                         .expect("FIXME unexpected higher-ranked trait bound"),
93                 ),
94                 _ => None,
95             })
96         })
97         .map(|pred| pred.substitute(&Interner, &trait_ref.substitution))
98         .collect()
99 }
100 
101 /// Returns an iterator over the whole super trait hierarchy (including the
102 /// trait itself).
all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId>103 pub fn all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
104     // we need to take care a bit here to avoid infinite loops in case of cycles
105     // (i.e. if we have `trait A: B; trait B: A;`)
106     let mut result = vec![trait_];
107     let mut i = 0;
108     while i < result.len() {
109         let t = result[i];
110         // yeah this is quadratic, but trait hierarchies should be flat
111         // enough that this doesn't matter
112         for tt in direct_super_traits(db, t) {
113             if !result.contains(&tt) {
114                 result.push(tt);
115             }
116         }
117         i += 1;
118     }
119     result
120 }
121 
122 /// Given a trait ref (`Self: Trait`), builds all the implied trait refs for
123 /// super traits. The original trait ref will be included. So the difference to
124 /// `all_super_traits` is that we keep track of type parameters; for example if
125 /// we have `Self: Trait<u32, i32>` and `Trait<T, U>: OtherTrait<U>` we'll get
126 /// `Self: OtherTrait<i32>`.
all_super_trait_refs(db: &dyn HirDatabase, trait_ref: TraitRef) -> SuperTraits127 pub(super) fn all_super_trait_refs(db: &dyn HirDatabase, trait_ref: TraitRef) -> SuperTraits {
128     SuperTraits { db, seen: iter::once(trait_ref.trait_id).collect(), stack: vec![trait_ref] }
129 }
130 
131 pub(super) struct SuperTraits<'a> {
132     db: &'a dyn HirDatabase,
133     stack: Vec<TraitRef>,
134     seen: FxHashSet<ChalkTraitId>,
135 }
136 
137 impl<'a> SuperTraits<'a> {
elaborate(&mut self, trait_ref: &TraitRef)138     fn elaborate(&mut self, trait_ref: &TraitRef) {
139         let mut trait_refs = direct_super_trait_refs(self.db, trait_ref);
140         trait_refs.retain(|tr| !self.seen.contains(&tr.trait_id));
141         self.stack.extend(trait_refs);
142     }
143 }
144 
145 impl<'a> Iterator for SuperTraits<'a> {
146     type Item = TraitRef;
147 
next(&mut self) -> Option<Self::Item>148     fn next(&mut self) -> Option<Self::Item> {
149         if let Some(next) = self.stack.pop() {
150             self.elaborate(&next);
151             Some(next)
152         } else {
153             None
154         }
155     }
156 }
157 
associated_type_by_name_including_super_traits( db: &dyn HirDatabase, trait_ref: TraitRef, name: &Name, ) -> Option<(TraitRef, TypeAliasId)>158 pub(super) fn associated_type_by_name_including_super_traits(
159     db: &dyn HirDatabase,
160     trait_ref: TraitRef,
161     name: &Name,
162 ) -> Option<(TraitRef, TypeAliasId)> {
163     all_super_trait_refs(db, trait_ref).find_map(|t| {
164         let assoc_type = db.trait_data(t.hir_trait_id()).associated_type_by_name(name)?;
165         Some((t, assoc_type))
166     })
167 }
168 
generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics169 pub(crate) fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics {
170     let parent_generics = parent_generic_def(db, def).map(|def| Box::new(generics(db, def)));
171     Generics { def, params: db.generic_params(def), parent_generics }
172 }
173 
174 #[derive(Debug)]
175 pub(crate) struct Generics {
176     def: GenericDefId,
177     pub(crate) params: Interned<GenericParams>,
178     parent_generics: Option<Box<Generics>>,
179 }
180 
181 impl Generics {
iter<'a>( &'a self, ) -> impl Iterator<Item = (TypeParamId, &'a TypeParamData)> + 'a182     pub(crate) fn iter<'a>(
183         &'a self,
184     ) -> impl Iterator<Item = (TypeParamId, &'a TypeParamData)> + 'a {
185         self.parent_generics
186             .as_ref()
187             .into_iter()
188             .flat_map(|it| {
189                 it.params
190                     .types
191                     .iter()
192                     .map(move |(local_id, p)| (TypeParamId { parent: it.def, local_id }, p))
193             })
194             .chain(
195                 self.params
196                     .types
197                     .iter()
198                     .map(move |(local_id, p)| (TypeParamId { parent: self.def, local_id }, p)),
199             )
200     }
201 
iter_parent<'a>( &'a self, ) -> impl Iterator<Item = (TypeParamId, &'a TypeParamData)> + 'a202     pub(crate) fn iter_parent<'a>(
203         &'a self,
204     ) -> impl Iterator<Item = (TypeParamId, &'a TypeParamData)> + 'a {
205         self.parent_generics.as_ref().into_iter().flat_map(|it| {
206             it.params
207                 .types
208                 .iter()
209                 .map(move |(local_id, p)| (TypeParamId { parent: it.def, local_id }, p))
210         })
211     }
212 
len(&self) -> usize213     pub(crate) fn len(&self) -> usize {
214         self.len_split().0
215     }
216 
217     /// (total, parents, child)
len_split(&self) -> (usize, usize, usize)218     pub(crate) fn len_split(&self) -> (usize, usize, usize) {
219         let parent = self.parent_generics.as_ref().map_or(0, |p| p.len());
220         let child = self.params.types.len();
221         (parent + child, parent, child)
222     }
223 
224     /// (parent total, self param, type param list, impl trait)
provenance_split(&self) -> (usize, usize, usize, usize)225     pub(crate) fn provenance_split(&self) -> (usize, usize, usize, usize) {
226         let parent = self.parent_generics.as_ref().map_or(0, |p| p.len());
227         let self_params = self
228             .params
229             .types
230             .iter()
231             .filter(|(_, p)| p.provenance == TypeParamProvenance::TraitSelf)
232             .count();
233         let list_params = self
234             .params
235             .types
236             .iter()
237             .filter(|(_, p)| p.provenance == TypeParamProvenance::TypeParamList)
238             .count();
239         let impl_trait_params = self
240             .params
241             .types
242             .iter()
243             .filter(|(_, p)| p.provenance == TypeParamProvenance::ArgumentImplTrait)
244             .count();
245         (parent, self_params, list_params, impl_trait_params)
246     }
247 
param_idx(&self, param: TypeParamId) -> Option<usize>248     pub(crate) fn param_idx(&self, param: TypeParamId) -> Option<usize> {
249         Some(self.find_param(param)?.0)
250     }
251 
find_param(&self, param: TypeParamId) -> Option<(usize, &TypeParamData)>252     fn find_param(&self, param: TypeParamId) -> Option<(usize, &TypeParamData)> {
253         if param.parent == self.def {
254             let (idx, (_local_id, data)) = self
255                 .params
256                 .types
257                 .iter()
258                 .enumerate()
259                 .find(|(_, (idx, _))| *idx == param.local_id)
260                 .unwrap();
261             let (_total, parent_len, _child) = self.len_split();
262             Some((parent_len + idx, data))
263         } else {
264             self.parent_generics.as_ref().and_then(|g| g.find_param(param))
265         }
266     }
267 
268     /// Returns a Substitution that replaces each parameter by a bound variable.
bound_vars_subst(&self, debruijn: DebruijnIndex) -> Substitution269     pub(crate) fn bound_vars_subst(&self, debruijn: DebruijnIndex) -> Substitution {
270         Substitution::from_iter(
271             &Interner,
272             self.iter()
273                 .enumerate()
274                 .map(|(idx, _)| TyKind::BoundVar(BoundVar::new(debruijn, idx)).intern(&Interner)),
275         )
276     }
277 
278     /// Returns a Substitution that replaces each parameter by itself (i.e. `Ty::Param`).
type_params_subst(&self, db: &dyn HirDatabase) -> Substitution279     pub(crate) fn type_params_subst(&self, db: &dyn HirDatabase) -> Substitution {
280         Substitution::from_iter(
281             &Interner,
282             self.iter().map(|(id, _)| {
283                 TyKind::Placeholder(crate::to_placeholder_idx(db, id)).intern(&Interner)
284             }),
285         )
286     }
287 }
288 
parent_generic_def(db: &dyn DefDatabase, def: GenericDefId) -> Option<GenericDefId>289 fn parent_generic_def(db: &dyn DefDatabase, def: GenericDefId) -> Option<GenericDefId> {
290     let container = match def {
291         GenericDefId::FunctionId(it) => it.lookup(db).container,
292         GenericDefId::TypeAliasId(it) => it.lookup(db).container,
293         GenericDefId::ConstId(it) => it.lookup(db).container,
294         GenericDefId::EnumVariantId(it) => return Some(it.parent.into()),
295         GenericDefId::AdtId(_) | GenericDefId::TraitId(_) | GenericDefId::ImplId(_) => return None,
296     };
297 
298     match container {
299         AssocContainerId::ImplId(it) => Some(it.into()),
300         AssocContainerId::TraitId(it) => Some(it.into()),
301         AssocContainerId::ModuleId(_) => None,
302     }
303 }
304