1 pub(crate) mod fixture;
2 mod test;
3 mod wrapper;
4 
5 use std::collections::HashMap;
6 use syn::token::Async;
7 
8 use proc_macro2::{Span, TokenStream};
9 use syn::{parse_quote, Attribute, Expr, FnArg, Ident, ItemFn, Path, ReturnType, Stmt};
10 
11 use quote::{format_ident, quote};
12 
13 use crate::utils::attr_ends_with;
14 use crate::{
15     parse::{
16         rstest::{RsTestAttributes, RsTestData, RsTestInfo},
17         testcase::TestCase,
18         vlist::ValueList,
19     },
20     utils::attr_is,
21 };
22 use crate::{
23     refident::MaybeIdent,
24     resolver::{self, Resolver},
25 };
26 use wrapper::WrapByModule;
27 
28 pub(crate) use fixture::render as fixture;
29 pub(crate) mod inject;
30 
single(mut test: ItemFn, info: RsTestInfo) -> TokenStream31 pub(crate) fn single(mut test: ItemFn, info: RsTestInfo) -> TokenStream {
32     let resolver = resolver::fixtures::get(info.data.fixtures());
33     let args = test.sig.inputs.iter().cloned().collect::<Vec<_>>();
34     let attrs = std::mem::take(&mut test.attrs);
35     let asyncness = test.sig.asyncness;
36     let generic_types = test
37         .sig
38         .generics
39         .type_params()
40         .map(|tp| &tp.ident)
41         .cloned()
42         .collect::<Vec<_>>();
43 
44     single_test_case(
45         &test.sig.ident,
46         &test.sig.ident,
47         &args,
48         &attrs,
49         &test.sig.output,
50         asyncness,
51         Some(&test),
52         resolver,
53         &info.attributes,
54         &generic_types,
55     )
56 }
57 
parametrize(test: ItemFn, info: RsTestInfo) -> TokenStream58 pub(crate) fn parametrize(test: ItemFn, info: RsTestInfo) -> TokenStream {
59     let RsTestInfo { data, attributes } = info;
60     let resolver_fixtures = resolver::fixtures::get(data.fixtures());
61 
62     let rendered_cases = cases_data(&data, test.sig.ident.span())
63         .map(|(name, attrs, resolver)| {
64             TestCaseRender::new(name, attrs, (resolver, &resolver_fixtures))
65         })
66         .map(|case| case.render(&test, &attributes))
67         .collect();
68 
69     test_group(test, rendered_cases)
70 }
71 
72 impl ValueList {
render( &self, test: &ItemFn, resolver: &dyn Resolver, attrs: &[syn::Attribute], attributes: &RsTestAttributes, ) -> TokenStream73     fn render(
74         &self,
75         test: &ItemFn,
76         resolver: &dyn Resolver,
77         attrs: &[syn::Attribute],
78         attributes: &RsTestAttributes,
79     ) -> TokenStream {
80         let span = test.sig.ident.span();
81         let test_cases = self
82             .argument_data(resolver)
83             .map(|(name, r)| TestCaseRender::new(Ident::new(&name, span), attrs, r))
84             .map(|test_case| test_case.render(test, attributes));
85 
86         quote! { #(#test_cases)* }
87     }
88 
argument_data<'a>( &'a self, resolver: &'a dyn Resolver, ) -> impl Iterator<Item = (String, Box<(&'a dyn Resolver, (String, Expr))>)> + 'a89     fn argument_data<'a>(
90         &'a self,
91         resolver: &'a dyn Resolver,
92     ) -> impl Iterator<Item = (String, Box<(&'a dyn Resolver, (String, Expr))>)> + 'a {
93         let max_len = self.values.len();
94         self.values.iter().enumerate().map(move |(index, expr)| {
95             let name = format!(
96                 "{}_{:0len$}",
97                 self.arg,
98                 index + 1,
99                 len = max_len.display_len()
100             );
101             let resolver_this = (self.arg.to_string(), expr.clone());
102             (name, Box::new((resolver, resolver_this)))
103         })
104     }
105 }
106 
_matrix_recursive<'a>( test: &ItemFn, list_values: &'a [&'a ValueList], resolver: &dyn Resolver, attrs: &'a [syn::Attribute], attributes: &RsTestAttributes, ) -> TokenStream107 fn _matrix_recursive<'a>(
108     test: &ItemFn,
109     list_values: &'a [&'a ValueList],
110     resolver: &dyn Resolver,
111     attrs: &'a [syn::Attribute],
112     attributes: &RsTestAttributes,
113 ) -> TokenStream {
114     if list_values.is_empty() {
115         return Default::default();
116     }
117     let vlist = list_values[0];
118     let list_values = &list_values[1..];
119 
120     if list_values.is_empty() {
121         vlist.render(test, resolver, attrs, attributes)
122     } else {
123         let span = test.sig.ident.span();
124         let modules = vlist.argument_data(resolver).map(move |(name, resolver)| {
125             _matrix_recursive(test, list_values, &resolver, attrs, attributes)
126                 .wrap_by_mod(&Ident::new(&name, span))
127         });
128 
129         quote! { #(#modules)* }
130     }
131 }
132 
matrix(test: ItemFn, info: RsTestInfo) -> TokenStream133 pub(crate) fn matrix(test: ItemFn, info: RsTestInfo) -> TokenStream {
134     let RsTestInfo {
135         data, attributes, ..
136     } = info;
137     let span = test.sig.ident.span();
138 
139     let cases = cases_data(&data, span).collect::<Vec<_>>();
140 
141     let resolver = resolver::fixtures::get(data.fixtures());
142     let rendered_cases = if cases.is_empty() {
143         let list_values = data.list_values().collect::<Vec<_>>();
144         _matrix_recursive(&test, &list_values, &resolver, &[], &attributes)
145     } else {
146         cases
147             .into_iter()
148             .map(|(case_name, attrs, case_resolver)| {
149                 let list_values = data.list_values().collect::<Vec<_>>();
150                 _matrix_recursive(
151                     &test,
152                     &list_values,
153                     &(case_resolver, &resolver),
154                     attrs,
155                     &attributes,
156                 )
157                 .wrap_by_mod(&case_name)
158             })
159             .collect()
160     };
161 
162     test_group(test, rendered_cases)
163 }
164 
resolve_default_test_attr(is_async: bool) -> TokenStream165 fn resolve_default_test_attr(is_async: bool) -> TokenStream {
166     if is_async {
167         quote! { #[async_std::test] }
168     } else {
169         quote! { #[test] }
170     }
171 }
172 
render_exec_call(fn_path: Path, args: &[Ident], is_async: bool) -> TokenStream173 fn render_exec_call(fn_path: Path, args: &[Ident], is_async: bool) -> TokenStream {
174     if is_async {
175         quote! {#fn_path(#(#args),*).await}
176     } else {
177         quote! {#fn_path(#(#args),*)}
178     }
179 }
180 
181 /// Render a single test case:
182 ///
183 /// * `name` - Test case name
184 /// * `testfn_name` - The name of test function to call
185 /// * `args` - The arguments of the test function
186 /// * `attrs` - The expected test attributes
187 /// * `output` - The expected test return type
188 /// * `asyncness` - The `async` fn token
189 /// * `test_impl` - If you want embed test function (should be the one called by `testfn_name`)
190 /// * `resolver` - The resolver used to resolve injected values
191 /// * `attributes` - Test attributes to select test behaviour
192 /// * `generic_types` - The genrics type used in signature
193 ///
194 // Ok I need some refactoring here but now that not a real issue
195 #[allow(clippy::too_many_arguments)]
single_test_case<'a>( name: &Ident, testfn_name: &Ident, args: &[FnArg], attrs: &[Attribute], output: &ReturnType, asyncness: Option<Async>, test_impl: Option<&ItemFn>, resolver: impl Resolver, attributes: &'a RsTestAttributes, generic_types: &[Ident], ) -> TokenStream196 fn single_test_case<'a>(
197     name: &Ident,
198     testfn_name: &Ident,
199     args: &[FnArg],
200     attrs: &[Attribute],
201     output: &ReturnType,
202     asyncness: Option<Async>,
203     test_impl: Option<&ItemFn>,
204     resolver: impl Resolver,
205     attributes: &'a RsTestAttributes,
206     generic_types: &[Ident],
207 ) -> TokenStream {
208     let (attrs, trace_me): (Vec<_>, Vec<_>) =
209         attrs.iter().cloned().partition(|a| !attr_is(a, "trace"));
210     let mut attributes = attributes.clone();
211     if !trace_me.is_empty() {
212         attributes.add_trace(format_ident!("trace"));
213     }
214     let inject = inject::resolve_aruments(args.iter(), &resolver, generic_types);
215     let args = args
216         .iter()
217         .filter_map(MaybeIdent::maybe_ident)
218         .cloned()
219         .collect::<Vec<_>>();
220     let trace_args = trace_arguments(args.iter(), &attributes);
221 
222     let is_async = asyncness.is_some();
223     // If no injected attribut provided use the default one
224     let test_attr = if attrs
225         .iter()
226         .any(|a| attr_ends_with(a, &parse_quote! {test}))
227     {
228         None
229     } else {
230         Some(resolve_default_test_attr(is_async))
231     };
232     let execute = render_exec_call(testfn_name.clone().into(), &args, is_async);
233 
234     quote! {
235         #test_attr
236         #(#attrs)*
237         #asyncness fn #name() #output {
238             #test_impl
239             #inject
240             #trace_args
241             println!("{:-^40}", " TEST START ");
242             #execute
243         }
244     }
245 }
246 
trace_arguments<'a>( args: impl Iterator<Item = &'a Ident>, attributes: &RsTestAttributes, ) -> Option<TokenStream>247 fn trace_arguments<'a>(
248     args: impl Iterator<Item = &'a Ident>,
249     attributes: &RsTestAttributes,
250 ) -> Option<TokenStream> {
251     let mut statements = args
252         .filter(|&arg| attributes.trace_me(arg))
253         .map(|arg| {
254             let s: Stmt = parse_quote! {
255                 println!("{} = {:?}", stringify!(#arg), #arg);
256             };
257             s
258         })
259         .peekable();
260     if statements.peek().is_some() {
261         Some(quote! {
262             println!("{:-^40}", " TEST ARGUMENTS ");
263             #(#statements)*
264         })
265     } else {
266         None
267     }
268 }
269 
270 struct TestCaseRender<'a> {
271     name: Ident,
272     attrs: &'a [syn::Attribute],
273     resolver: Box<dyn Resolver + 'a>,
274 }
275 
276 impl<'a> TestCaseRender<'a> {
new<R: Resolver + 'a>(name: Ident, attrs: &'a [syn::Attribute], resolver: R) -> Self277     pub fn new<R: Resolver + 'a>(name: Ident, attrs: &'a [syn::Attribute], resolver: R) -> Self {
278         TestCaseRender {
279             name,
280             attrs,
281             resolver: Box::new(resolver),
282         }
283     }
284 
render(self, testfn: &ItemFn, attributes: &RsTestAttributes) -> TokenStream285     fn render(self, testfn: &ItemFn, attributes: &RsTestAttributes) -> TokenStream {
286         let args = testfn.sig.inputs.iter().cloned().collect::<Vec<_>>();
287         let mut attrs = testfn.attrs.clone();
288         attrs.extend(self.attrs.iter().cloned());
289         let asyncness = testfn.sig.asyncness;
290         let generic_types = testfn
291             .sig
292             .generics
293             .type_params()
294             .map(|tp| &tp.ident)
295             .cloned()
296             .collect::<Vec<_>>();
297 
298         single_test_case(
299             &self.name,
300             &testfn.sig.ident,
301             &args,
302             &attrs,
303             &testfn.sig.output,
304             asyncness,
305             None,
306             self.resolver,
307             attributes,
308             &generic_types,
309         )
310     }
311 }
312 
test_group(mut test: ItemFn, rendered_cases: TokenStream) -> TokenStream313 fn test_group(mut test: ItemFn, rendered_cases: TokenStream) -> TokenStream {
314     let fname = &test.sig.ident;
315     test.attrs = vec![];
316 
317     quote! {
318         #[cfg(test)]
319         #test
320 
321         #[cfg(test)]
322         mod #fname {
323             use super::*;
324 
325             #rendered_cases
326         }
327     }
328 }
329 
330 trait DisplayLen {
display_len(&self) -> usize331     fn display_len(&self) -> usize;
332 }
333 
334 impl<D: std::fmt::Display> DisplayLen for D {
display_len(&self) -> usize335     fn display_len(&self) -> usize {
336         format!("{}", self).len()
337     }
338 }
339 
format_case_name(case: &TestCase, index: usize, display_len: usize) -> String340 fn format_case_name(case: &TestCase, index: usize, display_len: usize) -> String {
341     let description = case
342         .description
343         .as_ref()
344         .map(|d| format!("_{}", d))
345         .unwrap_or_default();
346     format!(
347         "case_{:0len$}{d}",
348         index,
349         len = display_len,
350         d = description
351     )
352 }
353 
cases_data( data: &RsTestData, name_span: Span, ) -> impl Iterator<Item = (Ident, &[syn::Attribute], HashMap<String, &syn::Expr>)>354 fn cases_data(
355     data: &RsTestData,
356     name_span: Span,
357 ) -> impl Iterator<Item = (Ident, &[syn::Attribute], HashMap<String, &syn::Expr>)> {
358     let display_len = data.cases().count().display_len();
359     data.cases().enumerate().map({
360         move |(n, case)| {
361             let resolver_case = data
362                 .case_args()
363                 .map(|a| a.to_string())
364                 .zip(case.args.iter())
365                 .collect::<HashMap<_, _>>();
366             (
367                 Ident::new(&format_case_name(case, n + 1, display_len), name_span),
368                 case.attrs.as_slice(),
369                 resolver_case,
370             )
371         }
372     })
373 }
374