1# Copyright (c) 2016 Nuxi (https://nuxi.nl/) and contributors.
2#
3# SPDX-License-Identifier: BSD-2-Clause
4
5from .abi import *
6from .generator import *
7
8
9class CGenerator(Generator):
10    def __init__(self,
11                 naming,
12                 header_guard=None,
13                 machine_dep=None,
14                 md_type=None,
15                 preamble='',
16                 postamble=''):
17        super().__init__(comment_prefix='// ')
18        self.naming = naming
19        self.header_guard = header_guard
20        self.machine_dep = machine_dep
21        self.md_type = md_type
22        self.preamble = preamble
23        self.postamble = postamble
24
25    def generate_head(self, abi):
26        super().generate_head(abi)
27        if self.header_guard is not None:
28            print('#ifndef {}'.format(self.header_guard))
29            print('#define {}'.format(self.header_guard))
30            print()
31        if self.preamble != '':
32            print(self.preamble)
33
34    def generate_foot(self, abi):
35        if self.postamble != '':
36            print(self.postamble)
37        if self.header_guard is not None:
38            print('#endif')
39        super().generate_foot(abi)
40
41    def mi_type(self, mtype):
42        if self.md_type is not None:
43            if isinstance(mtype, PointerType) or mtype.name == 'size':
44                return self.md_type
45            elif isinstance(mtype, ArrayType):
46                return ArrayType(mtype.count, self.mi_type(mtype.element_type))
47            elif isinstance(mtype, AtomicType):
48                return AtomicType(self.mi_type(mtype.target_type))
49        return mtype
50
51    def syscall_params(self, syscall):
52        params = []
53        for p in syscall.input.raw_members:
54            params.append(self.naming.vardecl(p.type, p.name))
55        for p in syscall.output.raw_members:
56            params.append(
57                self.naming.vardecl(OutputPointerType(p.type), p.name))
58        return params
59
60
61class CSyscalldefsGenerator(CGenerator):
62    def generate_struct_members(self, abi, type, indent=''):
63        for m in type.raw_members:
64            if isinstance(m, SimpleStructMember):
65                mtype = self.mi_type(m.type)
66                if mtype.layout.align[0] == mtype.layout.align[1]:
67                    alignas = '_Alignas({}) '.format(mtype.layout.align[0])
68                else:
69                    alignas = ''
70                print('{}{}{};'.format(indent, alignas,
71                                       self.naming.vardecl(mtype, m.name)))
72            elif isinstance(m, VariantStructMember):
73                print('{}union {{'.format(indent))
74                for x in m.members:
75                    if x.name is None:
76                        self.generate_struct_members(abi, x.type,
77                                                     indent + '  ')
78                    else:
79                        print('{}  struct {{'.format(indent))
80                        self.generate_struct_members(abi, x.type,
81                                                     indent + '    ')
82                        print('{}  }} {};'.format(indent, x.name))
83                print('{}}};'.format(indent))
84            else:
85                raise Exception('Unknown struct member: {}'.format(m))
86
87    def generate_type(self, abi, type):
88
89        if self.machine_dep is not None:
90            if type.layout.machine_dep != self.machine_dep:
91                return
92
93        if isinstance(type, IntLikeType):
94            print('typedef {};'.format(
95                self.naming.vardecl(type.int_type,
96                                    self.naming.typename(type))))
97            if len(type.values) > 0:
98                width = max(
99                    len(self.naming.valname(type, v)) for v in type.values)
100                if (isinstance(type, FlagsType)
101                        or isinstance(type, OpaqueType)):
102                    if len(type.values) == 1 and type.values[0].value == 0:
103                        val_format = 'd'
104                    else:
105                        val_format = '#0{}x'.format(type.layout.size[0] * 2 +
106                                                    2)
107                else:
108                    val_width = max(len(str(v.value)) for v in type.values)
109                    val_format = '{}d'.format(val_width)
110
111                for v in type.values:
112                    print('#define {name:{width}} '
113                          '{val:{val_format}}'.format(
114                              name=self.naming.valname(type, v),
115                              width=width,
116                              val=v.value,
117                              val_format=val_format))
118
119        elif isinstance(type, FunctionType):
120            parameters = []
121            for p in type.parameters.raw_members:
122                parameters.append(
123                    self.naming.vardecl(self.mi_type(p.type), p.name))
124            print('typedef {};'.format(
125                self.naming.vardecl(
126                    self.mi_type(type.return_type),
127                    '{}({})'.format(
128                        self.naming.typename(type), ', '.join(parameters)),
129                    array_need_parens=True)))
130
131        elif isinstance(type, StructType):
132            typename = self.naming.typename(type)
133
134            print('typedef struct {')
135            self.generate_struct_members(abi, type, '  ')
136            print('}} {};'.format(typename))
137
138            self.generate_offset_asserts(typename, type.raw_members)
139            self.generate_size_assert(typename, type.layout.size)
140            self.generate_align_assert(typename, type.layout.align)
141
142        else:
143            raise Exception('Unknown class of type: {}'.format(type))
144
145        print()
146
147    def generate_offset_asserts(self,
148                                type_name,
149                                members,
150                                prefix='',
151                                offset=(0, 0)):
152        for m in members:
153            if isinstance(m, VariantMember):
154                mprefix = prefix
155                if m.name is not None:
156                    mprefix += m.name + '.'
157                self.generate_offset_asserts(type_name, m.type.members,
158                                             mprefix, offset)
159            elif m.offset is not None:
160                moffset = (offset[0] + m.offset[0], offset[1] + m.offset[1])
161                if isinstance(m, VariantStructMember):
162                    self.generate_offset_asserts(type_name, m.members, prefix,
163                                                 moffset)
164                else:
165                    self.generate_offset_assert(type_name, prefix + m.name,
166                                                moffset)
167
168    def generate_offset_assert(self, type_name, member_name, offset):
169        self.generate_layout_assert('offsetof({}, {})'.format(
170            type_name, member_name), offset)
171
172    def generate_size_assert(self, type_name, size):
173        self.generate_layout_assert('sizeof({})'.format(type_name), size)
174
175    def generate_align_assert(self, type_name, align):
176        self.generate_layout_assert('_Alignof({})'.format(type_name), align)
177
178    def generate_layout_assert(self, expression, value):
179        static_assert = '_Static_assert({}, "Incorrect layout");'
180        if value[0] == value[1] or (self.md_type is not None
181                                    and self.md_type.layout.size in ((4, 4),
182                                                                     (8, 8))):
183            v = value[1]
184            if self.md_type is not None and self.md_type.layout.size == (4, 4):
185                v = value[0]
186            print(static_assert.format('{} == {}'.format(expression, v)))
187        else:
188            voidptr = self.naming.typename(PointerType())
189            print(
190                static_assert.format('sizeof({}) != 4 || {} == {}'.format(
191                    voidptr, expression, value[0])))
192            print(
193                static_assert.format('sizeof({}) != 8 || {} == {}'.format(
194                    voidptr, expression, value[1])))
195
196    def generate_syscalls(self, abi, syscalls):
197        pass
198
199
200class CSyscallsInfoGenerator(CGenerator):
201    def print_with_line_continuation(self, text):
202        lines = str.splitlines(text)
203        width = max(len(line) for line in lines)
204        for line in lines[:-1]:
205            print('{}{} \\'.format(line, ' ' * (width - len(line))))
206        print(lines[-1])
207
208    def generate_syscalls(self, abi, syscalls):
209        prefix = self.naming.prefix.upper()
210        self.print_with_line_continuation(
211            '#define {}SYSCALL_NAMES(SYSCALL)\n'.format(prefix) +
212            '\n'.join('  SYSCALL({})'.format(s) for s in sorted(abi.syscalls)))
213        print()
214        for s in sorted(abi.syscalls):
215            params = self.syscall_params(abi.syscalls[s])
216            self.print_with_line_continuation(
217                '#define {}SYSCALL_PARAMETERS_{}\n'.format(
218                    prefix, s) + ',\n'.join('  {}'.format(p) for p in params))
219            print()
220        for s in sorted(abi.syscalls):
221            syscall = abi.syscalls[s]
222            params = ([p.name for p in syscall.input.raw_members] +
223                      [p.name for p in syscall.output.raw_members])
224            print(
225                '#define {}SYSCALL_PARAMETER_NAMES_{}'.format(prefix, s),
226                end='')
227            if params == []:
228                print()
229            else:
230                print(' \\\n  ' + ', '.join(params))
231            print()
232        for s in sorted(abi.syscalls):
233            print('#define {}SYSCALL_HAS_PARAMETERS_{}(yes, no) {}'.format(
234                self.naming.prefix.upper(), s,
235                ('no'
236                 if self.syscall_params(abi.syscalls[s]) == [] else 'yes')))
237        print()
238        for s in sorted(abi.syscalls):
239            print('#define {}SYSCALL_RETURNS_{}(yes, no) {}'.format(
240                self.naming.prefix.upper(), s, 'no'
241                if abi.syscalls[s].noreturn else 'yes'))
242        print()
243
244    def generate_types(self, abi, types):
245        pass
246
247
248class CSyscallsGenerator(CGenerator):
249    def generate_syscall(self, abi, syscall):
250        if self.machine_dep is not None:
251            if syscall.machine_dep != self.machine_dep:
252                return
253
254        self.generate_syscall_keywords(syscall)
255        if syscall.noreturn:
256            return_type = VoidType()
257        else:
258            return_type = abi.types['errno']
259        print(self.naming.typename(return_type))
260        print(self.naming.syscallname(syscall))
261        print('(')
262        params = self.syscall_params(syscall)
263        if params == []:
264            print('void')
265        else:
266            print(','.join(params))
267        print(')')
268        self.generate_syscall_body(abi, syscall)
269        print()
270
271    def generate_syscall_keywords(self, syscall):
272        if syscall.noreturn:
273            print('_Noreturn')
274
275    def generate_syscall_body(self, abi, syscall):
276        print(';')
277
278    def generate_types(self, abi, types):
279        pass
280
281
282class CLinuxSyscallsGenerator(CSyscallsGenerator):
283    def generate_syscall_keywords(self, syscall):
284        pass
285
286
287class CLinuxSyscallTableGenerator(CGenerator):
288    def generate_head(self, abi):
289        super().generate_head(abi)
290
291        # Macro for placing the system call argument at the right spot
292        # within a register, depending on the system's endianness.
293        regalign = self.md_type.layout.align[0]
294        regtype = self.naming.typename(self.md_type)
295        print(
296            '#ifdef __LITTLE_ENDIAN\n'
297            '#define MEMBER(type, name) _Alignas({}) type name\n'
298            '#else\n'
299            '#define PAD(type) \\\n'
300            '    ((sizeof({}) - (sizeof(type) % sizeof({}))) % sizeof({}))\n'
301            '#define MEMBER(type, name) char name##_pad[PAD(type)]; type name\n'
302            '#endif\n'.format(regalign, regtype, regtype, regtype))
303
304    def generate_syscall(self, abi, syscall):
305        print('static {} do_{}(const void *in, void *out) {{'.format(
306            self.naming.typename(abi.types['errno']), syscall.name))
307
308        # Map structures over the system call input and output registers.
309        if syscall.input.raw_members:
310            print('const struct {')
311            for p in syscall.input.raw_members:
312                print('MEMBER({}, {});'.format(
313                    self.naming.typename(p.type), p.name))
314            print('} *vin = in;')
315        if syscall.output.raw_members:
316            print('struct {')
317            for p in syscall.output.raw_members:
318                print('MEMBER({}, {});'.format(
319                    self.naming.typename(p.type), p.name))
320            print('} *vout = out;')
321
322        # Invoke the system call implementation function.
323        if not syscall.noreturn:
324            print('return')
325        print(self.naming.syscallname(syscall))
326        params = []
327        for p in syscall.input.raw_members:
328            params.append('vin->' + p.name)
329        for p in syscall.output.raw_members:
330            params.append('&vout->' + p.name)
331        print('(', ', '.join(params), ');')
332        if syscall.noreturn:
333            print('return 0;')
334        print('}\n')
335
336    def generate_foot(self, abi):
337        # Emit the actual system call table.
338        print('static {} (*syscalls[])(const void *, void *) = {{'.format(
339            self.naming.typename(abi.types['errno'])))
340        for idx in sorted(abi.syscalls):
341            syscall = abi.syscalls[idx]
342            print('do_{},'.format(syscall.name))
343        print('};')
344
345        super().generate_foot(abi)
346