1import unittest
2from llvmcpy import llvm
3from packaging import version
4
5module_source = """; ModuleID = 'example.c'
6target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
7target triple = "x86_64-pc-linux-gnu"
8
9; Function Attrs: nounwind uwtable
10define i32 @function2() {
11  ret i32 42
12}
13
14; Function Attrs: nounwind uwtable
15define i32 @function1() {
16  %1 = call i32 @function2()
17  %2 = call i32 @function2()
18  ret i32 %1
19}
20
21; Function Attrs: nounwind uwtable
22define i32 @main(i32, i8**) {
23  %3 = alloca i32, align 4
24  %4 = alloca i32, align 4
25  %5 = alloca i8**, align 8
26  store i32 0, i32* %3, align 4
27  store i32 %0, i32* %4, align 4
28  store i8** %1, i8*** %5, align 8
29  %6 = call i32 @function1()
30  ret i32 %6
31}
32"""
33
34if version.parse(llvm.version) >= version.parse("7.0"):
35    module_source = module_source + """
36    !llvm.module.flags = !{!0}
37    !0 = !{ i32 4, !"foo", i32 42 }
38    """
39
40def load_module(ir):
41    context = llvm.get_global_context()
42    buffer = llvm.create_memory_buffer_with_memory_range_copy(ir,
43                                                              len(ir),
44                                                              "example")
45    return context.parse_ir(buffer)
46
47def get_function_number(ir):
48    module = load_module(ir)
49    return len(list(module.iter_functions()))
50
51def get_non_existing_basic_block(ir):
52    module = load_module(ir)
53    first_function = list(module.iter_functions())[0]
54    first_basic_block = list(first_function.iter_basic_blocks())[0]
55    first_basic_block.get_next().first_instruction()
56
57class TestSuite(unittest.TestCase):
58    def test_function_count(self):
59        self.assertEqual(get_function_number(module_source), 3)
60
61    def test_null_ptr(self):
62        with self.assertRaises(AttributeError):
63            get_non_existing_basic_block(module_source)
64
65    def test_resolve_enums(self):
66        assert llvm.Opcode[llvm.Switch] == 'Switch'
67        assert llvm.Opcode['Switch'] == llvm.Switch
68
69    def test_translate_null_ptr_to_none(self):
70        module = load_module(module_source)
71        first_function = list(module.iter_functions())[0]
72        first_basic_block = list(first_function.iter_basic_blocks())[0]
73        first_instruction = first_basic_block.first_instruction
74
75        assert first_instruction.is_a_binary_operator() is None
76
77    def test_value_as_key(self):
78        module = load_module(module_source)
79        function1 = module.get_named_function("function1")
80        first_basic_block = function1.get_first_basic_block()
81        first_instruction = first_basic_block.get_first_instruction()
82        second_instruction = first_instruction.get_next_instruction()
83        operand1 = first_instruction.get_operand(0)
84        operand2 = second_instruction.get_operand(0)
85        dictionary = {}
86        dictionary[operand1] = 42
87        assert operand2 in dictionary
88
89    def test_sized_string_return(self):
90        string = "a\0b\0c"
91        value = llvm.md_string(string, len(string))
92        self.assertEqual(value.get_md_string(), string)
93        self.assertEqual(value.get_md_string(encoding=None), string.encode('ascii'))
94
95    def test_metadata_flags(self):
96        if version.parse(llvm.version) < version.parse("7.0"):
97            return
98        module = load_module(module_source)
99        length = llvm.ffi.new("size_t *")
100        metadata_flags = module.copy_module_flags_metadata(length)
101        behavior = metadata_flags.module_flag_entries_get_flag_behavior(0)
102        key = metadata_flags.module_flag_entries_get_key(0)
103        assert behavior == 3
104        assert key == "foo"
105
106if __name__ == '__main__':
107    unittest.main()
108