1 use proc_macro2;
2 use proc_macro2::Span;
3 use syn;
4 use syn::fold::Fold;
5 use syn::spanned::Spanned;
6 
7 use diagnostic_shim::*;
8 use meta::*;
9 use model::*;
10 use util::*;
11 
derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagnostic>12 pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagnostic> {
13     let model = Model::from_item(&item)?;
14     let tokens = MetaItem::all_with_name(&item.attrs, "belongs_to")
15         .into_iter()
16         .filter_map(
17             |attr| match derive_belongs_to(&model, &item.generics, attr) {
18                 Ok(t) => Some(t),
19                 Err(e) => {
20                     e.emit();
21                     None
22                 }
23             },
24         );
25 
26     Ok(wrap_in_dummy_mod(
27         model.dummy_mod_name("associations"),
28         quote!(#(#tokens)*),
29     ))
30 }
31 
derive_belongs_to( model: &Model, generics: &syn::Generics, meta: MetaItem, ) -> Result<proc_macro2::TokenStream, Diagnostic>32 fn derive_belongs_to(
33     model: &Model,
34     generics: &syn::Generics,
35     meta: MetaItem,
36 ) -> Result<proc_macro2::TokenStream, Diagnostic> {
37     let AssociationOptions {
38         parent_struct,
39         foreign_key,
40     } = AssociationOptions::from_meta(meta)?;
41     let (_, ty_generics, _) = generics.split_for_impl();
42 
43     let foreign_key_field = model.find_column(&foreign_key)?;
44     let struct_name = &model.name;
45     let foreign_key_access = foreign_key_field.name.access();
46     let foreign_key_ty = inner_of_option_ty(&foreign_key_field.ty);
47     let table_name = model.table_name();
48 
49     let mut generics = generics.clone();
50 
51     let parent_struct = ReplacePathLifetimes::new(|i, span| {
52         let letter = char::from(b'b' + i as u8);
53         let lifetime = syn::Lifetime::new(&format!("'__{}", letter), span);
54         generics.params.push(parse_quote!(#lifetime));
55         lifetime
56     })
57     .fold_type_path(parent_struct);
58 
59     // TODO: Remove this special casing as soon as we bump our minimal supported
60     // rust version to >= 1.30.0 because this version will add
61     // `impl<'a, T> From<&'a Option<T>> for Option<&'a T>` to the std-lib
62     let (foreign_key_expr, foreign_key_ty) = if is_option_ty(&foreign_key_field.ty) {
63         (
64             quote!(self#foreign_key_access.as_ref()),
65             quote!(#foreign_key_ty),
66         )
67     } else {
68         generics.params.push(parse_quote!(__FK));
69         {
70             let where_clause = generics.where_clause.get_or_insert(parse_quote!(where));
71             where_clause
72                 .predicates
73                 .push(parse_quote!(__FK: std::hash::Hash + std::cmp::Eq));
74             where_clause.predicates.push(
75                 parse_quote!(for<'__a> &'__a #foreign_key_ty: std::convert::Into<::std::option::Option<&'__a __FK>>),
76             );
77             where_clause.predicates.push(
78                 parse_quote!(for<'__a> &'__a #parent_struct: diesel::associations::Identifiable<Id = &'__a __FK>),
79             );
80         }
81 
82         (
83             quote!(std::convert::Into::into(&self#foreign_key_access)),
84             quote!(__FK),
85         )
86     };
87 
88     let (impl_generics, _, where_clause) = generics.split_for_impl();
89 
90     Ok(quote! {
91         impl #impl_generics diesel::associations::BelongsTo<#parent_struct>
92             for #struct_name #ty_generics
93         #where_clause
94         {
95             type ForeignKey = #foreign_key_ty;
96             type ForeignKeyColumn = #table_name::#foreign_key;
97 
98             fn foreign_key(&self) -> std::option::Option<&Self::ForeignKey> {
99                 #foreign_key_expr
100             }
101 
102             fn foreign_key_column() -> Self::ForeignKeyColumn {
103                 #table_name::#foreign_key
104             }
105         }
106     })
107 }
108 
109 struct AssociationOptions {
110     parent_struct: syn::TypePath,
111     foreign_key: syn::Ident,
112 }
113 
114 impl AssociationOptions {
from_meta(meta: MetaItem) -> Result<Self, Diagnostic>115     fn from_meta(meta: MetaItem) -> Result<Self, Diagnostic> {
116         let parent_struct = meta
117             .nested()?
118             .find(|m| m.path().is_ok() || m.name().is_ident("parent"))
119             .ok_or_else(|| meta.span())
120             .and_then(|m| {
121                 m.path()
122                     .map(|i| parse_quote!(#i))
123                     .or_else(|_| m.ty_value())
124                     .map_err(|_| m.span())
125             })
126             .and_then(|ty| match ty {
127                 syn::Type::Path(ty_path) => Ok(ty_path),
128                 _ => Err(ty.span()),
129             })
130             .map_err(|span| {
131                 span.error("Expected a struct name")
132                     .help("e.g. `#[belongs_to(User)]` or `#[belongs_to(parent = \"User<'_>\")]")
133             })?;
134         let foreign_key = {
135             let parent_struct_name = parent_struct
136                 .path
137                 .segments
138                 .last()
139                 .expect("paths always have at least one segment");
140             meta.nested_item("foreign_key")?
141                 .map(|i| i.ident_value())
142                 .unwrap_or_else(|| Ok(infer_foreign_key(&parent_struct_name.ident)))?
143         };
144 
145         let unrecognized_options = meta
146             .nested()?
147             .skip(1)
148             .filter(|n| !n.name().is_ident("foreign_key"));
149         for ignored in unrecognized_options {
150             ignored
151                 .span()
152                 .warning(format!(
153                     "Unrecognized option {}",
154                     path_to_string(&ignored.name())
155                 ))
156                 .emit();
157         }
158 
159         Ok(Self {
160             parent_struct,
161             foreign_key,
162         })
163     }
164 }
165 
infer_foreign_key(name: &syn::Ident) -> syn::Ident166 fn infer_foreign_key(name: &syn::Ident) -> syn::Ident {
167     let snake_case = camel_to_snake(&name.to_string());
168     syn::Ident::new(&format!("{}_id", snake_case), name.span())
169 }
170 
171 struct ReplacePathLifetimes<F> {
172     count: usize,
173     f: F,
174 }
175 
176 impl<F> ReplacePathLifetimes<F> {
new(f: F) -> Self177     fn new(f: F) -> Self {
178         Self { count: 0, f }
179     }
180 }
181 
182 impl<F> Fold for ReplacePathLifetimes<F>
183 where
184     F: FnMut(usize, Span) -> syn::Lifetime,
185 {
fold_lifetime(&mut self, mut lt: syn::Lifetime) -> syn::Lifetime186     fn fold_lifetime(&mut self, mut lt: syn::Lifetime) -> syn::Lifetime {
187         if lt.ident == "_" {
188             lt = (self.f)(self.count, lt.span());
189             self.count += 1;
190         }
191         lt
192     }
193 }
194