1 pub(crate) mod fixture;
2 mod test;
3 mod wrapper;
4
5 use std::collections::HashMap;
6 use syn::token::Async;
7
8 use proc_macro2::{Span, TokenStream};
9 use syn::{parse_quote, Attribute, Expr, FnArg, Ident, ItemFn, Path, ReturnType, Stmt};
10
11 use quote::{format_ident, quote};
12
13 use crate::utils::attr_ends_with;
14 use crate::{
15 parse::{
16 rstest::{RsTestAttributes, RsTestData, RsTestInfo},
17 testcase::TestCase,
18 vlist::ValueList,
19 },
20 utils::attr_is,
21 };
22 use crate::{
23 refident::MaybeIdent,
24 resolver::{self, Resolver},
25 };
26 use wrapper::WrapByModule;
27
28 pub(crate) use fixture::render as fixture;
29 pub(crate) mod inject;
30
single(mut test: ItemFn, info: RsTestInfo) -> TokenStream31 pub(crate) fn single(mut test: ItemFn, info: RsTestInfo) -> TokenStream {
32 let resolver = resolver::fixtures::get(info.data.fixtures());
33 let args = test.sig.inputs.iter().cloned().collect::<Vec<_>>();
34 let attrs = std::mem::take(&mut test.attrs);
35 let asyncness = test.sig.asyncness;
36 let generic_types = test
37 .sig
38 .generics
39 .type_params()
40 .map(|tp| &tp.ident)
41 .cloned()
42 .collect::<Vec<_>>();
43
44 single_test_case(
45 &test.sig.ident,
46 &test.sig.ident,
47 &args,
48 &attrs,
49 &test.sig.output,
50 asyncness,
51 Some(&test),
52 resolver,
53 &info.attributes,
54 &generic_types,
55 )
56 }
57
parametrize(test: ItemFn, info: RsTestInfo) -> TokenStream58 pub(crate) fn parametrize(test: ItemFn, info: RsTestInfo) -> TokenStream {
59 let RsTestInfo { data, attributes } = info;
60 let resolver_fixtures = resolver::fixtures::get(data.fixtures());
61
62 let rendered_cases = cases_data(&data, test.sig.ident.span())
63 .map(|(name, attrs, resolver)| {
64 TestCaseRender::new(name, attrs, (resolver, &resolver_fixtures))
65 })
66 .map(|case| case.render(&test, &attributes))
67 .collect();
68
69 test_group(test, rendered_cases)
70 }
71
72 impl ValueList {
render( &self, test: &ItemFn, resolver: &dyn Resolver, attrs: &[syn::Attribute], attributes: &RsTestAttributes, ) -> TokenStream73 fn render(
74 &self,
75 test: &ItemFn,
76 resolver: &dyn Resolver,
77 attrs: &[syn::Attribute],
78 attributes: &RsTestAttributes,
79 ) -> TokenStream {
80 let span = test.sig.ident.span();
81 let test_cases = self
82 .argument_data(resolver)
83 .map(|(name, r)| TestCaseRender::new(Ident::new(&name, span), attrs, r))
84 .map(|test_case| test_case.render(test, attributes));
85
86 quote! { #(#test_cases)* }
87 }
88
argument_data<'a>( &'a self, resolver: &'a dyn Resolver, ) -> impl Iterator<Item = (String, Box<(&'a dyn Resolver, (String, Expr))>)> + 'a89 fn argument_data<'a>(
90 &'a self,
91 resolver: &'a dyn Resolver,
92 ) -> impl Iterator<Item = (String, Box<(&'a dyn Resolver, (String, Expr))>)> + 'a {
93 let max_len = self.values.len();
94 self.values.iter().enumerate().map(move |(index, expr)| {
95 let name = format!(
96 "{}_{:0len$}",
97 self.arg,
98 index + 1,
99 len = max_len.display_len()
100 );
101 let resolver_this = (self.arg.to_string(), expr.clone());
102 (name, Box::new((resolver, resolver_this)))
103 })
104 }
105 }
106
_matrix_recursive<'a>( test: &ItemFn, list_values: &'a [&'a ValueList], resolver: &dyn Resolver, attrs: &'a [syn::Attribute], attributes: &RsTestAttributes, ) -> TokenStream107 fn _matrix_recursive<'a>(
108 test: &ItemFn,
109 list_values: &'a [&'a ValueList],
110 resolver: &dyn Resolver,
111 attrs: &'a [syn::Attribute],
112 attributes: &RsTestAttributes,
113 ) -> TokenStream {
114 if list_values.is_empty() {
115 return Default::default();
116 }
117 let vlist = list_values[0];
118 let list_values = &list_values[1..];
119
120 if list_values.is_empty() {
121 vlist.render(test, resolver, attrs, attributes)
122 } else {
123 let span = test.sig.ident.span();
124 let modules = vlist.argument_data(resolver).map(move |(name, resolver)| {
125 _matrix_recursive(test, list_values, &resolver, attrs, attributes)
126 .wrap_by_mod(&Ident::new(&name, span))
127 });
128
129 quote! { #(#modules)* }
130 }
131 }
132
matrix(test: ItemFn, info: RsTestInfo) -> TokenStream133 pub(crate) fn matrix(test: ItemFn, info: RsTestInfo) -> TokenStream {
134 let RsTestInfo {
135 data, attributes, ..
136 } = info;
137 let span = test.sig.ident.span();
138
139 let cases = cases_data(&data, span).collect::<Vec<_>>();
140
141 let resolver = resolver::fixtures::get(data.fixtures());
142 let rendered_cases = if cases.is_empty() {
143 let list_values = data.list_values().collect::<Vec<_>>();
144 _matrix_recursive(&test, &list_values, &resolver, &[], &attributes)
145 } else {
146 cases
147 .into_iter()
148 .map(|(case_name, attrs, case_resolver)| {
149 let list_values = data.list_values().collect::<Vec<_>>();
150 _matrix_recursive(
151 &test,
152 &list_values,
153 &(case_resolver, &resolver),
154 attrs,
155 &attributes,
156 )
157 .wrap_by_mod(&case_name)
158 })
159 .collect()
160 };
161
162 test_group(test, rendered_cases)
163 }
164
resolve_default_test_attr(is_async: bool) -> TokenStream165 fn resolve_default_test_attr(is_async: bool) -> TokenStream {
166 if is_async {
167 quote! { #[async_std::test] }
168 } else {
169 quote! { #[test] }
170 }
171 }
172
render_exec_call(fn_path: Path, args: &[Ident], is_async: bool) -> TokenStream173 fn render_exec_call(fn_path: Path, args: &[Ident], is_async: bool) -> TokenStream {
174 if is_async {
175 quote! {#fn_path(#(#args),*).await}
176 } else {
177 quote! {#fn_path(#(#args),*)}
178 }
179 }
180
181 /// Render a single test case:
182 ///
183 /// * `name` - Test case name
184 /// * `testfn_name` - The name of test function to call
185 /// * `args` - The arguments of the test function
186 /// * `attrs` - The expected test attributes
187 /// * `output` - The expected test return type
188 /// * `asyncness` - The `async` fn token
189 /// * `test_impl` - If you want embed test function (should be the one called by `testfn_name`)
190 /// * `resolver` - The resolver used to resolve injected values
191 /// * `attributes` - Test attributes to select test behaviour
192 /// * `generic_types` - The genrics type used in signature
193 ///
194 // Ok I need some refactoring here but now that not a real issue
195 #[allow(clippy::too_many_arguments)]
single_test_case<'a>( name: &Ident, testfn_name: &Ident, args: &[FnArg], attrs: &[Attribute], output: &ReturnType, asyncness: Option<Async>, test_impl: Option<&ItemFn>, resolver: impl Resolver, attributes: &'a RsTestAttributes, generic_types: &[Ident], ) -> TokenStream196 fn single_test_case<'a>(
197 name: &Ident,
198 testfn_name: &Ident,
199 args: &[FnArg],
200 attrs: &[Attribute],
201 output: &ReturnType,
202 asyncness: Option<Async>,
203 test_impl: Option<&ItemFn>,
204 resolver: impl Resolver,
205 attributes: &'a RsTestAttributes,
206 generic_types: &[Ident],
207 ) -> TokenStream {
208 let (attrs, trace_me): (Vec<_>, Vec<_>) =
209 attrs.iter().cloned().partition(|a| !attr_is(a, "trace"));
210 let mut attributes = attributes.clone();
211 if !trace_me.is_empty() {
212 attributes.add_trace(format_ident!("trace"));
213 }
214 let inject = inject::resolve_aruments(args.iter(), &resolver, generic_types);
215 let args = args
216 .iter()
217 .filter_map(MaybeIdent::maybe_ident)
218 .cloned()
219 .collect::<Vec<_>>();
220 let trace_args = trace_arguments(args.iter(), &attributes);
221
222 let is_async = asyncness.is_some();
223 // If no injected attribut provided use the default one
224 let test_attr = if attrs
225 .iter()
226 .any(|a| attr_ends_with(a, &parse_quote! {test}))
227 {
228 None
229 } else {
230 Some(resolve_default_test_attr(is_async))
231 };
232 let execute = render_exec_call(testfn_name.clone().into(), &args, is_async);
233
234 quote! {
235 #test_attr
236 #(#attrs)*
237 #asyncness fn #name() #output {
238 #test_impl
239 #inject
240 #trace_args
241 println!("{:-^40}", " TEST START ");
242 #execute
243 }
244 }
245 }
246
trace_arguments<'a>( args: impl Iterator<Item = &'a Ident>, attributes: &RsTestAttributes, ) -> Option<TokenStream>247 fn trace_arguments<'a>(
248 args: impl Iterator<Item = &'a Ident>,
249 attributes: &RsTestAttributes,
250 ) -> Option<TokenStream> {
251 let mut statements = args
252 .filter(|&arg| attributes.trace_me(arg))
253 .map(|arg| {
254 let s: Stmt = parse_quote! {
255 println!("{} = {:?}", stringify!(#arg), #arg);
256 };
257 s
258 })
259 .peekable();
260 if statements.peek().is_some() {
261 Some(quote! {
262 println!("{:-^40}", " TEST ARGUMENTS ");
263 #(#statements)*
264 })
265 } else {
266 None
267 }
268 }
269
270 struct TestCaseRender<'a> {
271 name: Ident,
272 attrs: &'a [syn::Attribute],
273 resolver: Box<dyn Resolver + 'a>,
274 }
275
276 impl<'a> TestCaseRender<'a> {
new<R: Resolver + 'a>(name: Ident, attrs: &'a [syn::Attribute], resolver: R) -> Self277 pub fn new<R: Resolver + 'a>(name: Ident, attrs: &'a [syn::Attribute], resolver: R) -> Self {
278 TestCaseRender {
279 name,
280 attrs,
281 resolver: Box::new(resolver),
282 }
283 }
284
render(self, testfn: &ItemFn, attributes: &RsTestAttributes) -> TokenStream285 fn render(self, testfn: &ItemFn, attributes: &RsTestAttributes) -> TokenStream {
286 let args = testfn.sig.inputs.iter().cloned().collect::<Vec<_>>();
287 let mut attrs = testfn.attrs.clone();
288 attrs.extend(self.attrs.iter().cloned());
289 let asyncness = testfn.sig.asyncness;
290 let generic_types = testfn
291 .sig
292 .generics
293 .type_params()
294 .map(|tp| &tp.ident)
295 .cloned()
296 .collect::<Vec<_>>();
297
298 single_test_case(
299 &self.name,
300 &testfn.sig.ident,
301 &args,
302 &attrs,
303 &testfn.sig.output,
304 asyncness,
305 None,
306 self.resolver,
307 attributes,
308 &generic_types,
309 )
310 }
311 }
312
test_group(mut test: ItemFn, rendered_cases: TokenStream) -> TokenStream313 fn test_group(mut test: ItemFn, rendered_cases: TokenStream) -> TokenStream {
314 let fname = &test.sig.ident;
315 test.attrs = vec![];
316
317 quote! {
318 #[cfg(test)]
319 #test
320
321 #[cfg(test)]
322 mod #fname {
323 use super::*;
324
325 #rendered_cases
326 }
327 }
328 }
329
330 trait DisplayLen {
display_len(&self) -> usize331 fn display_len(&self) -> usize;
332 }
333
334 impl<D: std::fmt::Display> DisplayLen for D {
display_len(&self) -> usize335 fn display_len(&self) -> usize {
336 format!("{}", self).len()
337 }
338 }
339
format_case_name(case: &TestCase, index: usize, display_len: usize) -> String340 fn format_case_name(case: &TestCase, index: usize, display_len: usize) -> String {
341 let description = case
342 .description
343 .as_ref()
344 .map(|d| format!("_{}", d))
345 .unwrap_or_default();
346 format!(
347 "case_{:0len$}{d}",
348 index,
349 len = display_len,
350 d = description
351 )
352 }
353
cases_data( data: &RsTestData, name_span: Span, ) -> impl Iterator<Item = (Ident, &[syn::Attribute], HashMap<String, &syn::Expr>)>354 fn cases_data(
355 data: &RsTestData,
356 name_span: Span,
357 ) -> impl Iterator<Item = (Ident, &[syn::Attribute], HashMap<String, &syn::Expr>)> {
358 let display_len = data.cases().count().display_len();
359 data.cases().enumerate().map({
360 move |(n, case)| {
361 let resolver_case = data
362 .case_args()
363 .map(|a| a.to_string())
364 .zip(case.args.iter())
365 .collect::<HashMap<_, _>>();
366 (
367 Ident::new(&format_case_name(case, n + 1, display_len), name_span),
368 case.attrs.as_slice(),
369 resolver_case,
370 )
371 }
372 })
373 }
374