1import collections.abc
2import inspect
3import typing
4import warnings
5from typing import Any
6from typing import Callable
7from typing import Iterable
8from typing import Iterator
9from typing import List
10from typing import Mapping
11from typing import NamedTuple
12from typing import Optional
13from typing import Sequence
14from typing import Set
15from typing import Tuple
16from typing import TypeVar
17from typing import Union
18
19import attr
20
21from .._code import getfslineno
22from ..compat import ascii_escaped
23from ..compat import final
24from ..compat import NOTSET
25from ..compat import NotSetType
26from ..compat import overload
27from ..compat import TYPE_CHECKING
28from _pytest.config import Config
29from _pytest.outcomes import fail
30from _pytest.warning_types import PytestUnknownMarkWarning
31
32if TYPE_CHECKING:
33    from typing import Type
34
35    from ..nodes import Node
36
37
38EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"
39
40
41def istestfunc(func) -> bool:
42    return (
43        hasattr(func, "__call__")
44        and getattr(func, "__name__", "<lambda>") != "<lambda>"
45    )
46
47
48def get_empty_parameterset_mark(
49    config: Config, argnames: Sequence[str], func
50) -> "MarkDecorator":
51    from ..nodes import Collector
52
53    fs, lineno = getfslineno(func)
54    reason = "got empty parameter set %r, function %s at %s:%d" % (
55        argnames,
56        func.__name__,
57        fs,
58        lineno,
59    )
60
61    requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
62    if requested_mark in ("", None, "skip"):
63        mark = MARK_GEN.skip(reason=reason)
64    elif requested_mark == "xfail":
65        mark = MARK_GEN.xfail(reason=reason, run=False)
66    elif requested_mark == "fail_at_collect":
67        f_name = func.__name__
68        _, lineno = getfslineno(func)
69        raise Collector.CollectError(
70            "Empty parameter set in '%s' at line %d" % (f_name, lineno + 1)
71        )
72    else:
73        raise LookupError(requested_mark)
74    return mark
75
76
77class ParameterSet(
78    NamedTuple(
79        "ParameterSet",
80        [
81            ("values", Sequence[Union[object, NotSetType]]),
82            ("marks", "typing.Collection[Union[MarkDecorator, Mark]]"),
83            ("id", Optional[str]),
84        ],
85    )
86):
87    @classmethod
88    def param(
89        cls,
90        *values: object,
91        marks: "Union[MarkDecorator, typing.Collection[Union[MarkDecorator, Mark]]]" = (),
92        id: Optional[str] = None
93    ) -> "ParameterSet":
94        if isinstance(marks, MarkDecorator):
95            marks = (marks,)
96        else:
97            # TODO(py36): Change to collections.abc.Collection.
98            assert isinstance(marks, (collections.abc.Sequence, set))
99
100        if id is not None:
101            if not isinstance(id, str):
102                raise TypeError(
103                    "Expected id to be a string, got {}: {!r}".format(type(id), id)
104                )
105            id = ascii_escaped(id)
106        return cls(values, marks, id)
107
108    @classmethod
109    def extract_from(
110        cls,
111        parameterset: Union["ParameterSet", Sequence[object], object],
112        force_tuple: bool = False,
113    ) -> "ParameterSet":
114        """Extract from an object or objects.
115
116        :param parameterset:
117            A legacy style parameterset that may or may not be a tuple,
118            and may or may not be wrapped into a mess of mark objects.
119
120        :param force_tuple:
121            Enforce tuple wrapping so single argument tuple values
122            don't get decomposed and break tests.
123        """
124
125        if isinstance(parameterset, cls):
126            return parameterset
127        if force_tuple:
128            return cls.param(parameterset)
129        else:
130            # TODO: Refactor to fix this type-ignore. Currently the following
131            # type-checks but crashes:
132            #
133            #   @pytest.mark.parametrize(('x', 'y'), [1, 2])
134            #   def test_foo(x, y): pass
135            return cls(parameterset, marks=[], id=None)  # type: ignore[arg-type]
136
137    @staticmethod
138    def _parse_parametrize_args(
139        argnames: Union[str, List[str], Tuple[str, ...]],
140        argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
141        *args,
142        **kwargs
143    ) -> Tuple[Union[List[str], Tuple[str, ...]], bool]:
144        if not isinstance(argnames, (tuple, list)):
145            argnames = [x.strip() for x in argnames.split(",") if x.strip()]
146            force_tuple = len(argnames) == 1
147        else:
148            force_tuple = False
149        return argnames, force_tuple
150
151    @staticmethod
152    def _parse_parametrize_parameters(
153        argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
154        force_tuple: bool,
155    ) -> List["ParameterSet"]:
156        return [
157            ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
158        ]
159
160    @classmethod
161    def _for_parametrize(
162        cls,
163        argnames: Union[str, List[str], Tuple[str, ...]],
164        argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
165        func,
166        config: Config,
167        nodeid: str,
168    ) -> Tuple[Union[List[str], Tuple[str, ...]], List["ParameterSet"]]:
169        argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
170        parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
171        del argvalues
172
173        if parameters:
174            # Check all parameter sets have the correct number of values.
175            for param in parameters:
176                if len(param.values) != len(argnames):
177                    msg = (
178                        '{nodeid}: in "parametrize" the number of names ({names_len}):\n'
179                        "  {names}\n"
180                        "must be equal to the number of values ({values_len}):\n"
181                        "  {values}"
182                    )
183                    fail(
184                        msg.format(
185                            nodeid=nodeid,
186                            values=param.values,
187                            names=argnames,
188                            names_len=len(argnames),
189                            values_len=len(param.values),
190                        ),
191                        pytrace=False,
192                    )
193        else:
194            # Empty parameter set (likely computed at runtime): create a single
195            # parameter set with NOTSET values, with the "empty parameter set" mark applied to it.
196            mark = get_empty_parameterset_mark(config, argnames, func)
197            parameters.append(
198                ParameterSet(values=(NOTSET,) * len(argnames), marks=[mark], id=None)
199            )
200        return argnames, parameters
201
202
203@final
204@attr.s(frozen=True)
205class Mark:
206    #: Name of the mark.
207    name = attr.ib(type=str)
208    #: Positional arguments of the mark decorator.
209    args = attr.ib(type=Tuple[Any, ...])
210    #: Keyword arguments of the mark decorator.
211    kwargs = attr.ib(type=Mapping[str, Any])
212
213    #: Source Mark for ids with parametrize Marks.
214    _param_ids_from = attr.ib(type=Optional["Mark"], default=None, repr=False)
215    #: Resolved/generated ids with parametrize Marks.
216    _param_ids_generated = attr.ib(
217        type=Optional[Sequence[str]], default=None, repr=False
218    )
219
220    def _has_param_ids(self) -> bool:
221        return "ids" in self.kwargs or len(self.args) >= 4
222
223    def combined_with(self, other: "Mark") -> "Mark":
224        """Return a new Mark which is a combination of this
225        Mark and another Mark.
226
227        Combines by appending args and merging kwargs.
228
229        :param Mark other: The mark to combine with.
230        :rtype: Mark
231        """
232        assert self.name == other.name
233
234        # Remember source of ids with parametrize Marks.
235        param_ids_from = None  # type: Optional[Mark]
236        if self.name == "parametrize":
237            if other._has_param_ids():
238                param_ids_from = other
239            elif self._has_param_ids():
240                param_ids_from = self
241
242        return Mark(
243            self.name,
244            self.args + other.args,
245            dict(self.kwargs, **other.kwargs),
246            param_ids_from=param_ids_from,
247        )
248
249
250# A generic parameter designating an object to which a Mark may
251# be applied -- a test function (callable) or class.
252# Note: a lambda is not allowed, but this can't be represented.
253_Markable = TypeVar("_Markable", bound=Union[Callable[..., object], type])
254
255
256@attr.s
257class MarkDecorator:
258    """A decorator for applying a mark on test functions and classes.
259
260    MarkDecorators are created with ``pytest.mark``::
261
262        mark1 = pytest.mark.NAME              # Simple MarkDecorator
263        mark2 = pytest.mark.NAME(name1=value) # Parametrized MarkDecorator
264
265    and can then be applied as decorators to test functions::
266
267        @mark2
268        def test_function():
269            pass
270
271    When a MarkDecorator is called it does the following:
272
273    1. If called with a single class as its only positional argument and no
274       additional keyword arguments, it attaches the mark to the class so it
275       gets applied automatically to all test cases found in that class.
276
277    2. If called with a single function as its only positional argument and
278       no additional keyword arguments, it attaches the mark to the function,
279       containing all the arguments already stored internally in the
280       MarkDecorator.
281
282    3. When called in any other case, it returns a new MarkDecorator instance
283       with the original MarkDecorator's content updated with the arguments
284       passed to this call.
285
286    Note: The rules above prevent MarkDecorators from storing only a single
287    function or class reference as their positional argument with no
288    additional keyword or positional arguments. You can work around this by
289    using `with_args()`.
290    """
291
292    mark = attr.ib(type=Mark, validator=attr.validators.instance_of(Mark))
293
294    @property
295    def name(self) -> str:
296        """Alias for mark.name."""
297        return self.mark.name
298
299    @property
300    def args(self) -> Tuple[Any, ...]:
301        """Alias for mark.args."""
302        return self.mark.args
303
304    @property
305    def kwargs(self) -> Mapping[str, Any]:
306        """Alias for mark.kwargs."""
307        return self.mark.kwargs
308
309    @property
310    def markname(self) -> str:
311        return self.name  # for backward-compat (2.4.1 had this attr)
312
313    def __repr__(self) -> str:
314        return "<MarkDecorator {!r}>".format(self.mark)
315
316    def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator":
317        """Return a MarkDecorator with extra arguments added.
318
319        Unlike calling the MarkDecorator, with_args() can be used even
320        if the sole argument is a callable/class.
321
322        :rtype: MarkDecorator
323        """
324        mark = Mark(self.name, args, kwargs)
325        return self.__class__(self.mark.combined_with(mark))
326
327    # Type ignored because the overloads overlap with an incompatible
328    # return type. Not much we can do about that. Thankfully mypy picks
329    # the first match so it works out even if we break the rules.
330    @overload
331    def __call__(self, arg: _Markable) -> _Markable:  # type: ignore[misc]
332        pass
333
334    @overload  # noqa: F811
335    def __call__(  # noqa: F811
336        self, *args: object, **kwargs: object
337    ) -> "MarkDecorator":
338        pass
339
340    def __call__(self, *args: object, **kwargs: object):  # noqa: F811
341        """Call the MarkDecorator."""
342        if args and not kwargs:
343            func = args[0]
344            is_class = inspect.isclass(func)
345            if len(args) == 1 and (istestfunc(func) or is_class):
346                store_mark(func, self.mark)
347                return func
348        return self.with_args(*args, **kwargs)
349
350
351def get_unpacked_marks(obj) -> List[Mark]:
352    """Obtain the unpacked marks that are stored on an object."""
353    mark_list = getattr(obj, "pytestmark", [])
354    if not isinstance(mark_list, list):
355        mark_list = [mark_list]
356    return normalize_mark_list(mark_list)
357
358
359def normalize_mark_list(mark_list: Iterable[Union[Mark, MarkDecorator]]) -> List[Mark]:
360    """Normalize marker decorating helpers to mark objects.
361
362    :type List[Union[Mark, Markdecorator]] mark_list:
363    :rtype: List[Mark]
364    """
365    extracted = [
366        getattr(mark, "mark", mark) for mark in mark_list
367    ]  # unpack MarkDecorator
368    for mark in extracted:
369        if not isinstance(mark, Mark):
370            raise TypeError("got {!r} instead of Mark".format(mark))
371    return [x for x in extracted if isinstance(x, Mark)]
372
373
374def store_mark(obj, mark: Mark) -> None:
375    """Store a Mark on an object.
376
377    This is used to implement the Mark declarations/decorators correctly.
378    """
379    assert isinstance(mark, Mark), mark
380    # Always reassign name to avoid updating pytestmark in a reference that
381    # was only borrowed.
382    obj.pytestmark = get_unpacked_marks(obj) + [mark]
383
384
385# Typing for builtin pytest marks. This is cheating; it gives builtin marks
386# special privilege, and breaks modularity. But practicality beats purity...
387if TYPE_CHECKING:
388    from _pytest.fixtures import _Scope
389
390    class _SkipMarkDecorator(MarkDecorator):
391        @overload  # type: ignore[override,misc]
392        def __call__(self, arg: _Markable) -> _Markable:
393            ...
394
395        @overload  # noqa: F811
396        def __call__(self, reason: str = ...) -> "MarkDecorator":  # noqa: F811
397            ...
398
399    class _SkipifMarkDecorator(MarkDecorator):
400        def __call__(  # type: ignore[override]
401            self,
402            condition: Union[str, bool] = ...,
403            *conditions: Union[str, bool],
404            reason: str = ...
405        ) -> MarkDecorator:
406            ...
407
408    class _XfailMarkDecorator(MarkDecorator):
409        @overload  # type: ignore[override,misc]
410        def __call__(self, arg: _Markable) -> _Markable:
411            ...
412
413        @overload  # noqa: F811
414        def __call__(  # noqa: F811
415            self,
416            condition: Union[str, bool] = ...,
417            *conditions: Union[str, bool],
418            reason: str = ...,
419            run: bool = ...,
420            raises: Union[
421                "Type[BaseException]", Tuple["Type[BaseException]", ...]
422            ] = ...,
423            strict: bool = ...
424        ) -> MarkDecorator:
425            ...
426
427    class _ParametrizeMarkDecorator(MarkDecorator):
428        def __call__(  # type: ignore[override]
429            self,
430            argnames: Union[str, List[str], Tuple[str, ...]],
431            argvalues: Iterable[Union[ParameterSet, Sequence[object], object]],
432            *,
433            indirect: Union[bool, Sequence[str]] = ...,
434            ids: Optional[
435                Union[
436                    Iterable[Union[None, str, float, int, bool]],
437                    Callable[[Any], Optional[object]],
438                ]
439            ] = ...,
440            scope: Optional[_Scope] = ...
441        ) -> MarkDecorator:
442            ...
443
444    class _UsefixturesMarkDecorator(MarkDecorator):
445        def __call__(  # type: ignore[override]
446            self, *fixtures: str
447        ) -> MarkDecorator:
448            ...
449
450    class _FilterwarningsMarkDecorator(MarkDecorator):
451        def __call__(  # type: ignore[override]
452            self, *filters: str
453        ) -> MarkDecorator:
454            ...
455
456
457@final
458class MarkGenerator:
459    """Factory for :class:`MarkDecorator` objects - exposed as
460    a ``pytest.mark`` singleton instance.
461
462    Example::
463
464         import pytest
465
466         @pytest.mark.slowtest
467         def test_function():
468            pass
469
470    applies a 'slowtest' :class:`Mark` on ``test_function``.
471    """
472
473    _config = None  # type: Optional[Config]
474    _markers = set()  # type: Set[str]
475
476    # See TYPE_CHECKING above.
477    if TYPE_CHECKING:
478        # TODO(py36): Change to builtin annotation syntax.
479        skip = _SkipMarkDecorator(Mark("skip", (), {}))
480        skipif = _SkipifMarkDecorator(Mark("skipif", (), {}))
481        xfail = _XfailMarkDecorator(Mark("xfail", (), {}))
482        parametrize = _ParametrizeMarkDecorator(Mark("parametrize", (), {}))
483        usefixtures = _UsefixturesMarkDecorator(Mark("usefixtures", (), {}))
484        filterwarnings = _FilterwarningsMarkDecorator(Mark("filterwarnings", (), {}))
485
486    def __getattr__(self, name: str) -> MarkDecorator:
487        if name[0] == "_":
488            raise AttributeError("Marker name must NOT start with underscore")
489
490        if self._config is not None:
491            # We store a set of markers as a performance optimisation - if a mark
492            # name is in the set we definitely know it, but a mark may be known and
493            # not in the set.  We therefore start by updating the set!
494            if name not in self._markers:
495                for line in self._config.getini("markers"):
496                    # example lines: "skipif(condition): skip the given test if..."
497                    # or "hypothesis: tests which use Hypothesis", so to get the
498                    # marker name we split on both `:` and `(`.
499                    marker = line.split(":")[0].split("(")[0].strip()
500                    self._markers.add(marker)
501
502            # If the name is not in the set of known marks after updating,
503            # then it really is time to issue a warning or an error.
504            if name not in self._markers:
505                if self._config.option.strict_markers:
506                    fail(
507                        "{!r} not found in `markers` configuration option".format(name),
508                        pytrace=False,
509                    )
510
511                # Raise a specific error for common misspellings of "parametrize".
512                if name in ["parameterize", "parametrise", "parameterise"]:
513                    __tracebackhide__ = True
514                    fail("Unknown '{}' mark, did you mean 'parametrize'?".format(name))
515
516                warnings.warn(
517                    "Unknown pytest.mark.%s - is this a typo?  You can register "
518                    "custom marks to avoid this warning - for details, see "
519                    "https://docs.pytest.org/en/stable/mark.html" % name,
520                    PytestUnknownMarkWarning,
521                    2,
522                )
523
524        return MarkDecorator(Mark(name, (), {}))
525
526
527MARK_GEN = MarkGenerator()
528
529
530# TODO(py36): inherit from typing.MutableMapping[str, Any].
531@final
532class NodeKeywords(collections.abc.MutableMapping):  # type: ignore[type-arg]
533    def __init__(self, node: "Node") -> None:
534        self.node = node
535        self.parent = node.parent
536        self._markers = {node.name: True}
537
538    def __getitem__(self, key: str) -> Any:
539        try:
540            return self._markers[key]
541        except KeyError:
542            if self.parent is None:
543                raise
544            return self.parent.keywords[key]
545
546    def __setitem__(self, key: str, value: Any) -> None:
547        self._markers[key] = value
548
549    def __delitem__(self, key: str) -> None:
550        raise ValueError("cannot delete key in keywords dict")
551
552    def __iter__(self) -> Iterator[str]:
553        seen = self._seen()
554        return iter(seen)
555
556    def _seen(self) -> Set[str]:
557        seen = set(self._markers)
558        if self.parent is not None:
559            seen.update(self.parent.keywords)
560        return seen
561
562    def __len__(self) -> int:
563        return len(self._seen())
564
565    def __repr__(self) -> str:
566        return "<NodeKeywords for node {}>".format(self.node)
567