1# Copyright (c) 2016 Nuxi (https://nuxi.nl/) and contributors.
2#
3# SPDX-License-Identifier: BSD-2-Clause
4
5from .abi import *
6from .generator import *
7
8
9def howmany(a, b):
10    return (a + b - 1) // b
11
12
13def roundup(a, b):
14    return howmany(a, b) * b
15
16
17class AsmVdsoGenerator(Generator):
18    def __init__(self, function_alignment, type_character):
19        super().__init__(comment_prefix='// ')
20        self._function_alignment = function_alignment
21        self._type_character = type_character
22
23    def generate_head(self, abi):
24        super().generate_head(abi)
25
26        # Macros for opening/closing function bodies.
27        print('#define ENTRY(name)      \\')
28        print('  .text;                 \\')
29        print('  .p2align %-13s \\' % (self._function_alignment + ';'))
30        print('  .global name;          \\')
31        print('  .type name, %cfunction; \\' % self._type_character)
32        print('name:')
33        print()
34        print('#define END(name) .size name, . - name')
35
36    def generate_syscalls(self, abi, syscalls):
37        for s in sorted(abi.syscalls):
38            self.generate_syscall(abi, abi.syscalls[s])
39
40    def generate_syscall(self, abi, syscall):
41        print()
42        print('ENTRY(cloudabi_sys_{})'.format(syscall.name))
43        self.generate_syscall_body(
44            abi.syscall_number(syscall), syscall.input.raw_members,
45            syscall.output.raw_members, syscall.noreturn)
46        print('END(cloudabi_sys_{})'.format(syscall.name))
47
48    def generate_foot(self, abi):
49        super().generate_foot(abi)
50
51
52class AsmVdsoCommonGenerator(AsmVdsoGenerator):
53    def generate_syscall_body(self, number, args_input, args_output, noreturn):
54        # Compute the number of registers/stack slots consumed by all of
55        # the input and output arguments. We assume that a pointer
56        # always fits in a single slot.
57        slots_input = 0
58        for m in args_input:
59            slots_input = (roundup(slots_input, self.register_align(m)) +
60                           self.register_count(m))
61        slots_output = len(args_output)
62
63        # Determine which registers correspond to these slots. If these
64        # lists are shorter than the number of slots, the remainder will
65        # be assumed to be on the stack.
66        regs_input = self.REGISTERS_PARAMS[:slots_input]
67        regs_output = self.REGISTERS_PARAMS[slots_input:][:slots_output]
68
69        # Some hardware architectures use different registers for system
70        # calls than for function calls (e.g., %rcx -> %r10 on x86-64).
71        for reg_old, reg_new in regs_input:
72            if reg_old != reg_new:
73                self.print_remap_register(reg_old, reg_new)
74
75        # Push the output registers containing the addresses where the
76        # values should be stored to the stack. The system call may
77        # clobber them.
78        if regs_output:
79            self.print_push_addresses([reg[0] for reg in regs_output])
80
81        # Execute system call.
82        self.print_syscall(number)
83
84        if not noreturn:
85            if args_output:
86                # Pop the output addresses that we previously pushed on
87                # the stack into spare registers.
88                regs_spare = list(self.REGISTERS_SPARE)
89                regs_output_popped = [regs_spare.pop(0) for _ in regs_output]
90                if regs_output_popped:
91                    self.print_pop_addresses(regs_output_popped)
92
93                # No further processing if the system call failed.
94                self.print_jump_syscall_failed('1f')
95
96                # Copy return values that are stored in registers to the
97                # outputs.
98                regs_returns = list(self.REGISTERS_RETURNS)
99                reg_tmp = None
100                for i, member in enumerate(args_output):
101                    if regs_output_popped:
102                        # Output address stored in spare register.
103                        reg = regs_output_popped.pop(0)
104                    else:
105                        # Output address stored on the stack. Load it
106                        # into a spare register that we reuse across
107                        # iteratations.
108                        if reg_tmp is None:
109                            reg_tmp = regs_spare.pop(0)
110                        reg = reg_tmp
111                        slot = slots_input - len(self.REGISTERS_PARAMS) + i
112                        self.print_load_address_from_stack(slot, reg)
113
114                    # Copy the value from one or more registers.
115                    for j in range(0, self.register_count(member)):
116                        self.print_store_output(member,
117                                                regs_returns.pop(0), reg, j)
118
119                self.print_retval_success()
120
121            self.print_return()
122
123
124class AsmVdsoAarch64Generator(AsmVdsoCommonGenerator):
125
126    REGISTERS_PARAMS = [('0', '0'), ('1', '1'), ('2', '2'), ('3', '3'),
127                        ('4', '4'), ('5', '5'), ('6', '6'), ('7', '7')]
128    REGISTERS_RETURNS = ['0', '1']
129    REGISTERS_SPARE = ['2', '3']
130
131    def __init__(self):
132        super().__init__(function_alignment='2', type_character='@')
133
134    @staticmethod
135    def register_align(member):
136        return 1
137
138    @staticmethod
139    def register_count(member):
140        return howmany(member.type.layout.size[1], 8)
141
142    @staticmethod
143    def print_push_addresses(regs):
144        if len(regs) == 1:
145            print('  str x{}, [sp, #-8]'.format(regs[0]))
146        else:
147            assert len(regs) == 2
148            print('  stp x{}, x{}, [sp, #-16]'.format(regs[0], regs[1]))
149
150    @staticmethod
151    def print_syscall(number):
152        print('  mov w8, #{}'.format(number))
153        print('  svc #0')
154
155    @staticmethod
156    def print_pop_addresses(regs):
157        if len(regs) == 1:
158            print('  ldr x{}, [sp, #-8]'.format(regs[0]))
159        else:
160            assert len(regs) == 2
161            print('  ldp x{}, x{}, [sp, #-16]'.format(regs[0], regs[1]))
162
163    @staticmethod
164    def print_jump_syscall_failed(label):
165        print('  b.cs ' + label)
166
167    @staticmethod
168    def print_store_output(member, reg_from, reg_to, index):
169        assert index == 0
170        size = member.type.layout.size[1]
171        print('  str {}{}, [x{}]'.format({
172            4: 'w',
173            8: 'x'
174        }[size], reg_from, reg_to))
175
176    @staticmethod
177    def print_retval_success():
178        print('  mov w0, wzr')
179        print('1:')
180
181    @staticmethod
182    def print_return():
183        print('  ret')
184
185
186class AsmVdsoArmv6Generator(AsmVdsoCommonGenerator):
187
188    REGISTERS_PARAMS = [('0', '0'), ('1', '1'), ('2', '2'), ('3', '3')]
189    REGISTERS_RETURNS = ['0', '1']
190    REGISTERS_SPARE = ['2', '3']
191
192    def __init__(self):
193        super().__init__(function_alignment='2', type_character='%')
194
195    @staticmethod
196    def register_align(member):
197        return howmany(member.type.layout.size[0], 4)
198
199    @staticmethod
200    def register_count(member):
201        return howmany(member.type.layout.size[0], 4)
202
203    @staticmethod
204    def print_push_addresses(regs):
205        if len(regs) == 1:
206            print('  str r{}, [sp, #-4]'.format(regs[0]))
207        else:
208            assert len(regs) == 2
209            print('  str r{}, [sp, #-4]'.format(regs[0]))
210            print('  str r{}, [sp, #-8]'.format(regs[1]))
211
212    @staticmethod
213    def print_syscall(number):
214        print('  mov ip, #{}'.format(number))
215        print('  swi 0')
216
217    @staticmethod
218    def print_pop_addresses(regs):
219        if len(regs) == 1:
220            print('  ldrcc r{}, [sp, #-4]'.format(regs[0]))
221        else:
222            assert len(regs) == 2
223            print('  ldrcc r{}, [sp, #-4]'.format(regs[0]))
224            print('  ldrcc r{}, [sp, #-8]'.format(regs[1]))
225
226    @staticmethod
227    def print_jump_syscall_failed(label):
228        pass
229
230    @staticmethod
231    def print_load_address_from_stack(slot, reg):
232        print('  ldrcc r{}, [sp, #{}]'.format(reg, slot * 4))
233
234    @staticmethod
235    def print_store_output(member, reg_from, reg_to, index):
236        size = member.type.layout.size[0]
237        print('  strcc {}{}, [r{}{}]'.format({
238            4: 'r',
239            8: 'r'
240        }[size], reg_from, reg_to, ', #{}'.format(index * 4)
241                                             if size > 4 else ''))
242
243    @staticmethod
244    def print_retval_success():
245        print('  movcc r0, #0')
246
247    @staticmethod
248    def print_return():
249        print('  bx lr')
250
251
252class AsmVdsoArmv6On64bitGenerator(AsmVdsoGenerator):
253    def __init__(self):
254        super().__init__(function_alignment='2', type_character='%')
255
256    @staticmethod
257    def load_argument(offset):
258        if offset < 16:
259            return 'r{}'.format(offset // 4)
260        print('  ldr r1, [sp, #{}]'.format(offset - 16))
261        return 'r1'
262
263    def generate_syscall_body(self, number, args_input, args_output, noreturn):
264        # When running on 64-bit operating systems, we need to ensure
265        # that the system call arguments are padded to 64 bits words, so
266        # that they are passed in properly to the system call handler.
267        #
268        # Determine the number of 64-bit slots we need to allocate on
269        # the stack to be able to store both the input and output
270        # arguments.
271        #
272        # TODO(ed): Unify this with AsmVdsoI686On64bitGenerator.
273        slots_input_padded = sum(
274            howmany(m.type.layout.size[1], 8) for m in args_input)
275        slots_stack = max(slots_input_padded, 2)
276
277        # Copy original arguments into a properly padded buffer.
278        offset_in = 0
279        offset_out = -8 * slots_stack
280        r0_is_zero = False
281        for member in args_input:
282            offset_in = roundup(offset_in, member.type.layout.size[0])
283            if member.type.layout.size[0] == member.type.layout.size[1]:
284                # Argument whose size doesn't differ between systems.
285                for i in range(0, howmany(member.type.layout.size[0], 4)):
286                    register = self.load_argument(offset_in + i * 4)
287                    print('  str {}, [sp, #{}]'.format(register,
288                                                       offset_out + i * 4))
289            else:
290                # Pointer or size_t. Zero-extend it to 64 bits.
291                assert member.type.layout.size[0] == 4
292                assert member.type.layout.size[1] == 8
293                register = self.load_argument(offset_in)
294                print('  str {}, [sp, #{}]'.format(register, offset_out))
295                if not r0_is_zero:
296                    print('  mov r0, #0')
297                    r0_is_zero = True
298                print('  str r0, [sp, #{}]'.format(offset_out + 4))
299            offset_in += roundup(member.type.layout.size[0], 4)
300            offset_out += roundup(member.type.layout.size[1], 8)
301        assert offset_out <= 0
302
303        # Store addresses of return values on the stack.
304        slots_out = []
305        offset_out = -8 * slots_stack
306        for i in range(0, len(args_output)):
307            reg_from = offset_in // 4 + i
308            if reg_from < 4:
309                # Move the value to a register that is retained.
310                offset_out -= 4
311                print('  str r{}, [sp, #{}]'.format(reg_from, offset_out))
312                slots_out.append(offset_out)
313            else:
314                # Value is stored on the stack. No need to preserve.
315                slots_out.append((reg_from - 4) * 4)
316
317        # Invoke system call.
318        print('  mov r0, #{}'.format(number))
319        print('  sub r2, sp, #{}'.format(slots_stack * 8))
320        print('  swi 0')
321
322        if not noreturn:
323            if args_output:
324                # Extract arguments from the padded buffer.
325                offset_in = -8 * slots_stack
326                for member, slot in zip(args_output, slots_out):
327                    size = member.type.layout.size[0]
328                    assert (size == member.type.layout.size[1]
329                            or (size == 4 and member.type.layout.size[1] == 8))
330                    assert size % 4 == 0
331                    print('  ldrcc r1, [sp, #{}]'.format(slot))
332                    for i in range(0, howmany(size, 4)):
333                        print(
334                            '  ldrcc r2, [sp, #{}]'.format(offset_in + i * 4))
335                        print('  strcc r2, [r1, #{}]'.format(i * 4))
336                    offset_in += roundup(member.type.layout.size[1], 8)
337                    offset_out += 4
338            print('  bx lr')
339
340
341class AsmVdsoI686Generator(AsmVdsoCommonGenerator):
342
343    REGISTERS_PARAMS = []
344    REGISTERS_RETURNS = ['ax', 'dx']
345    REGISTERS_SPARE = ['cx']
346
347    def __init__(self):
348        super().__init__(function_alignment='2, 0x90', type_character='@')
349
350    @staticmethod
351    def register_align(member):
352        return 1
353
354    @staticmethod
355    def register_count(member):
356        return howmany(member.type.layout.size[0], 4)
357
358    @staticmethod
359    def print_syscall(number):
360        print('  mov ${}, %eax'.format(number))
361        print('  int $0x80')
362
363    @staticmethod
364    def print_jump_syscall_failed(label):
365        print('  jc ' + label)
366
367    @staticmethod
368    def print_load_address_from_stack(slot, reg):
369        print('  mov {}(%esp), %e{}'.format(slot * 4 + 4, reg))
370
371    @staticmethod
372    def print_store_output(member, reg_from, reg_to, index):
373        size = member.type.layout.size[0]
374        print('  mov {}{}, {}(%e{})'.format({
375            4: '%e',
376            8: '%e'
377        }[size], reg_from, index * 4 if size > 4 else '', reg_to))
378
379    @staticmethod
380    def print_retval_success():
381        print('  xor %eax, %eax')
382        print('1:')
383
384    @staticmethod
385    def print_return():
386        print('  ret')
387
388
389class AsmVdsoI686On64bitGenerator(AsmVdsoGenerator):
390    def __init__(self):
391        super().__init__(function_alignment='2, 0x90', type_character='@')
392
393    def generate_syscall_body(self, number, args_input, args_output, noreturn):
394        print('  push %ebp')
395        print('  mov %esp, %ebp')
396
397        # When running on 64-bit operating systems, we need to ensure
398        # that the system call arguments are padded to 64 bits words, so
399        # that they are passed in properly to the system call handler.
400        #
401        # Determine the number of 64-bit slots we need to allocate on
402        # the stack to be able to store both the input and output
403        # arguments.
404        #
405        # TODO(ed): Unify this with AsmVdsoArmv6On64bitGenerator.
406        slots_input_padded = sum(
407            howmany(m.type.layout.size[1], 8) for m in args_input)
408        slots_stack = max(slots_input_padded, 2)
409
410        # Copy original arguments into a properly padded buffer.
411        offset_in = 8
412        offset_out = -8 * slots_stack
413        for member in args_input:
414            if member.type.layout.size[0] == member.type.layout.size[1]:
415                # Argument whose size doesn't differ between systems.
416                for i in range(0, howmany(member.type.layout.size[0], 4)):
417                    print('  mov {}(%ebp), %ecx'.format(offset_in + i * 4))
418                    print('  mov %ecx, {}(%ebp)'.format(offset_out + i * 4))
419            else:
420                # Pointer or size_t. Zero-extend it to 64 bits.
421                assert member.type.layout.size[0] == 4
422                assert member.type.layout.size[1] == 8
423                print('  mov {}(%ebp), %ecx'.format(offset_in))
424                print('  mov %ecx, {}(%ebp)'.format(offset_out))
425                print('  movl $0, {}(%ebp)'.format(offset_out + 4))
426            offset_in += roundup(member.type.layout.size[0], 4)
427            offset_out += roundup(member.type.layout.size[1], 8)
428        assert offset_in == 8 + sum(
429            roundup(m.type.layout.size[0], 4) for m in args_input)
430        assert offset_out <= 0
431
432        # Invoke system call, setting %ecx to the padded buffer.
433        print('  mov ${}, %eax'.format(number))
434        print('  mov %ebp, %ecx')
435        print('  sub ${}, %ecx'.format(slots_stack * 8))
436        print('  int $0x80')
437
438        if not noreturn:
439            if args_output:
440                print('  test %eax, %eax')
441                print('  jnz 1f')
442
443                # Extract arguments from the padded buffer.
444                offset_in = -8 * slots_stack
445                offset_out = 8 + sum(
446                    roundup(m.type.layout.size[0], 4) for m in args_input)
447                for member in args_output:
448                    size = member.type.layout.size[0]
449                    assert (size == member.type.layout.size[1]
450                            or (size == 4 and member.type.layout.size[1] == 8))
451                    assert size % 4 == 0
452                    print('  mov {}(%ebp), %ecx'.format(offset_out))
453                    for i in range(0, howmany(size, 4)):
454                        print('  mov {}(%ebp), %edx'.format(offset_in + i * 4))
455                        print('  mov %edx, {}(%ecx)'.format(i * 4))
456                    offset_in += roundup(member.type.layout.size[1], 8)
457                    offset_out += 4
458
459                print('1:')
460
461            print('  pop %ebp')
462            print('  ret')
463
464
465class AsmVdsoX86_64Generator(AsmVdsoCommonGenerator):
466
467    REGISTERS_PARAMS = [('di', 'di'), ('si', 'si'), ('dx', 'dx'), ('cx', '10'),
468                        ('8', '8'), ('9', '9')]
469    REGISTERS_RETURNS = ['ax', 'dx']
470    REGISTERS_SPARE = ['cx', 'si', 'di', '8', '9', '10', '11']
471
472    def __init__(self):
473        super().__init__(function_alignment='4, 0x90', type_character='@')
474
475    @staticmethod
476    def register_align(member):
477        return 1
478
479    @staticmethod
480    def register_count(member):
481        return howmany(member.type.layout.size[1], 8)
482
483    @staticmethod
484    def print_remap_register(reg_old, reg_new):
485        print('  mov %r{}, %r{}'.format(reg_old, reg_new))
486
487    @staticmethod
488    def print_push_addresses(regs):
489        for reg in regs:
490            print('  push %r{}'.format(reg))
491
492    @staticmethod
493    def print_syscall(number):
494        print('  mov ${}, %eax'.format(number))
495        print('  syscall')
496
497    @staticmethod
498    def print_pop_addresses(regs):
499        for reg in reversed(regs):
500            print('  pop %r{}'.format(reg))
501
502    @staticmethod
503    def print_jump_syscall_failed(label):
504        print('  jc ' + label)
505
506    @staticmethod
507    def print_load_address_from_stack(slot, reg):
508        print('  mov {}(%rsp), %r{}'.format(slot * 8 + 8, reg))
509
510    @staticmethod
511    def print_store_output(member, reg_from, reg_to, index):
512        assert index == 0
513        size = member.type.layout.size[1]
514        print('  mov {}{}, (%r{})'.format({
515            4: '%e',
516            8: '%r'
517        }[size], reg_from, reg_to))
518
519    @staticmethod
520    def print_retval_success():
521        print('  xor %eax, %eax')
522        print('1:')
523
524    @staticmethod
525    def print_return():
526        print('  ret')
527