1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::{Data, DeriveInput, Ident};
4
5 use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream>7 pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11 let vis = &ast.vis;
12 let type_properties = ast.get_type_properties()?;
13 let strum_module_path = type_properties.crate_module_path();
14
15 if gen.lifetimes().count() > 0 {
16 return Err(syn::Error::new(
17 Span::call_site(),
18 "This macro doesn't support enums with lifetimes. \
19 The resulting enums would be unbounded.",
20 ));
21 }
22
23 let phantom_data = if gen.type_params().count() > 0 {
24 let g = gen.type_params().map(|param| ¶m.ident);
25 quote! { < ( #(#g),* ) > }
26 } else {
27 quote! { < () > }
28 };
29
30 let variants = match &ast.data {
31 Data::Enum(v) => &v.variants,
32 _ => return Err(non_enum_error()),
33 };
34
35 let mut arms = Vec::new();
36 let mut idx = 0usize;
37 for variant in variants {
38 use syn::Fields::*;
39
40 if variant.get_variant_properties()?.disabled.is_some() {
41 continue;
42 }
43
44 let ident = &variant.ident;
45 let params = match &variant.fields {
46 Unit => quote! {},
47 Unnamed(fields) => {
48 let defaults = ::std::iter::repeat(quote!(::core::default::Default::default()))
49 .take(fields.unnamed.len());
50 quote! { (#(#defaults),*) }
51 }
52 Named(fields) => {
53 let fields = fields
54 .named
55 .iter()
56 .map(|field| field.ident.as_ref().unwrap());
57 quote! { {#(#fields: ::core::default::Default::default()),*} }
58 }
59 };
60
61 arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
62 idx += 1;
63 }
64
65 let variant_count = arms.len();
66 arms.push(quote! { _ => ::core::option::Option::None });
67 let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
68
69 Ok(quote! {
70 #[allow(missing_docs)]
71 #vis struct #iter_name #ty_generics {
72 idx: usize,
73 back_idx: usize,
74 marker: ::core::marker::PhantomData #phantom_data,
75 }
76
77 impl #impl_generics #iter_name #ty_generics #where_clause {
78 fn get(&self, idx: usize) -> Option<#name #ty_generics> {
79 match idx {
80 #(#arms),*
81 }
82 }
83 }
84
85 impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
86 type Iterator = #iter_name #ty_generics;
87 fn iter() -> #iter_name #ty_generics {
88 #iter_name {
89 idx: 0,
90 back_idx: 0,
91 marker: ::core::marker::PhantomData,
92 }
93 }
94 }
95
96 impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
97 type Item = #name #ty_generics;
98
99 fn next(&mut self) -> Option<<Self as Iterator>::Item> {
100 self.nth(0)
101 }
102
103 fn size_hint(&self) -> (usize, Option<usize>) {
104 let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
105 (t, Some(t))
106 }
107
108 fn nth(&mut self, n: usize) -> Option<<Self as Iterator>::Item> {
109 let idx = self.idx + n + 1;
110 if idx + self.back_idx > #variant_count {
111 // We went past the end of the iterator. Freeze idx at #variant_count
112 // so that it doesn't overflow if the user calls this repeatedly.
113 // See PR #76 for context.
114 self.idx = #variant_count;
115 None
116 } else {
117 self.idx = idx;
118 self.get(idx - 1)
119 }
120 }
121 }
122
123 impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
124 fn len(&self) -> usize {
125 self.size_hint().0
126 }
127 }
128
129 impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
130 fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
131 let back_idx = self.back_idx + 1;
132
133 if self.idx + back_idx > #variant_count {
134 // We went past the end of the iterator. Freeze back_idx at #variant_count
135 // so that it doesn't overflow if the user calls this repeatedly.
136 // See PR #76 for context.
137 self.back_idx = #variant_count;
138 None
139 } else {
140 self.back_idx = back_idx;
141 self.get(#variant_count - self.back_idx)
142 }
143 }
144 }
145
146 impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
147 fn clone(&self) -> #iter_name #ty_generics {
148 #iter_name {
149 idx: self.idx,
150 back_idx: self.back_idx,
151 marker: self.marker.clone(),
152 }
153 }
154 }
155 })
156 }
157