1from __future__ import absolute_import
2
3from .Visitor import ScopeTrackingTransform
4from .Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode, DefNode
5from .ExprNodes import DictNode, DictItemNode, NameNode, UnicodeNode
6from .PyrexTypes import py_object_type
7from .StringEncoding import EncodedString
8from . import Symtab
9
10class AutoTestDictTransform(ScopeTrackingTransform):
11    # Handles autotestdict directive
12
13    blacklist = ['__cinit__', '__dealloc__', '__richcmp__',
14                 '__nonzero__', '__bool__',
15                 '__len__', '__contains__']
16
17    def visit_ModuleNode(self, node):
18        if node.is_pxd:
19            return node
20        self.scope_type = 'module'
21        self.scope_node = node
22
23        if not self.current_directives['autotestdict']:
24            return node
25        self.all_docstrings = self.current_directives['autotestdict.all']
26        self.cdef_docstrings = self.all_docstrings or self.current_directives['autotestdict.cdef']
27
28        assert isinstance(node.body, StatListNode)
29
30        # First see if __test__ is already created
31        if u'__test__' in node.scope.entries:
32            # Do nothing
33            return node
34
35        pos = node.pos
36
37        self.tests = []
38        self.testspos = node.pos
39
40        test_dict_entry = node.scope.declare_var(EncodedString(u'__test__'),
41                                                 py_object_type,
42                                                 pos,
43                                                 visibility='public')
44        create_test_dict_assignment = SingleAssignmentNode(pos,
45            lhs=NameNode(pos, name=EncodedString(u'__test__'),
46                         entry=test_dict_entry),
47            rhs=DictNode(pos, key_value_pairs=self.tests))
48        self.visitchildren(node)
49        node.body.stats.append(create_test_dict_assignment)
50        return node
51
52    def add_test(self, testpos, path, doctest):
53        pos = self.testspos
54        keystr = u'%s (line %d)' % (path, testpos[1])
55        key = UnicodeNode(pos, value=EncodedString(keystr))
56        value = UnicodeNode(pos, value=doctest)
57        self.tests.append(DictItemNode(pos, key=key, value=value))
58
59    def visit_ExprNode(self, node):
60        # expressions cannot contain functions and lambda expressions
61        # do not have a docstring
62        return node
63
64    def visit_FuncDefNode(self, node):
65        if not node.doc or (isinstance(node, DefNode) and node.fused_py_func):
66            return node
67        if not self.cdef_docstrings:
68            if isinstance(node, CFuncDefNode) and not node.py_func:
69                return node
70        if not self.all_docstrings and '>>>' not in node.doc:
71            return node
72
73        pos = self.testspos
74        if self.scope_type == 'module':
75            path = node.entry.name
76        elif self.scope_type in ('pyclass', 'cclass'):
77            if isinstance(node, CFuncDefNode):
78                if node.py_func is not None:
79                    name = node.py_func.name
80                else:
81                    name = node.entry.name
82            else:
83                name = node.name
84            if self.scope_type == 'cclass' and name in self.blacklist:
85                return node
86            if self.scope_type == 'pyclass':
87                class_name = self.scope_node.name
88            else:
89                class_name = self.scope_node.class_name
90            if isinstance(node.entry.scope, Symtab.PropertyScope):
91                property_method_name = node.entry.scope.name
92                path = "%s.%s.%s" % (class_name, node.entry.scope.name,
93                                     node.entry.name)
94            else:
95                path = "%s.%s" % (class_name, node.entry.name)
96        else:
97            assert False
98        self.add_test(node.pos, path, node.doc)
99        return node
100