1 use proc_macro2::{Span, TokenStream};
2 use syn::{parse_quote, Ident, ItemFn};
3 
4 use quote::quote;
5 
6 use super::{inject, render_exec_call};
7 use crate::resolver::{self, Resolver};
8 use crate::utils::{fn_args, fn_args_idents};
9 use crate::{parse::fixture::FixtureInfo, utils::generics_clean_up};
10 
render(fixture: ItemFn, info: FixtureInfo) -> TokenStream11 pub(crate) fn render(fixture: ItemFn, info: FixtureInfo) -> TokenStream {
12     let name = &fixture.sig.ident;
13     let asyncness = &fixture.sig.asyncness.clone();
14     let vargs = fn_args_idents(&fixture).cloned().collect::<Vec<_>>();
15     let args = &vargs;
16     let orig_args = &fixture.sig.inputs;
17     let orig_attrs = &fixture.attrs;
18     let generics = &fixture.sig.generics;
19     let default_output = info
20         .attributes
21         .extract_default_type()
22         .unwrap_or_else(|| fixture.sig.output.clone());
23     let default_generics =
24         generics_clean_up(&fixture.sig.generics, std::iter::empty(), &default_output);
25     let default_where_clause = &default_generics.where_clause;
26     let where_clause = &fixture.sig.generics.where_clause;
27     let output = &fixture.sig.output;
28     let visibility = &fixture.vis;
29     let resolver = (
30         resolver::fixtures::get(info.data.fixtures()),
31         resolver::values::get(info.data.values()),
32     );
33     let generics_idents = generics
34         .type_params()
35         .map(|tp| &tp.ident)
36         .cloned()
37         .collect::<Vec<_>>();
38     let inject = inject::resolve_aruments(fixture.sig.inputs.iter(), &resolver, &generics_idents);
39     let partials =
40         (1..=orig_args.len()).map(|n| render_partial_impl(&fixture, n, &resolver, &info));
41 
42     let call_get = render_exec_call(parse_quote! { Self::get }, args, asyncness.is_some());
43     let call_impl = render_exec_call(parse_quote! { #name }, args, asyncness.is_some());
44 
45     quote! {
46         #[allow(non_camel_case_types)]
47         #visibility struct #name {}
48 
49         impl #name {
50             #(#orig_attrs)*
51             #[allow(unused_mut)]
52             pub #asyncness fn get #generics (#orig_args) #output #where_clause {
53                 #call_impl
54             }
55 
56             pub #asyncness fn default #default_generics () #default_output #default_where_clause {
57                 #inject
58                 #call_get
59             }
60 
61             #(#partials)*
62         }
63 
64         #[allow(dead_code)]
65         #fixture
66     }
67 }
68 
render_partial_impl( fixture: &ItemFn, n: usize, resolver: &impl Resolver, info: &FixtureInfo, ) -> TokenStream69 fn render_partial_impl(
70     fixture: &ItemFn,
71     n: usize,
72     resolver: &impl Resolver,
73     info: &FixtureInfo,
74 ) -> TokenStream {
75     let output = info
76         .attributes
77         .extract_partial_type(n)
78         .unwrap_or_else(|| fixture.sig.output.clone());
79 
80     let generics = generics_clean_up(&fixture.sig.generics, fn_args(fixture).take(n), &output);
81     let where_clause = &generics.where_clause;
82     let asyncness = &fixture.sig.asyncness;
83 
84     let genercs_idents = generics
85         .type_params()
86         .map(|tp| &tp.ident)
87         .cloned()
88         .collect::<Vec<_>>();
89     let inject =
90         inject::resolve_aruments(fixture.sig.inputs.iter().skip(n), resolver, &genercs_idents);
91 
92     let sign_args = fn_args(fixture).take(n);
93     let fixture_args = fn_args_idents(fixture).cloned().collect::<Vec<_>>();
94     let name = Ident::new(&format!("partial_{}", n), Span::call_site());
95 
96     let call_get = render_exec_call(
97         parse_quote! { Self::get },
98         &fixture_args,
99         asyncness.is_some(),
100     );
101 
102     quote! {
103         #[allow(unused_mut)]
104         pub #asyncness fn #name #generics (#(#sign_args),*) #output #where_clause {
105             #inject
106             #call_get
107         }
108     }
109 }
110 
111 #[cfg(test)]
112 mod should {
113     use syn::{
114         parse::{Parse, ParseStream},
115         parse2, parse_str, ItemFn, ItemImpl, ItemStruct, Result,
116     };
117 
118     use crate::parse::{Attribute, Attributes};
119 
120     use super::*;
121     use crate::test::assert_eq;
122     use mytest::*;
123     use rstest_reuse::*;
124 
125     #[derive(Clone)]
126     struct FixtureOutput {
127         orig: ItemFn,
128         fixture: ItemStruct,
129         core_impl: ItemImpl,
130     }
131 
132     impl Parse for FixtureOutput {
parse(input: ParseStream) -> Result<Self>133         fn parse(input: ParseStream) -> Result<Self> {
134             Ok(FixtureOutput {
135                 fixture: input.parse()?,
136                 core_impl: input.parse()?,
137                 orig: input.parse()?,
138             })
139         }
140     }
141 
parse_fixture<S: AsRef<str>>(code: S) -> (ItemFn, FixtureOutput)142     fn parse_fixture<S: AsRef<str>>(code: S) -> (ItemFn, FixtureOutput) {
143         let item_fn = parse_str::<ItemFn>(code.as_ref()).unwrap();
144 
145         let tokens = render(item_fn.clone(), Default::default());
146         (item_fn, parse2(tokens).unwrap())
147     }
148 
test_maintains_function_visibility(code: &str)149     fn test_maintains_function_visibility(code: &str) {
150         let (item_fn, out) = parse_fixture(code);
151 
152         assert_eq!(item_fn.vis, out.fixture.vis);
153         assert_eq!(item_fn.vis, out.orig.vis);
154     }
155 
select_method<S: AsRef<str>>(impl_code: ItemImpl, name: S) -> Option<syn::ImplItemMethod>156     fn select_method<S: AsRef<str>>(impl_code: ItemImpl, name: S) -> Option<syn::ImplItemMethod> {
157         impl_code
158             .items
159             .into_iter()
160             .filter_map(|ii| match ii {
161                 syn::ImplItem::Method(f) => Some(f),
162                 _ => None,
163             })
164             .find(|f| f.sig.ident == name.as_ref())
165     }
166 
167     #[test]
maintains_pub_visibility()168     fn maintains_pub_visibility() {
169         test_maintains_function_visibility(r#"pub fn test() { }"#);
170     }
171 
172     #[test]
maintains_no_pub_visibility()173     fn maintains_no_pub_visibility() {
174         test_maintains_function_visibility(r#"fn test() { }"#);
175     }
176 
177     #[test]
implement_a_get_method_with_input_fixture_signature()178     fn implement_a_get_method_with_input_fixture_signature() {
179         let (item_fn, out) = parse_fixture(
180             r#"
181                     pub fn test<R: AsRef<str>, B>(mut s: String, v: &u32, a: &mut [i32], r: R) -> (u32, B, String, &str)
182                             where B: Borrow<u32>
183                     { }
184                     "#,
185         );
186 
187         let mut signature = select_method(out.core_impl, "get").unwrap().sig;
188 
189         signature.ident = item_fn.sig.ident.clone();
190 
191         assert_eq!(item_fn.sig, signature);
192     }
193 
194     #[template]
195     #[rstest(
196         method => ["default", "get", "partial_1", "partial_2", "partial_3"])
197     ]
198     #[case::async_fn(true)]
199     #[case::not_async_fn(false)]
async_fixture_cases(#[case] is_async: bool, method: &str)200     fn async_fixture_cases(#[case] is_async: bool, method: &str) {}
201 
202     #[apply(async_fixture_cases)]
fixture_method_should_be_async_if_fixture_function_is_async( #[case] is_async: bool, method: &str, )203     fn fixture_method_should_be_async_if_fixture_function_is_async(
204         #[case] is_async: bool,
205         method: &str,
206     ) {
207         let prefix = if is_async { "async" } else { "" };
208         let (_, out) = parse_fixture(&format!(
209             r#"
210                     pub {} fn test(mut s: String, v: &u32, a: &mut [i32]) -> u32
211                             where B: Borrow<u32>
212                     {{ }}
213                     "#,
214             prefix
215         ));
216 
217         let signature = select_method(out.core_impl, method).unwrap().sig;
218 
219         assert_eq!(is_async, signature.asyncness.is_some());
220     }
221 
222     #[apply(async_fixture_cases)]
fixture_method_should_use_await_if_fixture_function_is_async( #[case] is_async: bool, method: &str, )223     fn fixture_method_should_use_await_if_fixture_function_is_async(
224         #[case] is_async: bool,
225         method: &str,
226     ) {
227         let prefix = if is_async { "async" } else { "" };
228         let (_, out) = parse_fixture(&format!(
229             r#"
230                     pub {} fn test(mut s: String, v: &u32, a: &mut [i32]) -> u32
231                     {{ }}
232                     "#,
233             prefix
234         ));
235 
236         let body = select_method(out.core_impl, method).unwrap().block;
237         let last_statment = body.stmts.last().unwrap();
238         let is_await = match last_statment {
239             syn::Stmt::Expr(syn::Expr::Await(_)) => true,
240             _ => false,
241         };
242 
243         assert_eq!(is_async, is_await);
244     }
245 
246     #[test]
implement_a_default_method_with_input_cleaned_fixture_signature_and_no_args()247     fn implement_a_default_method_with_input_cleaned_fixture_signature_and_no_args() {
248         let (item_fn, out) = parse_fixture(
249             r#"
250                     pub fn test<R: AsRef<str>, B, F, H: Iterator<Item=u32>>(mut s: String, v: &u32, a: &mut [i32], r: R) -> (H, B, String, &str)
251                         where F: ToString,
252                         B: Borrow<u32>
253 
254                     { }
255                     "#,
256         );
257 
258         let default_decl = select_method(out.core_impl, "default").unwrap().sig;
259 
260         let expected = parse_str::<ItemFn>(
261             r#"
262                     pub fn default<B, H: Iterator<Item=u32>>() -> (H, B, String, &str)
263                             where B: Borrow<u32>
264                     { }
265                     "#,
266         )
267         .unwrap();
268 
269         assert_eq!(expected.sig.generics, default_decl.generics);
270         assert_eq!(item_fn.sig.output, default_decl.output);
271         assert!(default_decl.inputs.is_empty());
272     }
273 
274     #[test]
use_default_return_type_if_any()275     fn use_default_return_type_if_any() {
276         let item_fn = parse_str::<ItemFn>(
277             r#"
278                     pub fn test<R: AsRef<str>, B, F, H: Iterator<Item=u32>>() -> (H, B)
279                             where F: ToString,
280                             B: Borrow<u32>
281                     { }
282                     "#,
283         )
284         .unwrap();
285 
286         let tokens = render(
287             item_fn.clone(),
288             FixtureInfo {
289                 attributes: Attributes {
290                     attributes: vec![Attribute::Type(
291                         parse_str("default").unwrap(),
292                         parse_str("(impl Iterator<Item=u32>, B)").unwrap(),
293                     )],
294                 }
295                 .into(),
296                 ..Default::default()
297             },
298         );
299         let out: FixtureOutput = parse2(tokens).unwrap();
300 
301         let expected = parse_str::<syn::ItemFn>(
302             r#"
303                     pub fn default<B>() -> (impl Iterator<Item=u32>, B)
304                             where B: Borrow<u32>
305                     { }
306                     "#,
307         )
308         .unwrap();
309 
310         let default_decl = select_method(out.core_impl, "default").unwrap().sig;
311 
312         assert_eq!(expected.sig, default_decl);
313     }
314 
315     #[test]
implement_partial_methods()316     fn implement_partial_methods() {
317         let (item_fn, out) = parse_fixture(
318             r#"
319                     pub fn test(mut s: String, v: &u32, a: &mut [i32]) -> usize
320                     { }
321                     "#,
322         );
323 
324         let partials = (1..=3)
325             .map(|n| {
326                 select_method(out.core_impl.clone(), format!("partial_{}", n))
327                     .unwrap()
328                     .sig
329             })
330             .collect::<Vec<_>>();
331 
332         // All 3 methods found
333 
334         assert!(select_method(out.core_impl, "partial_4").is_none());
335 
336         let expected_1 = parse_str::<ItemFn>(
337             r#"
338                     pub fn partial_1(mut s: String) -> usize
339                     { }
340                     "#,
341         )
342         .unwrap();
343 
344         assert_eq!(expected_1.sig, partials[0]);
345         for p in partials {
346             assert_eq!(item_fn.sig.output, p.output);
347         }
348     }
349 
350     #[rstest]
351     #[case::base("fn test<S: AsRef<str>, U: AsRef<u32>, F: ToString>(mut s: S, v: U) -> F {}",
352         vec![
353             "fn default<F: ToString>() -> F {}",
354             "fn partial_1<S: AsRef<str>, F: ToString>(mut s: S) -> F {}",
355             "fn partial_2<S: AsRef<str>, U: AsRef<u32>, F: ToString>(mut s: S, v: U) -> F {}",
356         ]
357     )]
358     #[case::associated_type("fn test<T: IntoIterator>(mut i: T) where T::Item: Copy {}",
359         vec![
360             "fn default() {}",
361             "fn partial_1<T: IntoIterator>(mut i: T) where T::Item: Copy {}",
362         ]
363     )]
364     #[case::not_remove_const_generics("fn test<const N:usize>(v: [u32; N]) -> [i32; N] {}",
365         vec![
366             "fn default<const N:usize>() -> [i32; N] {}",
367             "fn partial_1<const N:usize>(v: [u32; N]) -> [i32; N] {}",
368         ]
369     )]
370     #[case::remove_const_generics("fn test<const N:usize>(a: i32, v: [u32; N]) {}",
371         vec![
372             "fn default() {}",
373             "fn partial_1(a:i32) {}",
374             "fn partial_2<const N:usize>(a:i32, v: [u32; N]) {}",
375         ]
376     )]
377 
clean_generics(#[case] code: &str, #[case] expected: Vec<&str>)378     fn clean_generics(#[case] code: &str, #[case] expected: Vec<&str>) {
379         let (item_fn, out) = parse_fixture(code);
380         let n_args = item_fn.sig.inputs.iter().count();
381 
382         let mut signatures = vec![select_method(out.core_impl.clone(), "default").unwrap().sig];
383         signatures.extend((1..=n_args).map(|n| {
384             select_method(out.core_impl.clone(), format!("partial_{}", n))
385                 .unwrap()
386                 .sig
387         }));
388 
389         let expected = expected
390             .into_iter()
391             .map(parse_str::<ItemFn>)
392             .map(|f| f.unwrap().sig)
393             .collect::<Vec<_>>();
394 
395         assert_eq!(expected, signatures);
396     }
397 
398     #[test]
use_partial_return_type_if_any()399     fn use_partial_return_type_if_any() {
400         let item_fn = parse_str::<ItemFn>(
401             r#"
402                     pub fn test<R: AsRef<str>, B, F, H: Iterator<Item=u32>>(h: H, b: B) -> (H, B)
403                             where F: ToString,
404                             B: Borrow<u32>
405                     { }
406                      "#,
407         )
408         .unwrap();
409 
410         let tokens = render(
411             item_fn.clone(),
412             FixtureInfo {
413                 attributes: Attributes {
414                     attributes: vec![Attribute::Type(
415                         parse_str("partial_1").unwrap(),
416                         parse_str("(H, impl Iterator<Item=u32>)").unwrap(),
417                     )],
418                 }
419                 .into(),
420                 ..Default::default()
421             },
422         );
423         let out: FixtureOutput = parse2(tokens).unwrap();
424 
425         let expected = parse_str::<syn::ItemFn>(
426             r#"
427                     pub fn partial_1<H: Iterator<Item=u32>>(h: H) -> (H, impl Iterator<Item=u32>)
428                     { }
429                     "#,
430         )
431         .unwrap();
432 
433         let partial = select_method(out.core_impl, "partial_1").unwrap();
434 
435         assert_eq!(expected.sig, partial.sig);
436     }
437 }
438