1from __future__ import absolute_import
2
3import os
4import unittest
5import tempfile
6
7from .Compiler import Errors
8from .CodeWriter import CodeWriter
9from .Compiler.TreeFragment import TreeFragment, strip_common_indent
10from .Compiler.Visitor import TreeVisitor, VisitorTransform
11from .Compiler import TreePath
12
13
14class NodeTypeWriter(TreeVisitor):
15    def __init__(self):
16        super(NodeTypeWriter, self).__init__()
17        self._indents = 0
18        self.result = []
19
20    def visit_Node(self, node):
21        if not self.access_path:
22            name = u"(root)"
23        else:
24            tip = self.access_path[-1]
25            if tip[2] is not None:
26                name = u"%s[%d]" % tip[1:3]
27            else:
28                name = tip[1]
29
30        self.result.append(u"  " * self._indents +
31                           u"%s: %s" % (name, node.__class__.__name__))
32        self._indents += 1
33        self.visitchildren(node)
34        self._indents -= 1
35
36
37def treetypes(root):
38    """Returns a string representing the tree by class names.
39    There's a leading and trailing whitespace so that it can be
40    compared by simple string comparison while still making test
41    cases look ok."""
42    w = NodeTypeWriter()
43    w.visit(root)
44    return u"\n".join([u""] + w.result + [u""])
45
46
47class CythonTest(unittest.TestCase):
48
49    def setUp(self):
50        self.listing_file = Errors.listing_file
51        self.echo_file = Errors.echo_file
52        Errors.listing_file = Errors.echo_file = None
53
54    def tearDown(self):
55        Errors.listing_file = self.listing_file
56        Errors.echo_file = self.echo_file
57
58    def assertLines(self, expected, result):
59        "Checks that the given strings or lists of strings are equal line by line"
60        if not isinstance(expected, list):
61            expected = expected.split(u"\n")
62        if not isinstance(result, list):
63            result = result.split(u"\n")
64        for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
65            self.assertEqual(expected_line, result_line,
66                             "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
67        self.assertEqual(len(expected), len(result),
68                         "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
69
70    def codeToLines(self, tree):
71        writer = CodeWriter()
72        writer.write(tree)
73        return writer.result.lines
74
75    def codeToString(self, tree):
76        return "\n".join(self.codeToLines(tree))
77
78    def assertCode(self, expected, result_tree):
79        result_lines = self.codeToLines(result_tree)
80
81        expected_lines = strip_common_indent(expected.split("\n"))
82
83        for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
84            self.assertEqual(expected_line, line,
85                             "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
86        self.assertEqual(len(result_lines), len(expected_lines),
87                         "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
88
89    def assertNodeExists(self, path, result_tree):
90        self.assertNotEqual(TreePath.find_first(result_tree, path), None,
91                            "Path '%s' not found in result tree" % path)
92
93    def fragment(self, code, pxds=None, pipeline=None):
94        "Simply create a tree fragment using the name of the test-case in parse errors."
95        if pxds is None:
96            pxds = {}
97        if pipeline is None:
98            pipeline = []
99        name = self.id()
100        if name.startswith("__main__."):
101            name = name[len("__main__."):]
102        name = name.replace(".", "_")
103        return TreeFragment(code, name, pxds, pipeline=pipeline)
104
105    def treetypes(self, root):
106        return treetypes(root)
107
108    def should_fail(self, func, exc_type=Exception):
109        """Calls "func" and fails if it doesn't raise the right exception
110        (any exception by default). Also returns the exception in question.
111        """
112        try:
113            func()
114            self.fail("Expected an exception of type %r" % exc_type)
115        except exc_type as e:
116            self.assertTrue(isinstance(e, exc_type))
117            return e
118
119    def should_not_fail(self, func):
120        """Calls func and succeeds if and only if no exception is raised
121        (i.e. converts exception raising into a failed testcase). Returns
122        the return value of func."""
123        try:
124            return func()
125        except Exception as exc:
126            self.fail(str(exc))
127
128
129class TransformTest(CythonTest):
130    """
131    Utility base class for transform unit tests. It is based around constructing
132    test trees (either explicitly or by parsing a Cython code string); running
133    the transform, serialize it using a customized Cython serializer (with
134    special markup for nodes that cannot be represented in Cython),
135    and do a string-comparison line-by-line of the result.
136
137    To create a test case:
138     - Call run_pipeline. The pipeline should at least contain the transform you
139       are testing; pyx should be either a string (passed to the parser to
140       create a post-parse tree) or a node representing input to pipeline.
141       The result will be a transformed result.
142
143     - Check that the tree is correct. If wanted, assertCode can be used, which
144       takes a code string as expected, and a ModuleNode in result_tree
145       (it serializes the ModuleNode to a string and compares line-by-line).
146
147    All code strings are first stripped for whitespace lines and then common
148    indentation.
149
150    Plans: One could have a pxd dictionary parameter to run_pipeline.
151    """
152
153    def run_pipeline(self, pipeline, pyx, pxds=None):
154        if pxds is None:
155            pxds = {}
156        tree = self.fragment(pyx, pxds).root
157        # Run pipeline
158        for T in pipeline:
159            tree = T(tree)
160        return tree
161
162
163class TreeAssertVisitor(VisitorTransform):
164    # actually, a TreeVisitor would be enough, but this needs to run
165    # as part of the compiler pipeline
166
167    def visit_CompilerDirectivesNode(self, node):
168        directives = node.directives
169        if 'test_assert_path_exists' in directives:
170            for path in directives['test_assert_path_exists']:
171                if TreePath.find_first(node, path) is None:
172                    Errors.error(
173                        node.pos,
174                        "Expected path '%s' not found in result tree" % path)
175        if 'test_fail_if_path_exists' in directives:
176            for path in directives['test_fail_if_path_exists']:
177                if TreePath.find_first(node, path) is not None:
178                    Errors.error(
179                        node.pos,
180                        "Unexpected path '%s' found in result tree" %  path)
181        self.visitchildren(node)
182        return node
183
184    visit_Node = VisitorTransform.recurse_to_children
185
186
187def unpack_source_tree(tree_file, dir=None):
188    if dir is None:
189        dir = tempfile.mkdtemp()
190    header = []
191    cur_file = None
192    f = open(tree_file)
193    try:
194        lines = f.readlines()
195    finally:
196        f.close()
197    del f
198    try:
199        for line in lines:
200            if line[:5] == '#####':
201                filename = line.strip().strip('#').strip().replace('/', os.path.sep)
202                path = os.path.join(dir, filename)
203                if not os.path.exists(os.path.dirname(path)):
204                    os.makedirs(os.path.dirname(path))
205                if cur_file is not None:
206                    f, cur_file = cur_file, None
207                    f.close()
208                cur_file = open(path, 'w')
209            elif cur_file is not None:
210                cur_file.write(line)
211            elif line.strip() and not line.lstrip().startswith('#'):
212                if line.strip() not in ('"""', "'''"):
213                    header.append(line)
214    finally:
215        if cur_file is not None:
216            cur_file.close()
217    return dir, ''.join(header)
218