1 //! Various diagnostics for expressions that are collected together in one pass
2 //! through the body using inference results: mismatched arg counts, missing
3 //! fields, etc.
4 
5 use std::{cell::RefCell, sync::Arc};
6 
7 use hir_def::{
8     expr::Statement, path::path, resolver::HasResolver, type_ref::Mutability, AssocItemId,
9     DefWithBodyId, HasModule,
10 };
11 use hir_expand::name;
12 use itertools::Either;
13 use rustc_hash::FxHashSet;
14 
15 use crate::{
16     db::HirDatabase,
17     diagnostics::match_check::{
18         self,
19         usefulness::{compute_match_usefulness, expand_pattern, MatchCheckCtx, PatternArena},
20     },
21     AdtId, InferenceResult, Interner, Ty, TyExt, TyKind,
22 };
23 
24 pub(crate) use hir_def::{
25     body::Body,
26     expr::{Expr, ExprId, MatchArm, Pat, PatId},
27     LocalFieldId, VariantId,
28 };
29 
30 pub enum BodyValidationDiagnostic {
31     RecordMissingFields {
32         record: Either<ExprId, PatId>,
33         variant: VariantId,
34         missed_fields: Vec<LocalFieldId>,
35     },
36     ReplaceFilterMapNextWithFindMap {
37         method_call_expr: ExprId,
38     },
39     MismatchedArgCount {
40         call_expr: ExprId,
41         expected: usize,
42         found: usize,
43     },
44     RemoveThisSemicolon {
45         expr: ExprId,
46     },
47     MissingOkOrSomeInTailExpr {
48         expr: ExprId,
49         required: String,
50     },
51     MissingMatchArms {
52         match_expr: ExprId,
53     },
54     AddReferenceHere {
55         arg_expr: ExprId,
56         mutability: Mutability,
57     },
58 }
59 
60 impl BodyValidationDiagnostic {
collect(db: &dyn HirDatabase, owner: DefWithBodyId) -> Vec<BodyValidationDiagnostic>61     pub fn collect(db: &dyn HirDatabase, owner: DefWithBodyId) -> Vec<BodyValidationDiagnostic> {
62         let _p = profile::span("BodyValidationDiagnostic::collect");
63         let infer = db.infer(owner);
64         let mut validator = ExprValidator::new(owner, infer);
65         validator.validate_body(db);
66         validator.diagnostics
67     }
68 }
69 
70 struct ExprValidator {
71     owner: DefWithBodyId,
72     infer: Arc<InferenceResult>,
73     pub(super) diagnostics: Vec<BodyValidationDiagnostic>,
74 }
75 
76 impl ExprValidator {
new(owner: DefWithBodyId, infer: Arc<InferenceResult>) -> ExprValidator77     fn new(owner: DefWithBodyId, infer: Arc<InferenceResult>) -> ExprValidator {
78         ExprValidator { owner, infer, diagnostics: Vec::new() }
79     }
80 
validate_body(&mut self, db: &dyn HirDatabase)81     fn validate_body(&mut self, db: &dyn HirDatabase) {
82         self.check_for_filter_map_next(db);
83 
84         let body = db.body(self.owner);
85 
86         for (id, expr) in body.exprs.iter() {
87             if let Some((variant, missed_fields, true)) =
88                 record_literal_missing_fields(db, &self.infer, id, expr)
89             {
90                 self.diagnostics.push(BodyValidationDiagnostic::RecordMissingFields {
91                     record: Either::Left(id),
92                     variant,
93                     missed_fields,
94                 });
95             }
96 
97             match expr {
98                 Expr::Match { expr, arms } => {
99                     self.validate_match(id, *expr, arms, db, self.infer.clone());
100                 }
101                 Expr::Call { .. } | Expr::MethodCall { .. } => {
102                     self.validate_call(db, id, expr);
103                 }
104                 _ => {}
105             }
106         }
107         for (id, pat) in body.pats.iter() {
108             if let Some((variant, missed_fields, true)) =
109                 record_pattern_missing_fields(db, &self.infer, id, pat)
110             {
111                 self.diagnostics.push(BodyValidationDiagnostic::RecordMissingFields {
112                     record: Either::Right(id),
113                     variant,
114                     missed_fields,
115                 });
116             }
117         }
118         let body_expr = &body[body.body_expr];
119         if let Expr::Block { statements, tail, .. } = body_expr {
120             if let Some(t) = tail {
121                 self.validate_results_in_tail_expr(body.body_expr, *t, db);
122             } else if let Some(Statement::Expr { expr: id, .. }) = statements.last() {
123                 self.validate_missing_tail_expr(body.body_expr, *id);
124             }
125         }
126 
127         let infer = &self.infer;
128         let diagnostics = &mut self.diagnostics;
129 
130         infer
131             .expr_type_mismatches()
132             .filter_map(|(expr, mismatch)| {
133                 let (expr_without_ref, mutability) =
134                     check_missing_refs(infer, expr, &mismatch.expected)?;
135 
136                 Some((expr_without_ref, mutability))
137             })
138             .for_each(|(arg_expr, mutability)| {
139                 diagnostics
140                     .push(BodyValidationDiagnostic::AddReferenceHere { arg_expr, mutability });
141             });
142     }
143 
check_for_filter_map_next(&mut self, db: &dyn HirDatabase)144     fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) {
145         // Find the FunctionIds for Iterator::filter_map and Iterator::next
146         let iterator_path = path![core::iter::Iterator];
147         let resolver = self.owner.resolver(db.upcast());
148         let iterator_trait_id = match resolver.resolve_known_trait(db.upcast(), &iterator_path) {
149             Some(id) => id,
150             None => return,
151         };
152         let iterator_trait_items = &db.trait_data(iterator_trait_id).items;
153         let filter_map_function_id =
154             match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) {
155                 Some((_, AssocItemId::FunctionId(id))) => id,
156                 _ => return,
157             };
158         let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next])
159         {
160             Some((_, AssocItemId::FunctionId(id))) => id,
161             _ => return,
162         };
163 
164         // Search function body for instances of .filter_map(..).next()
165         let body = db.body(self.owner);
166         let mut prev = None;
167         for (id, expr) in body.exprs.iter() {
168             if let Expr::MethodCall { receiver, .. } = expr {
169                 let function_id = match self.infer.method_resolution(id) {
170                     Some((id, _)) => id,
171                     None => continue,
172                 };
173 
174                 if function_id == *filter_map_function_id {
175                     prev = Some(id);
176                     continue;
177                 }
178 
179                 if function_id == *next_function_id {
180                     if let Some(filter_map_id) = prev {
181                         if *receiver == filter_map_id {
182                             self.diagnostics.push(
183                                 BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap {
184                                     method_call_expr: id,
185                                 },
186                             );
187                         }
188                     }
189                 }
190             }
191             prev = None;
192         }
193     }
194 
validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr)195     fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) {
196         // Check that the number of arguments matches the number of parameters.
197 
198         // FIXME: Due to shortcomings in the current type system implementation, only emit this
199         // diagnostic if there are no type mismatches in the containing function.
200         if self.infer.expr_type_mismatches().next().is_some() {
201             return;
202         }
203 
204         let is_method_call = matches!(expr, Expr::MethodCall { .. });
205         let (sig, mut arg_count) = match expr {
206             Expr::Call { callee, args } => {
207                 let callee = &self.infer.type_of_expr[*callee];
208                 let sig = match callee.callable_sig(db) {
209                     Some(sig) => sig,
210                     None => return,
211                 };
212                 (sig, args.len())
213             }
214             Expr::MethodCall { receiver, args, .. } => {
215                 let receiver = &self.infer.type_of_expr[*receiver];
216                 if receiver.strip_references().is_unknown() {
217                     // if the receiver is of unknown type, it's very likely we
218                     // don't know enough to correctly resolve the method call.
219                     // This is kind of a band-aid for #6975.
220                     return;
221                 }
222 
223                 let (callee, subst) = match self.infer.method_resolution(call_id) {
224                     Some(it) => it,
225                     None => return,
226                 };
227                 let sig = db.callable_item_signature(callee.into()).substitute(Interner, &subst);
228 
229                 (sig, args.len() + 1)
230             }
231             _ => return,
232         };
233 
234         if sig.is_varargs {
235             return;
236         }
237 
238         if sig.legacy_const_generics_indices.is_empty() {
239             let mut param_count = sig.params().len();
240 
241             if arg_count != param_count {
242                 if is_method_call {
243                     param_count -= 1;
244                     arg_count -= 1;
245                 }
246                 self.diagnostics.push(BodyValidationDiagnostic::MismatchedArgCount {
247                     call_expr: call_id,
248                     expected: param_count,
249                     found: arg_count,
250                 });
251             }
252         } else {
253             // With `#[rustc_legacy_const_generics]` there are basically two parameter counts that
254             // are allowed.
255             let count_non_legacy = sig.params().len();
256             let count_legacy = sig.params().len() + sig.legacy_const_generics_indices.len();
257             if arg_count != count_non_legacy && arg_count != count_legacy {
258                 self.diagnostics.push(BodyValidationDiagnostic::MismatchedArgCount {
259                     call_expr: call_id,
260                     // Since most users will use the legacy way to call them, report against that.
261                     expected: count_legacy,
262                     found: arg_count,
263                 });
264             }
265         }
266     }
267 
validate_match( &mut self, id: ExprId, match_expr: ExprId, arms: &[MatchArm], db: &dyn HirDatabase, infer: Arc<InferenceResult>, )268     fn validate_match(
269         &mut self,
270         id: ExprId,
271         match_expr: ExprId,
272         arms: &[MatchArm],
273         db: &dyn HirDatabase,
274         infer: Arc<InferenceResult>,
275     ) {
276         let body = db.body(self.owner);
277 
278         let match_expr_ty = if infer.type_of_expr[match_expr].is_unknown() {
279             return;
280         } else {
281             &infer.type_of_expr[match_expr]
282         };
283 
284         let pattern_arena = RefCell::new(PatternArena::new());
285 
286         let mut m_arms = Vec::new();
287         let mut has_lowering_errors = false;
288         for arm in arms {
289             if let Some(pat_ty) = infer.type_of_pat.get(arm.pat) {
290                 // We only include patterns whose type matches the type
291                 // of the match expression. If we had an InvalidMatchArmPattern
292                 // diagnostic or similar we could raise that in an else
293                 // block here.
294                 //
295                 // When comparing the types, we also have to consider that rustc
296                 // will automatically de-reference the match expression type if
297                 // necessary.
298                 //
299                 // FIXME we should use the type checker for this.
300                 if (pat_ty == match_expr_ty
301                     || match_expr_ty
302                         .as_reference()
303                         .map(|(match_expr_ty, ..)| match_expr_ty == pat_ty)
304                         .unwrap_or(false))
305                     && types_of_subpatterns_do_match(arm.pat, &body, &infer)
306                 {
307                     // If we had a NotUsefulMatchArm diagnostic, we could
308                     // check the usefulness of each pattern as we added it
309                     // to the matrix here.
310                     let m_arm = match_check::MatchArm {
311                         pat: self.lower_pattern(
312                             arm.pat,
313                             &mut pattern_arena.borrow_mut(),
314                             db,
315                             &body,
316                             &mut has_lowering_errors,
317                         ),
318                         has_guard: arm.guard.is_some(),
319                     };
320                     m_arms.push(m_arm);
321                     if !has_lowering_errors {
322                         continue;
323                     }
324                 }
325             }
326 
327             // If we can't resolve the type of a pattern, or the pattern type doesn't
328             // fit the match expression, we skip this diagnostic. Skipping the entire
329             // diagnostic rather than just not including this match arm is preferred
330             // to avoid the chance of false positives.
331             cov_mark::hit!(validate_match_bailed_out);
332             return;
333         }
334 
335         let cx = MatchCheckCtx {
336             module: self.owner.module(db.upcast()),
337             match_expr,
338             infer: &infer,
339             db,
340             pattern_arena: &pattern_arena,
341         };
342         let report = compute_match_usefulness(&cx, &m_arms);
343 
344         // FIXME Report unreacheble arms
345         // https://github.com/rust-lang/rust/blob/25c15cdbe/compiler/rustc_mir_build/src/thir/pattern/check_match.rs#L200-L201
346 
347         let witnesses = report.non_exhaustiveness_witnesses;
348         // FIXME Report witnesses
349         // eprintln!("compute_match_usefulness(..) -> {:?}", &witnesses);
350         if !witnesses.is_empty() {
351             self.diagnostics.push(BodyValidationDiagnostic::MissingMatchArms { match_expr: id });
352         }
353     }
354 
lower_pattern( &self, pat: PatId, pattern_arena: &mut PatternArena, db: &dyn HirDatabase, body: &Body, have_errors: &mut bool, ) -> match_check::PatId355     fn lower_pattern(
356         &self,
357         pat: PatId,
358         pattern_arena: &mut PatternArena,
359         db: &dyn HirDatabase,
360         body: &Body,
361         have_errors: &mut bool,
362     ) -> match_check::PatId {
363         let mut patcx = match_check::PatCtxt::new(db, &self.infer, body);
364         let pattern = patcx.lower_pattern(pat);
365         let pattern = pattern_arena.alloc(expand_pattern(pattern));
366         if !patcx.errors.is_empty() {
367             *have_errors = true;
368         }
369         pattern
370     }
371 
validate_results_in_tail_expr(&mut self, body_id: ExprId, id: ExprId, db: &dyn HirDatabase)372     fn validate_results_in_tail_expr(&mut self, body_id: ExprId, id: ExprId, db: &dyn HirDatabase) {
373         // the mismatch will be on the whole block currently
374         let mismatch = match self.infer.type_mismatch_for_expr(body_id) {
375             Some(m) => m,
376             None => return,
377         };
378 
379         let core_result_path = path![core::result::Result];
380         let core_option_path = path![core::option::Option];
381 
382         let resolver = self.owner.resolver(db.upcast());
383         let core_result_enum = match resolver.resolve_known_enum(db.upcast(), &core_result_path) {
384             Some(it) => it,
385             _ => return,
386         };
387         let core_option_enum = match resolver.resolve_known_enum(db.upcast(), &core_option_path) {
388             Some(it) => it,
389             _ => return,
390         };
391 
392         let (params, required) = match mismatch.expected.kind(Interner) {
393             TyKind::Adt(AdtId(hir_def::AdtId::EnumId(enum_id)), parameters)
394                 if *enum_id == core_result_enum =>
395             {
396                 (parameters, "Ok".to_string())
397             }
398             TyKind::Adt(AdtId(hir_def::AdtId::EnumId(enum_id)), parameters)
399                 if *enum_id == core_option_enum =>
400             {
401                 (parameters, "Some".to_string())
402             }
403             _ => return,
404         };
405 
406         if params.len(Interner) > 0 && params.at(Interner, 0).ty(Interner) == Some(&mismatch.actual)
407         {
408             self.diagnostics
409                 .push(BodyValidationDiagnostic::MissingOkOrSomeInTailExpr { expr: id, required });
410         }
411     }
412 
validate_missing_tail_expr(&mut self, body_id: ExprId, possible_tail_id: ExprId)413     fn validate_missing_tail_expr(&mut self, body_id: ExprId, possible_tail_id: ExprId) {
414         let mismatch = match self.infer.type_mismatch_for_expr(body_id) {
415             Some(m) => m,
416             None => return,
417         };
418 
419         let possible_tail_ty = match self.infer.type_of_expr.get(possible_tail_id) {
420             Some(ty) => ty,
421             None => return,
422         };
423 
424         if !mismatch.actual.is_unit() || mismatch.expected != *possible_tail_ty {
425             return;
426         }
427 
428         self.diagnostics
429             .push(BodyValidationDiagnostic::RemoveThisSemicolon { expr: possible_tail_id });
430     }
431 }
432 
record_literal_missing_fields( db: &dyn HirDatabase, infer: &InferenceResult, id: ExprId, expr: &Expr, ) -> Option<(VariantId, Vec<LocalFieldId>, bool)>433 pub fn record_literal_missing_fields(
434     db: &dyn HirDatabase,
435     infer: &InferenceResult,
436     id: ExprId,
437     expr: &Expr,
438 ) -> Option<(VariantId, Vec<LocalFieldId>, /*exhaustive*/ bool)> {
439     let (fields, exhaustive) = match expr {
440         Expr::RecordLit { path: _, fields, spread } => (fields, spread.is_none()),
441         _ => return None,
442     };
443 
444     let variant_def = infer.variant_resolution_for_expr(id)?;
445     if let VariantId::UnionId(_) = variant_def {
446         return None;
447     }
448 
449     let variant_data = variant_def.variant_data(db.upcast());
450 
451     let specified_fields: FxHashSet<_> = fields.iter().map(|f| &f.name).collect();
452     let missed_fields: Vec<LocalFieldId> = variant_data
453         .fields()
454         .iter()
455         .filter_map(|(f, d)| if specified_fields.contains(&d.name) { None } else { Some(f) })
456         .collect();
457     if missed_fields.is_empty() {
458         return None;
459     }
460     Some((variant_def, missed_fields, exhaustive))
461 }
462 
record_pattern_missing_fields( db: &dyn HirDatabase, infer: &InferenceResult, id: PatId, pat: &Pat, ) -> Option<(VariantId, Vec<LocalFieldId>, bool)>463 pub fn record_pattern_missing_fields(
464     db: &dyn HirDatabase,
465     infer: &InferenceResult,
466     id: PatId,
467     pat: &Pat,
468 ) -> Option<(VariantId, Vec<LocalFieldId>, /*exhaustive*/ bool)> {
469     let (fields, exhaustive) = match pat {
470         Pat::Record { path: _, args, ellipsis } => (args, !ellipsis),
471         _ => return None,
472     };
473 
474     let variant_def = infer.variant_resolution_for_pat(id)?;
475     if let VariantId::UnionId(_) = variant_def {
476         return None;
477     }
478 
479     let variant_data = variant_def.variant_data(db.upcast());
480 
481     let specified_fields: FxHashSet<_> = fields.iter().map(|f| &f.name).collect();
482     let missed_fields: Vec<LocalFieldId> = variant_data
483         .fields()
484         .iter()
485         .filter_map(|(f, d)| if specified_fields.contains(&d.name) { None } else { Some(f) })
486         .collect();
487     if missed_fields.is_empty() {
488         return None;
489     }
490     Some((variant_def, missed_fields, exhaustive))
491 }
492 
types_of_subpatterns_do_match(pat: PatId, body: &Body, infer: &InferenceResult) -> bool493 fn types_of_subpatterns_do_match(pat: PatId, body: &Body, infer: &InferenceResult) -> bool {
494     fn walk(pat: PatId, body: &Body, infer: &InferenceResult, has_type_mismatches: &mut bool) {
495         match infer.type_mismatch_for_pat(pat) {
496             Some(_) => *has_type_mismatches = true,
497             None => {
498                 body[pat].walk_child_pats(|subpat| walk(subpat, body, infer, has_type_mismatches))
499             }
500         }
501     }
502 
503     let mut has_type_mismatches = false;
504     walk(pat, body, infer, &mut has_type_mismatches);
505     !has_type_mismatches
506 }
507 
check_missing_refs( infer: &InferenceResult, arg: ExprId, param: &Ty, ) -> Option<(ExprId, Mutability)>508 fn check_missing_refs(
509     infer: &InferenceResult,
510     arg: ExprId,
511     param: &Ty,
512 ) -> Option<(ExprId, Mutability)> {
513     let arg_ty = infer.type_of_expr.get(arg)?;
514 
515     let reference_one = arg_ty.as_reference();
516     let reference_two = param.as_reference();
517 
518     match (reference_one, reference_two) {
519         (None, Some((referenced_ty, _, mutability))) if referenced_ty == arg_ty => {
520             Some((arg, Mutability::from_mutable(matches!(mutability, chalk_ir::Mutability::Mut))))
521         }
522         (None, Some((referenced_ty, _, mutability))) => match referenced_ty.kind(Interner) {
523             TyKind::Slice(subst) if matches!(arg_ty.kind(Interner), TyKind::Array(arr_subst, _) if arr_subst == subst) => {
524                 Some((
525                     arg,
526                     Mutability::from_mutable(matches!(mutability, chalk_ir::Mutability::Mut)),
527                 ))
528             }
529             _ => None,
530         },
531         _ => None,
532     }
533 }
534