1# coding: utf-8 2 3""" 4Provides test-related code that can be used by all tests. 5 6""" 7 8import os 9 10import pystache 11from pystache import defaults 12from pystache.tests import examples 13 14# Save a reference to the original function to avoid recursion. 15_DEFAULT_TAG_ESCAPE = defaults.TAG_ESCAPE 16_TESTS_DIR = os.path.dirname(pystache.tests.__file__) 17 18DATA_DIR = os.path.join(_TESTS_DIR, 'data') # i.e. 'pystache/tests/data'. 19EXAMPLES_DIR = os.path.dirname(examples.__file__) 20PACKAGE_DIR = os.path.dirname(pystache.__file__) 21PROJECT_DIR = os.path.join(PACKAGE_DIR, '..') 22# TEXT_DOCTEST_PATHS: the paths to text files (i.e. non-module files) 23# containing doctests. The paths should be relative to the project directory. 24TEXT_DOCTEST_PATHS = ['README.md'] 25 26UNITTEST_FILE_PREFIX = "test_" 27 28 29def get_spec_test_dir(project_dir): 30 return os.path.join(project_dir, 'ext', 'spec', 'specs') 31 32 33def html_escape(u): 34 """ 35 An html escape function that behaves the same in both Python 2 and 3. 36 37 This function is needed because single quotes are escaped in Python 3 38 (to '''), but not in Python 2. 39 40 The global defaults.TAG_ESCAPE can be set to this function in the 41 setUp() and tearDown() of unittest test cases, for example, for 42 consistent test results. 43 44 """ 45 u = _DEFAULT_TAG_ESCAPE(u) 46 return u.replace("'", ''') 47 48 49def get_data_path(file_name=None): 50 """Return the path to a file in the test data directory.""" 51 if file_name is None: 52 file_name = "" 53 return os.path.join(DATA_DIR, file_name) 54 55 56# Functions related to get_module_names(). 57 58def _find_files(root_dir, should_include): 59 """ 60 Return a list of paths to all modules below the given directory. 61 62 Arguments: 63 64 should_include: a function that accepts a file path and returns True or False. 65 66 """ 67 paths = [] # Return value. 68 69 is_module = lambda path: path.endswith(".py") 70 71 # os.walk() is new in Python 2.3 72 # http://docs.python.org/library/os.html#os.walk 73 for dir_path, dir_names, file_names in os.walk(root_dir): 74 new_paths = [os.path.join(dir_path, file_name) for file_name in file_names] 75 new_paths = list(filter(is_module, new_paths)) 76 new_paths = list(filter(should_include, new_paths)) 77 paths.extend(new_paths) 78 79 return paths 80 81 82def _make_module_names(package_dir, paths): 83 """ 84 Return a list of fully-qualified module names given a list of module paths. 85 86 """ 87 package_dir = os.path.abspath(package_dir) 88 package_name = os.path.split(package_dir)[1] 89 90 prefix_length = len(package_dir) 91 92 module_names = [] 93 for path in paths: 94 path = os.path.abspath(path) # for example <path_to_package>/subpackage/module.py 95 rel_path = path[prefix_length:] # for example /subpackage/module.py 96 rel_path = os.path.splitext(rel_path)[0] # for example /subpackage/module 97 98 parts = [] 99 while True: 100 (rel_path, tail) = os.path.split(rel_path) 101 if not tail: 102 break 103 parts.insert(0, tail) 104 # We now have, for example, ['subpackage', 'module']. 105 parts.insert(0, package_name) 106 module = ".".join(parts) 107 module_names.append(module) 108 109 return module_names 110 111 112def get_module_names(package_dir=None, should_include=None): 113 """ 114 Return a list of fully-qualified module names in the given package. 115 116 """ 117 if package_dir is None: 118 package_dir = PACKAGE_DIR 119 120 if should_include is None: 121 should_include = lambda path: True 122 123 paths = _find_files(package_dir, should_include) 124 names = _make_module_names(package_dir, paths) 125 names.sort() 126 127 return names 128 129 130class AssertStringMixin: 131 132 """A unittest.TestCase mixin to check string equality.""" 133 134 def assertString(self, actual, expected, format=None): 135 """ 136 Assert that the given strings are equal and have the same type. 137 138 Arguments: 139 140 format: a format string containing a single conversion specifier %s. 141 Defaults to "%s". 142 143 """ 144 if format is None: 145 format = "%s" 146 147 # Show both friendly and literal versions. 148 details = """String mismatch: %%s 149 150 Expected: \"""%s\""" 151 Actual: \"""%s\""" 152 153 Expected: %s 154 Actual: %s""" % (expected, actual, repr(expected), repr(actual)) 155 156 def make_message(reason): 157 description = details % reason 158 return format % description 159 160 self.assertEqual(actual, expected, make_message("different characters")) 161 162 reason = "types different: %s != %s (actual)" % (repr(type(expected)), repr(type(actual))) 163 self.assertEqual(type(expected), type(actual), make_message(reason)) 164 165 166class AssertIsMixin: 167 168 """A unittest.TestCase mixin adding assertIs().""" 169 170 # unittest.assertIs() is not available until Python 2.7: 171 # http://docs.python.org/library/unittest.html#unittest.TestCase.assertIsNone 172 def assertIs(self, first, second): 173 self.assertTrue(first is second, msg="%s is not %s" % (repr(first), repr(second))) 174 175 176class AssertExceptionMixin: 177 178 """A unittest.TestCase mixin adding assertException().""" 179 180 # unittest.assertRaisesRegexp() is not available until Python 2.7: 181 # http://docs.python.org/library/unittest.html#unittest.TestCase.assertRaisesRegexp 182 def assertException(self, exception_type, msg, callable, *args, **kwds): 183 try: 184 callable(*args, **kwds) 185 raise Exception("Expected exception: %s: %s" % (exception_type, repr(msg))) 186 except exception_type as err: 187 self.assertEqual(str(err), msg) 188 189 190class SetupDefaults(object): 191 192 """ 193 Mix this class in to a unittest.TestCase for standard defaults. 194 195 This class allows for consistent test results across Python 2/3. 196 197 """ 198 199 def setup_defaults(self): 200 self.original_decode_errors = defaults.DECODE_ERRORS 201 self.original_file_encoding = defaults.FILE_ENCODING 202 self.original_string_encoding = defaults.STRING_ENCODING 203 204 defaults.DECODE_ERRORS = 'strict' 205 defaults.FILE_ENCODING = 'ascii' 206 defaults.STRING_ENCODING = 'ascii' 207 208 def teardown_defaults(self): 209 defaults.DECODE_ERRORS = self.original_decode_errors 210 defaults.FILE_ENCODING = self.original_file_encoding 211 defaults.STRING_ENCODING = self.original_string_encoding 212 213 214class Attachable(object): 215 """ 216 A class that attaches all constructor named parameters as attributes. 217 218 For example-- 219 220 >>> obj = Attachable(foo=42, size="of the universe") 221 >>> repr(obj) 222 "Attachable(foo=42, size='of the universe')" 223 >>> obj.foo 224 42 225 >>> obj.size 226 'of the universe' 227 228 """ 229 def __init__(self, **kwargs): 230 self.__args__ = kwargs 231 for arg, value in kwargs.items(): 232 setattr(self, arg, value) 233 234 def __repr__(self): 235 return "%s(%s)" % (self.__class__.__name__, 236 ", ".join("%s=%s" % (k, repr(v)) 237 for k, v in self.__args__.items())) 238