1# Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
2
3__all__ = [
4    'AfterPreprocessing',
5    'AllMatch',
6    'Annotate',
7    'AnyMatch',
8    'MatchesAny',
9    'MatchesAll',
10    'Not',
11    ]
12
13import types
14
15from ._impl import (
16    Matcher,
17    Mismatch,
18    MismatchDecorator,
19    )
20
21
22class MatchesAny:
23    """Matches if any of the matchers it is created with match."""
24
25    def __init__(self, *matchers):
26        self.matchers = matchers
27
28    def match(self, matchee):
29        results = []
30        for matcher in self.matchers:
31            mismatch = matcher.match(matchee)
32            if mismatch is None:
33                return None
34            results.append(mismatch)
35        return MismatchesAll(results)
36
37    def __str__(self):
38        return "MatchesAny(%s)" % ', '.join([
39            str(matcher) for matcher in self.matchers])
40
41
42class MatchesAll:
43    """Matches if all of the matchers it is created with match."""
44
45    def __init__(self, *matchers, **options):
46        """Construct a MatchesAll matcher.
47
48        Just list the component matchers as arguments in the ``*args``
49        style. If you want only the first mismatch to be reported, past in
50        first_only=True as a keyword argument. By default, all mismatches are
51        reported.
52        """
53        self.matchers = matchers
54        self.first_only = options.get('first_only', False)
55
56    def __str__(self):
57        return 'MatchesAll(%s)' % ', '.join(map(str, self.matchers))
58
59    def match(self, matchee):
60        results = []
61        for matcher in self.matchers:
62            mismatch = matcher.match(matchee)
63            if mismatch is not None:
64                if self.first_only:
65                    return mismatch
66                results.append(mismatch)
67        if results:
68            return MismatchesAll(results)
69        else:
70            return None
71
72
73class MismatchesAll(Mismatch):
74    """A mismatch with many child mismatches."""
75
76    def __init__(self, mismatches, wrap=True):
77        self.mismatches = mismatches
78        self._wrap = wrap
79
80    def describe(self):
81        descriptions = []
82        if self._wrap:
83            descriptions = ["Differences: ["]
84        for mismatch in self.mismatches:
85            descriptions.append(mismatch.describe())
86        if self._wrap:
87            descriptions.append("]")
88        return '\n'.join(descriptions)
89
90
91class Not:
92    """Inverts a matcher."""
93
94    def __init__(self, matcher):
95        self.matcher = matcher
96
97    def __str__(self):
98        return 'Not({})'.format(self.matcher)
99
100    def match(self, other):
101        mismatch = self.matcher.match(other)
102        if mismatch is None:
103            return MatchedUnexpectedly(self.matcher, other)
104        else:
105            return None
106
107
108class MatchedUnexpectedly(Mismatch):
109    """A thing matched when it wasn't supposed to."""
110
111    def __init__(self, matcher, other):
112        self.matcher = matcher
113        self.other = other
114
115    def describe(self):
116        return "{!r} matches {}".format(self.other, self.matcher)
117
118
119class Annotate:
120    """Annotates a matcher with a descriptive string.
121
122    Mismatches are then described as '<mismatch>: <annotation>'.
123    """
124
125    def __init__(self, annotation, matcher):
126        self.annotation = annotation
127        self.matcher = matcher
128
129    @classmethod
130    def if_message(cls, annotation, matcher):
131        """Annotate ``matcher`` only if ``annotation`` is non-empty."""
132        if not annotation:
133            return matcher
134        return cls(annotation, matcher)
135
136    def __str__(self):
137        return 'Annotate({!r}, {})'.format(self.annotation, self.matcher)
138
139    def match(self, other):
140        mismatch = self.matcher.match(other)
141        if mismatch is not None:
142            return AnnotatedMismatch(self.annotation, mismatch)
143
144
145class PostfixedMismatch(MismatchDecorator):
146    """A mismatch annotated with a descriptive string."""
147
148    def __init__(self, annotation, mismatch):
149        super().__init__(mismatch)
150        self.annotation = annotation
151        self.mismatch = mismatch
152
153    def describe(self):
154        return '{}: {}'.format(self.original.describe(), self.annotation)
155
156
157AnnotatedMismatch = PostfixedMismatch
158
159
160class PrefixedMismatch(MismatchDecorator):
161
162    def __init__(self, prefix, mismatch):
163        super().__init__(mismatch)
164        self.prefix = prefix
165
166    def describe(self):
167        return '{}: {}'.format(self.prefix, self.original.describe())
168
169
170class AfterPreprocessing:
171    """Matches if the value matches after passing through a function.
172
173    This can be used to aid in creating trivial matchers as functions, for
174    example::
175
176      def PathHasFileContent(content):
177          def _read(path):
178              return open(path).read()
179          return AfterPreprocessing(_read, Equals(content))
180    """
181
182    def __init__(self, preprocessor, matcher, annotate=True):
183        """Create an AfterPreprocessing matcher.
184
185        :param preprocessor: A function called with the matchee before
186            matching.
187        :param matcher: What to match the preprocessed matchee against.
188        :param annotate: Whether or not to annotate the matcher with
189            something explaining how we transformed the matchee. Defaults
190            to True.
191        """
192        self.preprocessor = preprocessor
193        self.matcher = matcher
194        self.annotate = annotate
195
196    def _str_preprocessor(self):
197        if isinstance(self.preprocessor, types.FunctionType):
198            return '<function %s>' % self.preprocessor.__name__
199        return str(self.preprocessor)
200
201    def __str__(self):
202        return "AfterPreprocessing({}, {})".format(
203            self._str_preprocessor(), self.matcher)
204
205    def match(self, value):
206        after = self.preprocessor(value)
207        if self.annotate:
208            matcher = Annotate(
209                "after {} on {!r}".format(self._str_preprocessor(), value),
210                self.matcher)
211        else:
212            matcher = self.matcher
213        return matcher.match(after)
214
215
216# This is the old, deprecated. spelling of the name, kept for backwards
217# compatibility.
218AfterPreproccessing = AfterPreprocessing
219
220
221class AllMatch:
222    """Matches if all provided values match the given matcher."""
223
224    def __init__(self, matcher):
225        self.matcher = matcher
226
227    def __str__(self):
228        return 'AllMatch({})'.format(self.matcher)
229
230    def match(self, values):
231        mismatches = []
232        for value in values:
233            mismatch = self.matcher.match(value)
234            if mismatch:
235                mismatches.append(mismatch)
236        if mismatches:
237            return MismatchesAll(mismatches)
238
239
240class AnyMatch:
241    """Matches if any of the provided values match the given matcher."""
242
243    def __init__(self, matcher):
244        self.matcher = matcher
245
246    def __str__(self):
247        return 'AnyMatch({})'.format(self.matcher)
248
249    def match(self, values):
250        mismatches = []
251        for value in values:
252            mismatch = self.matcher.match(value)
253            if mismatch:
254                mismatches.append(mismatch)
255            else:
256                return None
257        return MismatchesAll(mismatches)
258
259
260class MatchesPredicate(Matcher):
261    """Match if a given function returns True.
262
263    It is reasonably common to want to make a very simple matcher based on a
264    function that you already have that returns True or False given a single
265    argument (i.e. a predicate function).  This matcher makes it very easy to
266    do so. e.g.::
267
268      IsEven = MatchesPredicate(lambda x: x % 2 == 0, '%s is not even')
269      self.assertThat(4, IsEven)
270    """
271
272    def __init__(self, predicate, message):
273        """Create a ``MatchesPredicate`` matcher.
274
275        :param predicate: A function that takes a single argument and returns
276            a value that will be interpreted as a boolean.
277        :param message: A message to describe a mismatch.  It will be formatted
278            with '%' and be given whatever was passed to ``match()``. Thus, it
279            needs to contain exactly one thing like '%s', '%d' or '%f'.
280        """
281        self.predicate = predicate
282        self.message = message
283
284    def __str__(self):
285        return '{}({!r}, {!r})'.format(
286            self.__class__.__name__, self.predicate, self.message)
287
288    def match(self, x):
289        if not self.predicate(x):
290            return Mismatch(self.message % x)
291
292
293def MatchesPredicateWithParams(predicate, message, name=None):
294    """Match if a given parameterised function returns True.
295
296    It is reasonably common to want to make a very simple matcher based on a
297    function that you already have that returns True or False given some
298    arguments. This matcher makes it very easy to do so. e.g.::
299
300      HasLength = MatchesPredicate(
301          lambda x, y: len(x) == y, 'len({0}) is not {1}')
302      # This assertion will fail, as 'len([1, 2]) == 3' is False.
303      self.assertThat([1, 2], HasLength(3))
304
305    Note that unlike MatchesPredicate MatchesPredicateWithParams returns a
306    factory which you then customise to use by constructing an actual matcher
307    from it.
308
309    The predicate function should take the object to match as its first
310    parameter. Any additional parameters supplied when constructing a matcher
311    are supplied to the predicate as additional parameters when checking for a
312    match.
313
314    :param predicate: The predicate function.
315    :param message: A format string for describing mis-matches.
316    :param name: Optional replacement name for the matcher.
317    """
318    def construct_matcher(*args, **kwargs):
319        return _MatchesPredicateWithParams(
320            predicate, message, name, *args, **kwargs)
321    return construct_matcher
322
323
324class _MatchesPredicateWithParams(Matcher):
325
326    def __init__(self, predicate, message, name, *args, **kwargs):
327        """Create a ``MatchesPredicateWithParams`` matcher.
328
329        :param predicate: A function that takes an object to match and
330            additional params as given in ``*args`` and ``**kwargs``. The
331            result of the function will be interpreted as a boolean to
332            determine a match.
333        :param message: A message to describe a mismatch.  It will be formatted
334            with .format() and be given a tuple containing whatever was passed
335            to ``match()`` + ``*args`` in ``*args``, and whatever was passed to
336            ``**kwargs`` as its ``**kwargs``.
337
338            For instance, to format a single parameter::
339
340                "{0} is not a {1}"
341
342            To format a keyword arg::
343
344                "{0} is not a {type_to_check}"
345        :param name: What name to use for the matcher class. Pass None to use
346            the default.
347        """
348        self.predicate = predicate
349        self.message = message
350        self.name = name
351        self.args = args
352        self.kwargs = kwargs
353
354    def __str__(self):
355        args = [str(arg) for arg in self.args]
356        kwargs = ["%s=%s" % item for item in self.kwargs.items()]
357        args = ", ".join(args + kwargs)
358        if self.name is None:
359            name = 'MatchesPredicateWithParams({!r}, {!r})'.format(
360                self.predicate, self.message)
361        else:
362            name = self.name
363        return '{}({})'.format(name, args)
364
365    def match(self, x):
366        if not self.predicate(x, *self.args, **self.kwargs):
367            return Mismatch(
368                self.message.format(*((x,) + self.args), **self.kwargs))
369