1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::{Data, DeriveInput, Ident};
4 
5 use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6 
enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream>7 pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8     let name = &ast.ident;
9     let gen = &ast.generics;
10     let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11     let vis = &ast.vis;
12     let type_properties = ast.get_type_properties()?;
13     let strum_module_path = type_properties.crate_module_path();
14 
15     if gen.lifetimes().count() > 0 {
16         return Err(syn::Error::new(
17             Span::call_site(),
18             "This macro doesn't support enums with lifetimes. \
19              The resulting enums would be unbounded.",
20         ));
21     }
22 
23     let phantom_data = if gen.type_params().count() > 0 {
24         let g = gen.type_params().map(|param| &param.ident);
25         quote! { < ( #(#g),* ) > }
26     } else {
27         quote! { < () > }
28     };
29 
30     let variants = match &ast.data {
31         Data::Enum(v) => &v.variants,
32         _ => return Err(non_enum_error()),
33     };
34 
35     let mut arms = Vec::new();
36     let mut idx = 0usize;
37     for variant in variants {
38         use syn::Fields::*;
39 
40         if variant.get_variant_properties()?.disabled.is_some() {
41             continue;
42         }
43 
44         let ident = &variant.ident;
45         let params = match &variant.fields {
46             Unit => quote! {},
47             Unnamed(fields) => {
48                 let defaults = ::std::iter::repeat(quote!(::core::default::Default::default()))
49                     .take(fields.unnamed.len());
50                 quote! { (#(#defaults),*) }
51             }
52             Named(fields) => {
53                 let fields = fields
54                     .named
55                     .iter()
56                     .map(|field| field.ident.as_ref().unwrap());
57                 quote! { {#(#fields: ::core::default::Default::default()),*} }
58             }
59         };
60 
61         arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
62         idx += 1;
63     }
64 
65     let variant_count = arms.len();
66     arms.push(quote! { _ => ::core::option::Option::None });
67     let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
68 
69     Ok(quote! {
70         #[allow(missing_docs)]
71         #vis struct #iter_name #ty_generics {
72             idx: usize,
73             back_idx: usize,
74             marker: ::core::marker::PhantomData #phantom_data,
75         }
76 
77         impl #impl_generics #iter_name #ty_generics #where_clause {
78             fn get(&self, idx: usize) -> Option<#name #ty_generics> {
79                 match idx {
80                     #(#arms),*
81                 }
82             }
83         }
84 
85         impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
86             type Iterator = #iter_name #ty_generics;
87             fn iter() -> #iter_name #ty_generics {
88                 #iter_name {
89                     idx: 0,
90                     back_idx: 0,
91                     marker: ::core::marker::PhantomData,
92                 }
93             }
94         }
95 
96         impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
97             type Item = #name #ty_generics;
98 
99             fn next(&mut self) -> Option<<Self as Iterator>::Item> {
100                 self.nth(0)
101             }
102 
103             fn size_hint(&self) -> (usize, Option<usize>) {
104                 let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
105                 (t, Some(t))
106             }
107 
108             fn nth(&mut self, n: usize) -> Option<<Self as Iterator>::Item> {
109                 let idx = self.idx + n + 1;
110                 if idx + self.back_idx > #variant_count {
111                     // We went past the end of the iterator. Freeze idx at #variant_count
112                     // so that it doesn't overflow if the user calls this repeatedly.
113                     // See PR #76 for context.
114                     self.idx = #variant_count;
115                     None
116                 } else {
117                     self.idx = idx;
118                     self.get(idx - 1)
119                 }
120             }
121         }
122 
123         impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
124             fn len(&self) -> usize {
125                 self.size_hint().0
126             }
127         }
128 
129         impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
130             fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
131                 let back_idx = self.back_idx + 1;
132 
133                 if self.idx + back_idx > #variant_count {
134                     // We went past the end of the iterator. Freeze back_idx at #variant_count
135                     // so that it doesn't overflow if the user calls this repeatedly.
136                     // See PR #76 for context.
137                     self.back_idx = #variant_count;
138                     None
139                 } else {
140                     self.back_idx = back_idx;
141                     self.get(#variant_count - self.back_idx)
142                 }
143             }
144         }
145 
146         impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
147             fn clone(&self) -> #iter_name #ty_generics {
148                 #iter_name {
149                     idx: self.idx,
150                     back_idx: self.back_idx,
151                     marker: self.marker.clone(),
152                 }
153             }
154         }
155     })
156 }
157