1from __future__ import print_function
2from __future__ import absolute_import
3
4# System modules
5import os
6import textwrap
7
8# Third-party modules
9import io
10
11# LLDB modules
12import lldb
13from .lldbtest import *
14from . import configuration
15from . import lldbutil
16from .decorators import *
17
18def source_type(filename):
19    _, extension = os.path.splitext(filename)
20    return {
21        '.c': 'C_SOURCES',
22        '.cpp': 'CXX_SOURCES',
23        '.cxx': 'CXX_SOURCES',
24        '.cc': 'CXX_SOURCES',
25        '.m': 'OBJC_SOURCES',
26        '.mm': 'OBJCXX_SOURCES'
27    }.get(extension, None)
28
29
30class CommandParser:
31
32    def __init__(self):
33        self.breakpoints = []
34
35    def parse_one_command(self, line):
36        parts = line.split('//%')
37
38        command = None
39        new_breakpoint = True
40
41        if len(parts) == 2:
42            command = parts[1].rstrip()
43            new_breakpoint = parts[0].strip() != ""
44
45        return (command, new_breakpoint)
46
47    def parse_source_files(self, source_files):
48        for source_file in source_files:
49            file_handle = io.open(source_file, encoding='utf-8')
50            lines = file_handle.readlines()
51            line_number = 0
52            # non-NULL means we're looking through whitespace to find
53            # additional commands
54            current_breakpoint = None
55            for line in lines:
56                line_number = line_number + 1  # 1-based, so we do this first
57                (command, new_breakpoint) = self.parse_one_command(line)
58
59                if new_breakpoint:
60                    current_breakpoint = None
61
62                if command is not None:
63                    if current_breakpoint is None:
64                        current_breakpoint = {}
65                        current_breakpoint['file_name'] = source_file
66                        current_breakpoint['line_number'] = line_number
67                        current_breakpoint['command'] = command
68                        self.breakpoints.append(current_breakpoint)
69                    else:
70                        current_breakpoint['command'] = current_breakpoint[
71                            'command'] + "\n" + command
72        for bkpt in self.breakpoints:
73            bkpt['command'] = textwrap.dedent(bkpt['command'])
74
75    def set_breakpoints(self, target):
76        for breakpoint in self.breakpoints:
77            breakpoint['breakpoint'] = target.BreakpointCreateByLocation(
78                breakpoint['file_name'], breakpoint['line_number'])
79
80    def handle_breakpoint(self, test, breakpoint_id):
81        for breakpoint in self.breakpoints:
82            if breakpoint['breakpoint'].GetID() == breakpoint_id:
83                test.execute_user_command(breakpoint['command'])
84                return
85
86
87class InlineTest(TestBase):
88
89    def getBuildDirBasename(self):
90        return self.__class__.__name__ + "." + self.testMethodName
91
92    def BuildMakefile(self):
93        makefilePath = self.getBuildArtifact("Makefile")
94        if os.path.exists(makefilePath):
95            return
96
97        categories = {}
98        for f in os.listdir(self.getSourceDir()):
99            t = source_type(f)
100            if t:
101                if t in list(categories.keys()):
102                    categories[t].append(f)
103                else:
104                    categories[t] = [f]
105
106        with open(makefilePath, 'w+') as makefile:
107            for t in list(categories.keys()):
108                line = t + " := " + " ".join(categories[t])
109                makefile.write(line + "\n")
110
111            if ('OBJCXX_SOURCES' in list(categories.keys())) or \
112               ('OBJC_SOURCES' in list(categories.keys())):
113                makefile.write(
114                    "LDFLAGS = $(CFLAGS) -lobjc -framework Foundation\n")
115
116            if ('CXX_SOURCES' in list(categories.keys())):
117                makefile.write("CXXFLAGS += -std=c++11\n")
118
119            makefile.write("include Makefile.rules\n")
120
121    def _test(self):
122        self.BuildMakefile()
123        self.build(dictionary=self._build_dict)
124        self.do_test()
125
126    def execute_user_command(self, __command):
127        exec(__command, globals(), locals())
128
129    def _get_breakpoint_ids(self, thread):
130        ids = set()
131        for i in range(0, thread.GetStopReasonDataCount(), 2):
132            ids.add(thread.GetStopReasonDataAtIndex(i))
133        self.assertGreater(len(ids), 0)
134        return sorted(ids)
135
136    def do_test(self):
137        exe = self.getBuildArtifact("a.out")
138        source_files = [f for f in os.listdir(self.getSourceDir())
139                        if source_type(f)]
140        target = self.dbg.CreateTarget(exe)
141
142        parser = CommandParser()
143        parser.parse_source_files(source_files)
144        parser.set_breakpoints(target)
145
146        process = target.LaunchSimple(None, None, self.get_process_working_directory())
147        self.assertIsNotNone(process, PROCESS_IS_VALID)
148
149        hit_breakpoints = 0
150
151        while lldbutil.get_stopped_thread(process, lldb.eStopReasonBreakpoint):
152            hit_breakpoints += 1
153            thread = lldbutil.get_stopped_thread(
154                process, lldb.eStopReasonBreakpoint)
155            for bp_id in self._get_breakpoint_ids(thread):
156                parser.handle_breakpoint(self, bp_id)
157            process.Continue()
158
159        self.assertTrue(hit_breakpoints > 0,
160                        "inline test did not hit a single breakpoint")
161        # Either the process exited or the stepping plan is complete.
162        self.assertTrue(process.GetState() in [lldb.eStateStopped,
163                                               lldb.eStateExited],
164                        PROCESS_EXITED)
165
166    def check_expression(self, expression, expected_result, use_summary=True):
167        value = self.frame().EvaluateExpression(expression)
168        self.assertTrue(value.IsValid(), expression + "returned a valid value")
169        if self.TraceOn():
170            print(value.GetSummary())
171            print(value.GetValue())
172        if use_summary:
173            answer = value.GetSummary()
174        else:
175            answer = value.GetValue()
176        report_str = "%s expected: %s got: %s" % (
177            expression, expected_result, answer)
178        self.assertTrue(answer == expected_result, report_str)
179
180
181def ApplyDecoratorsToFunction(func, decorators):
182    tmp = func
183    if isinstance(decorators, list):
184        for decorator in decorators:
185            tmp = decorator(tmp)
186    elif hasattr(decorators, '__call__'):
187        tmp = decorators(tmp)
188    return tmp
189
190
191def MakeInlineTest(__file, __globals, decorators=None, name=None,
192        build_dict=None):
193    # Adjust the filename if it ends in .pyc.  We want filenames to
194    # reflect the source python file, not the compiled variant.
195    if __file is not None and __file.endswith(".pyc"):
196        # Strip the trailing "c"
197        __file = __file[0:-1]
198
199    if name is None:
200        # Derive the test name from the current file name
201        file_basename = os.path.basename(__file)
202        name, _ = os.path.splitext(file_basename)
203
204    test_func = ApplyDecoratorsToFunction(InlineTest._test, decorators)
205    # Build the test case
206    test_class = type(name, (InlineTest,), dict(test=test_func,
207        name=name, _build_dict=build_dict))
208
209    # Add the test case to the globals, and hide InlineTest
210    __globals.update({name: test_class})
211
212    # Keep track of the original test filename so we report it
213    # correctly in test results.
214    test_class.test_filename = __file
215    test_class.mydir = TestBase.compute_mydir(__file)
216    return test_class
217