1from __future__ import absolute_import 2 3import os 4import unittest 5import tempfile 6 7from .Compiler import Errors 8from .CodeWriter import CodeWriter 9from .Compiler.TreeFragment import TreeFragment, strip_common_indent 10from .Compiler.Visitor import TreeVisitor, VisitorTransform 11from .Compiler import TreePath 12 13 14class NodeTypeWriter(TreeVisitor): 15 def __init__(self): 16 super(NodeTypeWriter, self).__init__() 17 self._indents = 0 18 self.result = [] 19 20 def visit_Node(self, node): 21 if not self.access_path: 22 name = u"(root)" 23 else: 24 tip = self.access_path[-1] 25 if tip[2] is not None: 26 name = u"%s[%d]" % tip[1:3] 27 else: 28 name = tip[1] 29 30 self.result.append(u" " * self._indents + 31 u"%s: %s" % (name, node.__class__.__name__)) 32 self._indents += 1 33 self.visitchildren(node) 34 self._indents -= 1 35 36 37def treetypes(root): 38 """Returns a string representing the tree by class names. 39 There's a leading and trailing whitespace so that it can be 40 compared by simple string comparison while still making test 41 cases look ok.""" 42 w = NodeTypeWriter() 43 w.visit(root) 44 return u"\n".join([u""] + w.result + [u""]) 45 46 47class CythonTest(unittest.TestCase): 48 49 def setUp(self): 50 self.listing_file = Errors.listing_file 51 self.echo_file = Errors.echo_file 52 Errors.listing_file = Errors.echo_file = None 53 54 def tearDown(self): 55 Errors.listing_file = self.listing_file 56 Errors.echo_file = self.echo_file 57 58 def assertLines(self, expected, result): 59 "Checks that the given strings or lists of strings are equal line by line" 60 if not isinstance(expected, list): 61 expected = expected.split(u"\n") 62 if not isinstance(result, list): 63 result = result.split(u"\n") 64 for idx, (expected_line, result_line) in enumerate(zip(expected, result)): 65 self.assertEqual(expected_line, result_line, 66 "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line)) 67 self.assertEqual(len(expected), len(result), 68 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result))) 69 70 def codeToLines(self, tree): 71 writer = CodeWriter() 72 writer.write(tree) 73 return writer.result.lines 74 75 def codeToString(self, tree): 76 return "\n".join(self.codeToLines(tree)) 77 78 def assertCode(self, expected, result_tree): 79 result_lines = self.codeToLines(result_tree) 80 81 expected_lines = strip_common_indent(expected.split("\n")) 82 83 for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)): 84 self.assertEqual(expected_line, line, 85 "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line)) 86 self.assertEqual(len(result_lines), len(expected_lines), 87 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) 88 89 def assertNodeExists(self, path, result_tree): 90 self.assertNotEqual(TreePath.find_first(result_tree, path), None, 91 "Path '%s' not found in result tree" % path) 92 93 def fragment(self, code, pxds=None, pipeline=None): 94 "Simply create a tree fragment using the name of the test-case in parse errors." 95 if pxds is None: 96 pxds = {} 97 if pipeline is None: 98 pipeline = [] 99 name = self.id() 100 if name.startswith("__main__."): 101 name = name[len("__main__."):] 102 name = name.replace(".", "_") 103 return TreeFragment(code, name, pxds, pipeline=pipeline) 104 105 def treetypes(self, root): 106 return treetypes(root) 107 108 def should_fail(self, func, exc_type=Exception): 109 """Calls "func" and fails if it doesn't raise the right exception 110 (any exception by default). Also returns the exception in question. 111 """ 112 try: 113 func() 114 self.fail("Expected an exception of type %r" % exc_type) 115 except exc_type as e: 116 self.assertTrue(isinstance(e, exc_type)) 117 return e 118 119 def should_not_fail(self, func): 120 """Calls func and succeeds if and only if no exception is raised 121 (i.e. converts exception raising into a failed testcase). Returns 122 the return value of func.""" 123 try: 124 return func() 125 except Exception as exc: 126 self.fail(str(exc)) 127 128 129class TransformTest(CythonTest): 130 """ 131 Utility base class for transform unit tests. It is based around constructing 132 test trees (either explicitly or by parsing a Cython code string); running 133 the transform, serialize it using a customized Cython serializer (with 134 special markup for nodes that cannot be represented in Cython), 135 and do a string-comparison line-by-line of the result. 136 137 To create a test case: 138 - Call run_pipeline. The pipeline should at least contain the transform you 139 are testing; pyx should be either a string (passed to the parser to 140 create a post-parse tree) or a node representing input to pipeline. 141 The result will be a transformed result. 142 143 - Check that the tree is correct. If wanted, assertCode can be used, which 144 takes a code string as expected, and a ModuleNode in result_tree 145 (it serializes the ModuleNode to a string and compares line-by-line). 146 147 All code strings are first stripped for whitespace lines and then common 148 indentation. 149 150 Plans: One could have a pxd dictionary parameter to run_pipeline. 151 """ 152 153 def run_pipeline(self, pipeline, pyx, pxds=None): 154 if pxds is None: 155 pxds = {} 156 tree = self.fragment(pyx, pxds).root 157 # Run pipeline 158 for T in pipeline: 159 tree = T(tree) 160 return tree 161 162 163class TreeAssertVisitor(VisitorTransform): 164 # actually, a TreeVisitor would be enough, but this needs to run 165 # as part of the compiler pipeline 166 167 def visit_CompilerDirectivesNode(self, node): 168 directives = node.directives 169 if 'test_assert_path_exists' in directives: 170 for path in directives['test_assert_path_exists']: 171 if TreePath.find_first(node, path) is None: 172 Errors.error( 173 node.pos, 174 "Expected path '%s' not found in result tree" % path) 175 if 'test_fail_if_path_exists' in directives: 176 for path in directives['test_fail_if_path_exists']: 177 if TreePath.find_first(node, path) is not None: 178 Errors.error( 179 node.pos, 180 "Unexpected path '%s' found in result tree" % path) 181 self.visitchildren(node) 182 return node 183 184 visit_Node = VisitorTransform.recurse_to_children 185 186 187def unpack_source_tree(tree_file, dir=None): 188 if dir is None: 189 dir = tempfile.mkdtemp() 190 header = [] 191 cur_file = None 192 f = open(tree_file) 193 try: 194 lines = f.readlines() 195 finally: 196 f.close() 197 del f 198 try: 199 for line in lines: 200 if line[:5] == '#####': 201 filename = line.strip().strip('#').strip().replace('/', os.path.sep) 202 path = os.path.join(dir, filename) 203 if not os.path.exists(os.path.dirname(path)): 204 os.makedirs(os.path.dirname(path)) 205 if cur_file is not None: 206 f, cur_file = cur_file, None 207 f.close() 208 cur_file = open(path, 'w') 209 elif cur_file is not None: 210 cur_file.write(line) 211 elif line.strip() and not line.lstrip().startswith('#'): 212 if line.strip() not in ('"""', "'''"): 213 header.append(line) 214 finally: 215 if cur_file is not None: 216 cur_file.close() 217 return dir, ''.join(header) 218