1# coding: utf-8
2
3"""
4Provides test-related code that can be used by all tests.
5
6"""
7
8import os
9
10import pystache
11from pystache import defaults
12from pystache.tests import examples
13
14# Save a reference to the original function to avoid recursion.
15_DEFAULT_TAG_ESCAPE = defaults.TAG_ESCAPE
16_TESTS_DIR = os.path.dirname(pystache.tests.__file__)
17
18DATA_DIR = os.path.join(_TESTS_DIR, 'data')  # i.e. 'pystache/tests/data'.
19EXAMPLES_DIR = os.path.dirname(examples.__file__)
20PACKAGE_DIR = os.path.dirname(pystache.__file__)
21PROJECT_DIR = os.path.join(PACKAGE_DIR, '..')
22# TEXT_DOCTEST_PATHS: the paths to text files (i.e. non-module files)
23# containing doctests.  The paths should be relative to the project directory.
24TEXT_DOCTEST_PATHS = ['README.md']
25
26UNITTEST_FILE_PREFIX = "test_"
27
28
29def get_spec_test_dir(project_dir):
30    return os.path.join(project_dir, 'ext', 'spec', 'specs')
31
32
33def html_escape(u):
34    """
35    An html escape function that behaves the same in both Python 2 and 3.
36
37    This function is needed because single quotes are escaped in Python 3
38    (to '''), but not in Python 2.
39
40    The global defaults.TAG_ESCAPE can be set to this function in the
41    setUp() and tearDown() of unittest test cases, for example, for
42    consistent test results.
43
44    """
45    u = _DEFAULT_TAG_ESCAPE(u)
46    return u.replace("'", ''')
47
48
49def get_data_path(file_name=None):
50    """Return the path to a file in the test data directory."""
51    if file_name is None:
52        file_name = ""
53    return os.path.join(DATA_DIR, file_name)
54
55
56# Functions related to get_module_names().
57
58def _find_files(root_dir, should_include):
59    """
60    Return a list of paths to all modules below the given directory.
61
62    Arguments:
63
64      should_include: a function that accepts a file path and returns True or False.
65
66    """
67    paths = []  # Return value.
68
69    is_module = lambda path: path.endswith(".py")
70
71    # os.walk() is new in Python 2.3
72    #   http://docs.python.org/library/os.html#os.walk
73    for dir_path, dir_names, file_names in os.walk(root_dir):
74        new_paths = [os.path.join(dir_path, file_name) for file_name in file_names]
75        new_paths = list(filter(is_module, new_paths))
76        new_paths = list(filter(should_include, new_paths))
77        paths.extend(new_paths)
78
79    return paths
80
81
82def _make_module_names(package_dir, paths):
83    """
84    Return a list of fully-qualified module names given a list of module paths.
85
86    """
87    package_dir = os.path.abspath(package_dir)
88    package_name = os.path.split(package_dir)[1]
89
90    prefix_length = len(package_dir)
91
92    module_names = []
93    for path in paths:
94        path = os.path.abspath(path)  # for example <path_to_package>/subpackage/module.py
95        rel_path = path[prefix_length:]  # for example /subpackage/module.py
96        rel_path = os.path.splitext(rel_path)[0]  # for example /subpackage/module
97
98        parts = []
99        while True:
100            (rel_path, tail) = os.path.split(rel_path)
101            if not tail:
102                break
103            parts.insert(0, tail)
104        # We now have, for example, ['subpackage', 'module'].
105        parts.insert(0, package_name)
106        module = ".".join(parts)
107        module_names.append(module)
108
109    return module_names
110
111
112def get_module_names(package_dir=None, should_include=None):
113    """
114    Return a list of fully-qualified module names in the given package.
115
116    """
117    if package_dir is None:
118        package_dir = PACKAGE_DIR
119
120    if should_include is None:
121        should_include = lambda path: True
122
123    paths = _find_files(package_dir, should_include)
124    names = _make_module_names(package_dir, paths)
125    names.sort()
126
127    return names
128
129
130class AssertStringMixin:
131
132    """A unittest.TestCase mixin to check string equality."""
133
134    def assertString(self, actual, expected, format=None):
135        """
136        Assert that the given strings are equal and have the same type.
137
138        Arguments:
139
140          format: a format string containing a single conversion specifier %s.
141            Defaults to "%s".
142
143        """
144        if format is None:
145            format = "%s"
146
147        # Show both friendly and literal versions.
148        details = """String mismatch: %%s
149
150        Expected: \"""%s\"""
151        Actual:   \"""%s\"""
152
153        Expected: %s
154        Actual:   %s""" % (expected, actual, repr(expected), repr(actual))
155
156        def make_message(reason):
157            description = details % reason
158            return format % description
159
160        self.assertEqual(actual, expected, make_message("different characters"))
161
162        reason = "types different: %s != %s (actual)" % (repr(type(expected)), repr(type(actual)))
163        self.assertEqual(type(expected), type(actual), make_message(reason))
164
165
166class AssertIsMixin:
167
168    """A unittest.TestCase mixin adding assertIs()."""
169
170    # unittest.assertIs() is not available until Python 2.7:
171    #   http://docs.python.org/library/unittest.html#unittest.TestCase.assertIsNone
172    def assertIs(self, first, second):
173        self.assertTrue(first is second, msg="%s is not %s" % (repr(first), repr(second)))
174
175
176class AssertExceptionMixin:
177
178    """A unittest.TestCase mixin adding assertException()."""
179
180    # unittest.assertRaisesRegexp() is not available until Python 2.7:
181    #   http://docs.python.org/library/unittest.html#unittest.TestCase.assertRaisesRegexp
182    def assertException(self, exception_type, msg, callable, *args, **kwds):
183        try:
184            callable(*args, **kwds)
185            raise Exception("Expected exception: %s: %s" % (exception_type, repr(msg)))
186        except exception_type as err:
187            self.assertEqual(str(err), msg)
188
189
190class SetupDefaults(object):
191
192    """
193    Mix this class in to a unittest.TestCase for standard defaults.
194
195    This class allows for consistent test results across Python 2/3.
196
197    """
198
199    def setup_defaults(self):
200        self.original_decode_errors = defaults.DECODE_ERRORS
201        self.original_file_encoding = defaults.FILE_ENCODING
202        self.original_string_encoding = defaults.STRING_ENCODING
203
204        defaults.DECODE_ERRORS = 'strict'
205        defaults.FILE_ENCODING = 'ascii'
206        defaults.STRING_ENCODING = 'ascii'
207
208    def teardown_defaults(self):
209        defaults.DECODE_ERRORS = self.original_decode_errors
210        defaults.FILE_ENCODING = self.original_file_encoding
211        defaults.STRING_ENCODING = self.original_string_encoding
212
213
214class Attachable(object):
215    """
216    A class that attaches all constructor named parameters as attributes.
217
218    For example--
219
220    >>> obj = Attachable(foo=42, size="of the universe")
221    >>> repr(obj)
222    "Attachable(foo=42, size='of the universe')"
223    >>> obj.foo
224    42
225    >>> obj.size
226    'of the universe'
227
228    """
229    def __init__(self, **kwargs):
230        self.__args__ = kwargs
231        for arg, value in kwargs.items():
232            setattr(self, arg, value)
233
234    def __repr__(self):
235        return "%s(%s)" % (self.__class__.__name__,
236                           ", ".join("%s=%s" % (k, repr(v))
237                                     for k, v in self.__args__.items()))
238