1 use std::iter;
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{quote, ToTokens};
5 use syn::{parse::Result, DeriveInput, Ident, Index};
6 
7 use crate::utils::{
8     add_where_clauses_for_new_ident, AttrParams, DeriveType, HashMap, MultiFieldData,
9     RefType, State,
10 };
11 
12 /// Provides the hook to expand `#[derive(From)]` into an implementation of `From`
expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream>13 pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
14     let state = State::with_attr_params(
15         input,
16         trait_name,
17         quote!(::core::convert),
18         trait_name.to_lowercase(),
19         AttrParams {
20             enum_: vec!["forward", "ignore"],
21             variant: vec!["forward", "ignore", "types"],
22             struct_: vec!["forward", "types"],
23             field: vec!["forward"],
24         },
25     )?;
26     if state.derive_type == DeriveType::Enum {
27         Ok(enum_from(input, state))
28     } else {
29         Ok(struct_from(input, &state))
30     }
31 }
32 
struct_from(input: &DeriveInput, state: &State) -> TokenStream33 pub fn struct_from(input: &DeriveInput, state: &State) -> TokenStream {
34     let multi_field_data = state.enabled_fields_data();
35     let MultiFieldData {
36         fields,
37         variant_info,
38         infos,
39         input_type,
40         trait_path,
41         ..
42     } = multi_field_data.clone();
43 
44     let additional_types = variant_info.additional_types(RefType::No);
45     let mut impls = Vec::with_capacity(additional_types.len() + 1);
46     for explicit_type in iter::once(None).chain(additional_types.iter().map(Some)) {
47         let mut new_generics = input.generics.clone();
48 
49         let mut initializers = Vec::with_capacity(infos.len());
50         let mut from_types = Vec::with_capacity(infos.len());
51         for (i, (info, field)) in infos.iter().zip(fields.iter()).enumerate() {
52             let field_type = &field.ty;
53             let variable = if fields.len() == 1 {
54                 quote! { original }
55             } else {
56                 let tuple_index = Index::from(i);
57                 quote! { original.#tuple_index }
58             };
59             if let Some(type_) = explicit_type {
60                 initializers.push(quote! {
61                     <#field_type as #trait_path<#type_>>::from(#variable)
62                 });
63                 from_types.push(quote! { #type_ });
64             } else if info.forward {
65                 let type_param =
66                     &Ident::new(&format!("__FromT{}", i), Span::call_site());
67                 let sub_trait_path = quote! { #trait_path<#type_param> };
68                 let type_where_clauses = quote! {
69                     where #field_type: #sub_trait_path
70                 };
71                 new_generics = add_where_clauses_for_new_ident(
72                     &new_generics,
73                     &[field],
74                     type_param,
75                     type_where_clauses,
76                     true,
77                 );
78                 let casted_trait = quote! { <#field_type as #sub_trait_path> };
79                 initializers.push(quote! { #casted_trait::from(#variable) });
80                 from_types.push(quote! { #type_param });
81             } else {
82                 initializers.push(variable);
83                 from_types.push(quote! { #field_type });
84             }
85         }
86 
87         let body = multi_field_data.initializer(&initializers);
88         let (impl_generics, _, where_clause) = new_generics.split_for_impl();
89         let (_, ty_generics, _) = input.generics.split_for_impl();
90 
91         impls.push(quote! {
92             #[automatically_derived]
93             impl#impl_generics #trait_path<(#(#from_types),*)> for
94                 #input_type#ty_generics #where_clause {
95 
96                 #[inline]
97                 fn from(original: (#(#from_types),*)) -> #input_type#ty_generics {
98                     #body
99                 }
100             }
101         });
102     }
103 
104     quote! { #( #impls )* }
105 }
106 
enum_from(input: &DeriveInput, state: State) -> TokenStream107 fn enum_from(input: &DeriveInput, state: State) -> TokenStream {
108     let mut tokens = TokenStream::new();
109 
110     let mut variants_per_types = HashMap::default();
111     for variant_state in state.enabled_variant_data().variant_states {
112         let multi_field_data = variant_state.enabled_fields_data();
113         let MultiFieldData { field_types, .. } = multi_field_data.clone();
114         variants_per_types
115             .entry(field_types.clone())
116             .or_insert_with(Vec::new)
117             .push(variant_state);
118     }
119     for (ref field_types, ref variant_states) in variants_per_types {
120         for variant_state in variant_states {
121             let multi_field_data = variant_state.enabled_fields_data();
122             let MultiFieldData {
123                 variant_info,
124                 infos,
125                 ..
126             } = multi_field_data.clone();
127             // If there would be a conflict on a empty tuple derive, ignore the
128             // variants that are not explicitly enabled or have explicitly enabled
129             // or disabled fields
130             if field_types.is_empty()
131                 && variant_states.len() > 1
132                 && !std::iter::once(variant_info)
133                     .chain(infos)
134                     .any(|info| info.info.enabled.is_some())
135             {
136                 continue;
137             }
138             struct_from(input, variant_state).to_tokens(&mut tokens);
139         }
140     }
141     tokens
142 }
143