1#!/usr/bin/env python3
2
3"""Updates FileCheck checks in MIR tests.
4
5This script is a utility to update MIR based tests with new FileCheck
6patterns.
7
8The checks added by this script will cover the entire body of each
9function it handles. Virtual registers used are given names via
10FileCheck patterns, so if you do want to check a subset of the body it
11should be straightforward to trim out the irrelevant parts. None of
12the YAML metadata will be checked, other than function names.
13
14If there are multiple llc commands in a test, the full set of checks
15will be repeated for each different check pattern. Checks for patterns
16that are common between different commands will be left as-is by
17default, or removed if the --remove-common-prefixes flag is provided.
18"""
19
20from __future__ import print_function
21
22import argparse
23import collections
24import glob
25import os
26import re
27import subprocess
28import sys
29
30from UpdateTestChecks import common
31
32MIR_FUNC_NAME_RE = re.compile(r' *name: *(?P<func>[A-Za-z0-9_.-]+)')
33MIR_BODY_BEGIN_RE = re.compile(r' *body: *\|')
34MIR_BASIC_BLOCK_RE = re.compile(r' *bb\.[0-9]+.*:$')
35VREG_RE = re.compile(r'(%[0-9]+)(?::[a-z0-9_]+)?(?:\([<>a-z0-9 ]+\))?')
36MI_FLAGS_STR= (
37    r'(frame-setup |frame-destroy |nnan |ninf |nsz |arcp |contract |afn '
38    r'|reassoc |nuw |nsw |exact |fpexcept )*')
39VREG_DEF_RE = re.compile(
40    r'^ *(?P<vregs>{0}(?:, {0})*) = '
41    r'{1}(?P<opcode>[A-Zt][A-Za-z0-9_]+)'.format(VREG_RE.pattern, MI_FLAGS_STR))
42MIR_PREFIX_DATA_RE = re.compile(r'^ *(;|bb.[0-9].*: *$|[a-z]+:( |$)|$)')
43
44IR_FUNC_NAME_RE = re.compile(
45    r'^\s*define\s+(?:internal\s+)?[^@]*@(?P<func>[A-Za-z0-9_.]+)\s*\(')
46IR_PREFIX_DATA_RE = re.compile(r'^ *(;|$)')
47
48MIR_FUNC_RE = re.compile(
49    r'^---$'
50    r'\n'
51    r'^ *name: *(?P<func>[A-Za-z0-9_.-]+)$'
52    r'.*?'
53    r'^ *body: *\|\n'
54    r'(?P<body>.*?)\n'
55    r'^\.\.\.$',
56    flags=(re.M | re.S))
57
58
59class LLC:
60    def __init__(self, bin):
61        self.bin = bin
62
63    def __call__(self, args, ir):
64        if ir.endswith('.mir'):
65            args = '{} -x mir'.format(args)
66        with open(ir) as ir_file:
67            stdout = subprocess.check_output('{} {}'.format(self.bin, args),
68                                             shell=True, stdin=ir_file)
69            if sys.version_info[0] > 2:
70              stdout = stdout.decode()
71            # Fix line endings to unix CR style.
72            stdout = stdout.replace('\r\n', '\n')
73        return stdout
74
75
76class Run:
77    def __init__(self, prefixes, cmd_args, triple):
78        self.prefixes = prefixes
79        self.cmd_args = cmd_args
80        self.triple = triple
81
82    def __getitem__(self, index):
83        return [self.prefixes, self.cmd_args, self.triple][index]
84
85
86def log(msg, verbose=True):
87    if verbose:
88        print(msg, file=sys.stderr)
89
90
91def find_triple_in_ir(lines, verbose=False):
92    for l in lines:
93        m = common.TRIPLE_IR_RE.match(l)
94        if m:
95            return m.group(1)
96    return None
97
98
99def build_run_list(test, run_lines, verbose=False):
100    run_list = []
101    all_prefixes = []
102    for l in run_lines:
103        if '|' not in l:
104            common.warn('Skipping unparseable RUN line: ' + l)
105            continue
106
107        commands = [cmd.strip() for cmd in l.split('|', 1)]
108        llc_cmd = commands[0]
109        filecheck_cmd = commands[1] if len(commands) > 1 else ''
110        common.verify_filecheck_prefixes(filecheck_cmd)
111
112        if not llc_cmd.startswith('llc '):
113            common.warn('Skipping non-llc RUN line: {}'.format(l), test_file=test)
114            continue
115        if not filecheck_cmd.startswith('FileCheck '):
116            common.warn('Skipping non-FileChecked RUN line: {}'.format(l),
117                 test_file=test)
118            continue
119
120        triple = None
121        m = common.TRIPLE_ARG_RE.search(llc_cmd)
122        if m:
123            triple = m.group(1)
124        # If we find -march but not -mtriple, use that.
125        m = common.MARCH_ARG_RE.search(llc_cmd)
126        if m and not triple:
127            triple = '{}--'.format(m.group(1))
128
129        cmd_args = llc_cmd[len('llc'):].strip()
130        cmd_args = cmd_args.replace('< %s', '').replace('%s', '').strip()
131
132        check_prefixes = [
133            item
134            for m in common.CHECK_PREFIX_RE.finditer(filecheck_cmd)
135            for item in m.group(1).split(',')]
136        if not check_prefixes:
137            check_prefixes = ['CHECK']
138        all_prefixes += check_prefixes
139
140        run_list.append(Run(check_prefixes, cmd_args, triple))
141
142    # Remove any common prefixes. We'll just leave those entirely alone.
143    common_prefixes = set([prefix for prefix in all_prefixes
144                           if all_prefixes.count(prefix) > 1])
145    for run in run_list:
146        run.prefixes = [p for p in run.prefixes if p not in common_prefixes]
147
148    return run_list, common_prefixes
149
150
151def find_functions_with_one_bb(lines, verbose=False):
152    result = []
153    cur_func = None
154    bbs = 0
155    for line in lines:
156        m = MIR_FUNC_NAME_RE.match(line)
157        if m:
158            if bbs == 1:
159                result.append(cur_func)
160            cur_func = m.group('func')
161            bbs = 0
162        m = MIR_BASIC_BLOCK_RE.match(line)
163        if m:
164            bbs += 1
165    if bbs == 1:
166        result.append(cur_func)
167    return result
168
169
170def build_function_body_dictionary(test, raw_tool_output, triple, prefixes,
171                                   func_dict, verbose):
172    for m in MIR_FUNC_RE.finditer(raw_tool_output):
173        func = m.group('func')
174        body = m.group('body')
175        if verbose:
176            log('Processing function: {}'.format(func))
177            for l in body.splitlines():
178                log('  {}'.format(l))
179        for prefix in prefixes:
180            if func in func_dict[prefix] and func_dict[prefix][func] != body:
181                common.warn('Found conflicting asm for prefix: {}'.format(prefix),
182                     test_file=test)
183            func_dict[prefix][func] = body
184
185
186def add_checks_for_function(test, output_lines, run_list, func_dict, func_name,
187                            single_bb, verbose=False):
188    printed_prefixes = set()
189    for run in run_list:
190        for prefix in run.prefixes:
191            if prefix in printed_prefixes:
192                continue
193            if not func_dict[prefix][func_name]:
194                continue
195            # if printed_prefixes:
196            #     # Add some space between different check prefixes.
197            #     output_lines.append('')
198            printed_prefixes.add(prefix)
199            log('Adding {} lines for {}'.format(prefix, func_name), verbose)
200            add_check_lines(test, output_lines, prefix, func_name, single_bb,
201                            func_dict[prefix][func_name].splitlines())
202            break
203    return output_lines
204
205
206def add_check_lines(test, output_lines, prefix, func_name, single_bb,
207                    func_body):
208    if single_bb:
209        # Don't bother checking the basic block label for a single BB
210        func_body.pop(0)
211
212    if not func_body:
213        common.warn('Function has no instructions to check: {}'.format(func_name),
214             test_file=test)
215        return
216
217    first_line = func_body[0]
218    indent = len(first_line) - len(first_line.lstrip(' '))
219    # A check comment, indented the appropriate amount
220    check = '{:>{}}; {}'.format('', indent, prefix)
221
222    output_lines.append('{}-LABEL: name: {}'.format(check, func_name))
223
224    vreg_map = {}
225    for func_line in func_body:
226        if not func_line.strip():
227            continue
228        m = VREG_DEF_RE.match(func_line)
229        if m:
230            for vreg in VREG_RE.finditer(m.group('vregs')):
231                name = mangle_vreg(m.group('opcode'), vreg_map.values())
232                vreg_map[vreg.group(1)] = name
233                func_line = func_line.replace(
234                    vreg.group(1), '[[{}:%[0-9]+]]'.format(name), 1)
235        for number, name in vreg_map.items():
236            func_line = re.sub(r'{}\b'.format(number), '[[{}]]'.format(name),
237                               func_line)
238        check_line = '{}: {}'.format(check, func_line[indent:]).rstrip()
239        output_lines.append(check_line)
240
241
242def mangle_vreg(opcode, current_names):
243    base = opcode
244    # Simplify some common prefixes and suffixes
245    if opcode.startswith('G_'):
246        base = base[len('G_'):]
247    if opcode.endswith('_PSEUDO'):
248        base = base[:len('_PSEUDO')]
249    # Shorten some common opcodes with long-ish names
250    base = dict(IMPLICIT_DEF='DEF',
251                GLOBAL_VALUE='GV',
252                CONSTANT='C',
253                FCONSTANT='C',
254                MERGE_VALUES='MV',
255                UNMERGE_VALUES='UV',
256                INTRINSIC='INT',
257                INTRINSIC_W_SIDE_EFFECTS='INT',
258                INSERT_VECTOR_ELT='IVEC',
259                EXTRACT_VECTOR_ELT='EVEC',
260                SHUFFLE_VECTOR='SHUF').get(base, base)
261    # Avoid ambiguity when opcodes end in numbers
262    if len(base.rstrip('0123456789')) < len(base):
263        base += '_'
264
265    i = 0
266    for name in current_names:
267        if name.rstrip('0123456789') == base:
268            i += 1
269    if i:
270        return '{}{}'.format(base, i)
271    return base
272
273
274def should_add_line_to_output(input_line, prefix_set):
275    # Skip any check lines that we're handling.
276    m = common.CHECK_RE.match(input_line)
277    if m and m.group(1) in prefix_set:
278        return False
279    return True
280
281
282def update_test_file(args, test):
283    with open(test) as fd:
284        input_lines = [l.rstrip() for l in fd]
285
286    script_name = os.path.basename(__file__)
287    first_line = input_lines[0] if input_lines else ""
288    if 'autogenerated' in first_line and script_name not in first_line:
289        common.warn("Skipping test which wasn't autogenerated by " +
290                    script_name + ": " + test)
291        return
292
293    if args.update_only:
294      if not first_line or 'autogenerated' not in first_line:
295        common.warn("Skipping test which isn't autogenerated: " + test)
296        return
297
298    triple_in_ir = find_triple_in_ir(input_lines, args.verbose)
299    run_lines = common.find_run_lines(test, input_lines)
300    run_list, common_prefixes = build_run_list(test, run_lines, args.verbose)
301
302    simple_functions = find_functions_with_one_bb(input_lines, args.verbose)
303
304    func_dict = {}
305    for run in run_list:
306        for prefix in run.prefixes:
307            func_dict.update({prefix: dict()})
308    for prefixes, llc_args, triple_in_cmd in run_list:
309        log('Extracted LLC cmd: llc {}'.format(llc_args), args.verbose)
310        log('Extracted FileCheck prefixes: {}'.format(prefixes), args.verbose)
311
312        raw_tool_output = args.llc(llc_args, test)
313        if not triple_in_cmd and not triple_in_ir:
314            common.warn('No triple found: skipping file', test_file=test)
315            return
316
317        build_function_body_dictionary(test, raw_tool_output,
318                                       triple_in_cmd or triple_in_ir,
319                                       prefixes, func_dict, args.verbose)
320
321    state = 'toplevel'
322    func_name = None
323    prefix_set = set([prefix for run in run_list for prefix in run.prefixes])
324    log('Rewriting FileCheck prefixes: {}'.format(prefix_set), args.verbose)
325
326    if args.remove_common_prefixes:
327        prefix_set.update(common_prefixes)
328    elif common_prefixes:
329        common.warn('Ignoring common prefixes: {}'.format(common_prefixes),
330             test_file=test)
331
332    comment_char = '#' if test.endswith('.mir') else ';'
333    autogenerated_note = ('{} NOTE: Assertions have been autogenerated by '
334                          'utils/{}'.format(comment_char, script_name))
335    output_lines = []
336    output_lines.append(autogenerated_note)
337
338    for input_line in input_lines:
339        if input_line == autogenerated_note:
340            continue
341
342        if state == 'toplevel':
343            m = IR_FUNC_NAME_RE.match(input_line)
344            if m:
345                state = 'ir function prefix'
346                func_name = m.group('func')
347            if input_line.rstrip('| \r\n') == '---':
348                state = 'document'
349            output_lines.append(input_line)
350        elif state == 'document':
351            m = MIR_FUNC_NAME_RE.match(input_line)
352            if m:
353                state = 'mir function metadata'
354                func_name = m.group('func')
355            if input_line.strip() == '...':
356                state = 'toplevel'
357                func_name = None
358            if should_add_line_to_output(input_line, prefix_set):
359                output_lines.append(input_line)
360        elif state == 'mir function metadata':
361            if should_add_line_to_output(input_line, prefix_set):
362                output_lines.append(input_line)
363            m = MIR_BODY_BEGIN_RE.match(input_line)
364            if m:
365                if func_name in simple_functions:
366                    # If there's only one block, put the checks inside it
367                    state = 'mir function prefix'
368                    continue
369                state = 'mir function body'
370                add_checks_for_function(test, output_lines, run_list,
371                                        func_dict, func_name, single_bb=False,
372                                        verbose=args.verbose)
373        elif state == 'mir function prefix':
374            m = MIR_PREFIX_DATA_RE.match(input_line)
375            if not m:
376                state = 'mir function body'
377                add_checks_for_function(test, output_lines, run_list,
378                                        func_dict, func_name, single_bb=True,
379                                        verbose=args.verbose)
380
381            if should_add_line_to_output(input_line, prefix_set):
382                output_lines.append(input_line)
383        elif state == 'mir function body':
384            if input_line.strip() == '...':
385                state = 'toplevel'
386                func_name = None
387            if should_add_line_to_output(input_line, prefix_set):
388                output_lines.append(input_line)
389        elif state == 'ir function prefix':
390            m = IR_PREFIX_DATA_RE.match(input_line)
391            if not m:
392                state = 'ir function body'
393                add_checks_for_function(test, output_lines, run_list,
394                                        func_dict, func_name, single_bb=False,
395                                        verbose=args.verbose)
396
397            if should_add_line_to_output(input_line, prefix_set):
398                output_lines.append(input_line)
399        elif state == 'ir function body':
400            if input_line.strip() == '}':
401                state = 'toplevel'
402                func_name = None
403            if should_add_line_to_output(input_line, prefix_set):
404                output_lines.append(input_line)
405
406
407    log('Writing {} lines to {}...'.format(len(output_lines), test), args.verbose)
408
409    with open(test, 'wb') as fd:
410        fd.writelines(['{}\n'.format(l).encode('utf-8') for l in output_lines])
411
412
413def main():
414    parser = argparse.ArgumentParser(
415        description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
416    parser.add_argument('--llc-binary', dest='llc', default='llc', type=LLC,
417                        help='The "llc" binary to generate the test case with')
418    parser.add_argument('--remove-common-prefixes', action='store_true',
419                        help='Remove existing check lines whose prefixes are '
420                             'shared between multiple commands')
421    parser.add_argument('tests', nargs='+')
422    args = common.parse_commandline_args(parser)
423
424    test_paths = [test for pattern in args.tests for test in glob.glob(pattern)]
425    for test in test_paths:
426        try:
427            update_test_file(args, test)
428        except Exception:
429            common.warn('Error processing file', test_file=test)
430            raise
431
432
433if __name__ == '__main__':
434  main()
435