1# Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
2
3__all__ = [
4    'ContainsAll',
5    'MatchesListwise',
6    'MatchesSetwise',
7    'MatchesStructure',
8    ]
9
10"""Matchers that operate with knowledge of Python data structures."""
11
12from ..helpers import map_values
13from ._higherorder import (
14    Annotate,
15    MatchesAll,
16    MismatchesAll,
17    )
18from ._impl import Mismatch
19
20
21def ContainsAll(items):
22    """Make a matcher that checks whether a list of things is contained
23    in another thing.
24
25    The matcher effectively checks that the provided sequence is a subset of
26    the matchee.
27    """
28    from ._basic import Contains
29    return MatchesAll(*map(Contains, items), first_only=False)
30
31
32class MatchesListwise(object):
33    """Matches if each matcher matches the corresponding value.
34
35    More easily explained by example than in words:
36
37    >>> from ._basic import Equals
38    >>> MatchesListwise([Equals(1)]).match([1])
39    >>> MatchesListwise([Equals(1), Equals(2)]).match([1, 2])
40    >>> print (MatchesListwise([Equals(1), Equals(2)]).match([2, 1]).describe())
41    Differences: [
42    1 != 2
43    2 != 1
44    ]
45    >>> matcher = MatchesListwise([Equals(1), Equals(2)], first_only=True)
46    >>> print (matcher.match([3, 4]).describe())
47    1 != 3
48    """
49
50    def __init__(self, matchers, first_only=False):
51        """Construct a MatchesListwise matcher.
52
53        :param matchers: A list of matcher that the matched values must match.
54        :param first_only: If True, then only report the first mismatch,
55            otherwise report all of them. Defaults to False.
56        """
57        self.matchers = matchers
58        self.first_only = first_only
59
60    def match(self, values):
61        from ._basic import Equals
62        mismatches = []
63        length_mismatch = Annotate(
64            "Length mismatch", Equals(len(self.matchers))).match(len(values))
65        if length_mismatch:
66            mismatches.append(length_mismatch)
67        for matcher, value in zip(self.matchers, values):
68            mismatch = matcher.match(value)
69            if mismatch:
70                if self.first_only:
71                    return mismatch
72                mismatches.append(mismatch)
73        if mismatches:
74            return MismatchesAll(mismatches)
75
76
77class MatchesStructure(object):
78    """Matcher that matches an object structurally.
79
80    'Structurally' here means that attributes of the object being matched are
81    compared against given matchers.
82
83    `fromExample` allows the creation of a matcher from a prototype object and
84    then modified versions can be created with `update`.
85
86    `byEquality` creates a matcher in much the same way as the constructor,
87    except that the matcher for each of the attributes is assumed to be
88    `Equals`.
89
90    `byMatcher` creates a similar matcher to `byEquality`, but you get to pick
91    the matcher, rather than just using `Equals`.
92    """
93
94    def __init__(self, **kwargs):
95        """Construct a `MatchesStructure`.
96
97        :param kwargs: A mapping of attributes to matchers.
98        """
99        self.kws = kwargs
100
101    @classmethod
102    def byEquality(cls, **kwargs):
103        """Matches an object where the attributes equal the keyword values.
104
105        Similar to the constructor, except that the matcher is assumed to be
106        Equals.
107        """
108        from ._basic import Equals
109        return cls.byMatcher(Equals, **kwargs)
110
111    @classmethod
112    def byMatcher(cls, matcher, **kwargs):
113        """Matches an object where the attributes match the keyword values.
114
115        Similar to the constructor, except that the provided matcher is used
116        to match all of the values.
117        """
118        return cls(**map_values(matcher, kwargs))
119
120    @classmethod
121    def fromExample(cls, example, *attributes):
122        from ._basic import Equals
123        kwargs = {}
124        for attr in attributes:
125            kwargs[attr] = Equals(getattr(example, attr))
126        return cls(**kwargs)
127
128    def update(self, **kws):
129        new_kws = self.kws.copy()
130        for attr, matcher in kws.items():
131            if matcher is None:
132                new_kws.pop(attr, None)
133            else:
134                new_kws[attr] = matcher
135        return type(self)(**new_kws)
136
137    def __str__(self):
138        kws = []
139        for attr, matcher in sorted(self.kws.items()):
140            kws.append("%s=%s" % (attr, matcher))
141        return "%s(%s)" % (self.__class__.__name__, ', '.join(kws))
142
143    def match(self, value):
144        matchers = []
145        values = []
146        for attr, matcher in sorted(self.kws.items()):
147            matchers.append(Annotate(attr, matcher))
148            values.append(getattr(value, attr))
149        return MatchesListwise(matchers).match(values)
150
151
152class MatchesSetwise(object):
153    """Matches if all the matchers match elements of the value being matched.
154
155    That is, each element in the 'observed' set must match exactly one matcher
156    from the set of matchers, with no matchers left over.
157
158    The difference compared to `MatchesListwise` is that the order of the
159    matchings does not matter.
160    """
161
162    def __init__(self, *matchers):
163        self.matchers = matchers
164
165    def match(self, observed):
166        remaining_matchers = set(self.matchers)
167        not_matched = []
168        for value in observed:
169            for matcher in remaining_matchers:
170                if matcher.match(value) is None:
171                    remaining_matchers.remove(matcher)
172                    break
173            else:
174                not_matched.append(value)
175        if not_matched or remaining_matchers:
176            remaining_matchers = list(remaining_matchers)
177            # There are various cases that all should be reported somewhat
178            # differently.
179
180            # There are two trivial cases:
181            # 1) There are just some matchers left over.
182            # 2) There are just some values left over.
183
184            # Then there are three more interesting cases:
185            # 3) There are the same number of matchers and values left over.
186            # 4) There are more matchers left over than values.
187            # 5) There are more values left over than matchers.
188
189            if len(not_matched) == 0:
190                if len(remaining_matchers) > 1:
191                    msg = "There were %s matchers left over: " % (
192                        len(remaining_matchers),)
193                else:
194                    msg = "There was 1 matcher left over: "
195                msg += ', '.join(map(str, remaining_matchers))
196                return Mismatch(msg)
197            elif len(remaining_matchers) == 0:
198                if len(not_matched) > 1:
199                    return Mismatch(
200                        "There were %s values left over: %s" % (
201                            len(not_matched), not_matched))
202                else:
203                    return Mismatch(
204                        "There was 1 value left over: %s" % (
205                            not_matched, ))
206            else:
207                common_length = min(len(remaining_matchers), len(not_matched))
208                if common_length == 0:
209                    raise AssertionError("common_length can't be 0 here")
210                if common_length > 1:
211                    msg = "There were %s mismatches" % (common_length,)
212                else:
213                    msg = "There was 1 mismatch"
214                if len(remaining_matchers) > len(not_matched):
215                    extra_matchers = remaining_matchers[common_length:]
216                    msg += " and %s extra matcher" % (len(extra_matchers), )
217                    if len(extra_matchers) > 1:
218                        msg += "s"
219                    msg += ': ' + ', '.join(map(str, extra_matchers))
220                elif len(not_matched) > len(remaining_matchers):
221                    extra_values = not_matched[common_length:]
222                    msg += " and %s extra value" % (len(extra_values), )
223                    if len(extra_values) > 1:
224                        msg += "s"
225                    msg += ': ' + str(extra_values)
226                return Annotate(
227                    msg, MatchesListwise(remaining_matchers[:common_length])
228                    ).match(not_matched[:common_length])
229