1# Assemble Arybo IR into ASM thanks to LLVM
2# Map symbol names to register
3# Warning: this tries to do its best to save modified temporary registers.
4# There might be errors while doing this. The idea is not to regenerate clean
5# binaries, but to help the reverser!
6
7try:
8    import llvmlite.ir as ll
9    import llvmlite.binding as llvm
10    import ctypes
11    llvmlite_available = True
12    __llvm_initialized = False
13except ImportError:
14    llvmlite_available = False
15
16import six
17import collections
18
19import arybo.lib.mba_exprs as EX
20from arybo.lib.exprs_passes import lower_rol_ror, CachePass
21
22def IntType(n):
23    return ll.IntType(int(n))
24
25class ToLLVMIr(CachePass):
26    def __init__(self, sym_to_value, IRB):
27        super(ToLLVMIr,self).__init__()
28        self.IRB = IRB
29        self.sym_to_value = sym_to_value
30        self.values = {}
31
32    def visit_wrapper(self, e, cb):
33        ret = super(ToLLVMIr, self).visit_wrapper(e, cb)
34        if not isinstance(ret, tuple):
35            return (ret,self.IRB.block)
36        else:
37            return ret
38
39    def visit_value(self, e):
40        return EX.visit(e, self)[0]
41
42    def visit_Cst(self, e):
43        return ll.Constant(IntType(e.nbits), e.n)
44
45    def visit_BV(self, e):
46        name = e.v.name
47        value = self.sym_to_value.get(name, None)
48        if value is None:
49            raise ValueError("unable to map BV name '%s' to an LLVM value!" % name)
50        # TODO: check value bit-size
51        #ret,nbits = value
52        #if e.nbits != nbits:
53        #    raise ValueError("bit-vector is %d bits, expected %d bits" % (e.nbits, nbits))
54        return value
55
56    def visit_Not(self, e):
57        return self.IRB.not_(self.visit_value(e.arg))
58
59    def visit_ZX(self, e):
60        return self.IRB.zext(self.visit_value(e.arg), IntType(e.n))
61
62    def visit_SX(self, e):
63        return self.IRB.sext(self.visit_value(e.arg), IntType(e.n))
64
65    def visit_Concat(self, e):
66        # Generate a suite of OR + shifts
67        # TODO: pass that lowers concat
68        arg0 = e.args[0]
69        ret = self.visit_value(arg0)
70        type_ = IntType(e.nbits)
71        ret = self.IRB.zext(ret, type_)
72        cur_bits = arg0.nbits
73        for a in e.args[1:]:
74            cur_arg = self.IRB.zext(self.visit_value(a), type_)
75            ret = self.IRB.or_(ret,
76                self.IRB.shl(cur_arg, ll.Constant(type_, cur_bits)))
77            cur_bits += a.nbits
78        return ret
79
80    def visit_Slice(self, e):
81        # TODO: pass that lowers slice
82        ret = self.visit_value(e.arg)
83        idxes = e.idxes
84        # Support only sorted indxes for now
85        if idxes != list(range(idxes[0], idxes[-1]+1)):
86            raise ValueError("slice indexes must be continuous and sorted")
87        if idxes[0] != 0:
88            ret = self.IRB.lshr(ret, ll.Constant(IntType(e.arg.nbits), idxes[0]))
89        return self.IRB.trunc(ret, IntType(len(idxes)))
90
91    def visit_Broadcast(self, e):
92        # TODO: pass that lowers broadcast
93        # left-shift to get the idx as the MSB, and them use an arithmetic
94        # right shift of nbits-1
95        type_ = IntType(e.nbits)
96        ret = self.visit_value(e.arg)
97        ret = self.IRB.zext(ret, type_)
98        ret = self.IRB.shl(ret, ll.Constant(type_, e.nbits-e.idx-1))
99        return self.IRB.ashr(ret, ll.Constant(type_, e.nbits-1))
100
101    def visit_nary_args(self, e, op):
102        return op(*(self.visit_value(a) for a in e.args))
103
104    def visit_BinaryOp(self, e):
105        ops = {
106            EX.ExprAdd: self.IRB.add,
107            EX.ExprSub: self.IRB.sub,
108            EX.ExprMul: self.IRB.mul,
109            EX.ExprShl: self.IRB.shl,
110            EX.ExprLShr: self.IRB.lshr,
111            EX.ExprAShr: self.IRB.ashr
112        }
113        op = ops[type(e)]
114        return self.visit_nary_args(e, op)
115
116    def visit_Div(self, e):
117        return self.visit_nary_args(e, self.IRB.sdiv if e.is_signed else self.IRB.udiv)
118
119    def visit_Rem(self, e):
120        return self.visit_nary_args(e, self.IRB.srem if e.is_signed else self.IRB.urem)
121
122    def visit_NaryOp(self, e):
123        ops = {
124            EX.ExprXor: self.IRB.xor,
125            EX.ExprAnd: self.IRB.and_,
126            EX.ExprOr: self.IRB.or_,
127        }
128        op = ops[type(e)]
129        return self.visit_nary_args(e, op)
130
131    def visit_Cmp(self, e):
132        f = self.IRB.icmp_signed if e.is_signed else self.IRB.icmp_unsigned
133        cmp_op = {
134            EX.ExprCmp.OpEq:  '==',
135            EX.ExprCmp.OpNeq: '!=',
136            EX.ExprCmp.OpLt:  '<',
137            EX.ExprCmp.OpLte: '<=',
138            EX.ExprCmp.OpGt:  '>',
139            EX.ExprCmp.OpGte: '>='
140        }
141        return f(cmp_op[e.op], self.visit_value(e.X), self.visit_value(e.Y))
142
143    def visit_Cond(self, e):
144        cond = self.visit_value(e.cond)
145        bb_name = self.IRB.basic_block.name
146        ifb = self.IRB.append_basic_block(bb_name + ".if")
147        elseb = self.IRB.append_basic_block(bb_name + ".else")
148        endb = self.IRB.append_basic_block(bb_name + ".endif")
149        self.IRB.cbranch(cond, ifb, elseb)
150
151        self.IRB.position_at_end(ifb)
152        ifv,ifb = EX.visit(e.a, self)
153        self.IRB.branch(endb)
154
155        self.IRB.position_at_end(elseb)
156        elsev,elseb = EX.visit(e.b, self)
157        self.IRB.branch(endb)
158
159        self.IRB.position_at_end(endb)
160        ret = self.IRB.phi(IntType(e.nbits))
161        ret.add_incoming(ifv, ifb)
162        ret.add_incoming(elsev, elseb)
163        return ret,endb
164
165def llvm_get_target(triple_or_target=None):
166    global __llvm_initialized
167    if not __llvm_initialized:
168        # Lazy initialisation
169        llvm.initialize()
170        llvm.initialize_all_targets()
171        llvm.initialize_all_asmprinters()
172        __llvm_initialized = True
173
174    if isinstance(triple_or_target, llvm.Target):
175        return triple_or_target
176    if triple_or_target is None:
177        return llvm.Target.from_default_triple()
178    return llvm.Target.from_triple(triple_or_target)
179
180def _create_execution_engine(M, target):
181    target_machine = target.create_target_machine()
182    engine = llvm.create_mcjit_compiler(M, target_machine)
183    return engine
184
185def to_llvm_ir(exprs, sym_to_value, IRB):
186    if not llvmlite_available:
187        raise RuntimeError("llvmlite module unavailable! can't assemble to LLVM IR...")
188
189    if not isinstance(exprs, collections.abc.Iterable):
190        exprs = (exprs,)
191
192    ret = None
193    visitor = ToLLVMIr(sym_to_value, IRB)
194    for e in exprs:
195        e = lower_rol_ror(e)
196        ret = visitor.visit_value(e)
197    return ret
198
199def to_llvm_function(exprs, vars_, name="__arybo"):
200    if not llvmlite_available:
201        raise RuntimeError("llvmlite module unavailable! can't assemble to LLVM IR...")
202
203    if not isinstance(exprs, collections.abc.Iterable):
204        exprs = (exprs,)
205
206    M = ll.Module()
207    args_types = [IntType(v.nbits) for v in vars_]
208    fntype = ll.FunctionType(IntType(exprs[-1].nbits), args_types)
209    func = ll.Function(M, fntype, name=name)
210    func.attributes.add("nounwind")
211    BB = func.append_basic_block()
212
213    IRB = ll.IRBuilder()
214    IRB.position_at_end(BB)
215
216    sym_to_value = {}
217    for i,v in enumerate(vars_):
218        arg = func.args[i]
219        arg.name = v.name
220        sym_to_value[v.name] = arg
221    ret = to_llvm_ir(exprs, sym_to_value, IRB)
222    IRB.ret(ret)
223    return M
224
225def asm_module(exprs, dst_reg, sym_to_reg, triple_or_target=None):
226    '''
227    Generate an LLVM module for a list of expressions
228
229    Arguments:
230      * See :meth:`arybo.lib.exprs_asm.asm_binary` for a description of the list of arguments
231
232    Output:
233      * An LLVM module with one function named "__arybo", containing the
234        translated expression.
235
236    See :meth:`arybo.lib.exprs_asm.asm_binary` for an usage example.
237    '''
238
239    if not llvmlite_available:
240        raise RuntimeError("llvmlite module unavailable! can't assemble...")
241
242    target = llvm_get_target(triple_or_target)
243
244    M = ll.Module()
245    fntype = ll.FunctionType(ll.VoidType(), [])
246    func = ll.Function(M, fntype, name='__arybo')
247    func.attributes.add("naked")
248    func.attributes.add("nounwind")
249    BB = func.append_basic_block()
250
251    IRB = ll.IRBuilder()
252    IRB.position_at_end(BB)
253
254    sym_to_value = {sym: IRB.load_reg(IntType(reg[1]), reg[0], reg[0]) for sym,reg in six.iteritems(sym_to_reg)}
255
256    ret = to_llvm_ir(exprs, sym_to_value, IRB)
257    IRB.store_reg(ret, IntType(dst_reg[1]), dst_reg[0])
258    # See https://llvm.org/bugs/show_bug.cgi?id=15806
259    IRB.unreachable()
260
261    return M
262
263def asm_binary(exprs, dst_reg, sym_to_reg, triple_or_target=None):
264    '''
265    Compile and assemble an expression for a given architecture.
266
267    Arguments:
268      * *exprs*: list of expressions to convert. This can represent a graph of
269        expressions.
270      * *dst_reg*: final register on which to store the result of the last
271        expression. This is represented by a tuple ("reg_name", reg_size_bits).
272        Example: ("rax", 64)
273      * *sym_to_reg*: a dictionnary that maps Arybo variable name to registers
274        (described as tuple, see *dst_reg*). Example: {"x": ("rdi",64), "y": ("rsi", 64)}
275      * *triple_or_target*: LLVM architecture triple to use. Use by default the
276        host architecture. Example: "x86_64-unknown-unknown"
277
278    Output:
279      * binary stream of the assembled expression for the given target
280
281    Here is an example that will compile and assemble "x+y" for x86_64::
282
283        from arybo.lib import MBA
284        from arybo.lib import mba_exprs
285        from arybo.lib.exprs_asm import asm_binary
286        mba = MBA(64)
287        x = mba.var("x")
288        y = mba.var("y")
289        e = mba_exprs.ExprBV(x) + mba_exprs.ExprBV(y)
290        code = asm_binary([e], ("rax", 64), {"x": ("rdi", 64), "y": ("rsi", 64)}, "x86_64-unknown-unknown")
291        print(code.hex())
292
293    which outputs ``488d0437`` (which is equivalent to ``lea rax,[rdi+rsi*1]``).
294    '''
295    if not llvmlite_available:
296        raise RuntimeError("llvmlite module unavailable! can't assemble...")
297
298    target = llvm_get_target(triple_or_target)
299    M = asm_module(exprs, dst_reg, sym_to_reg, target)
300
301    # Use LLVM to compile the '__arybo' function. As the function is naked and
302    # is the only, we just got to dump the .text section to get the binary
303    # assembly.
304    # No need for keystone or whatever hype stuff. llvmlite does the job.
305
306    M = llvm.parse_assembly(str(M))
307    M.verify()
308    target_machine = target.create_target_machine()
309    obj_bin = target_machine.emit_object(M)
310    obj = llvm.ObjectFileRef.from_data(obj_bin)
311    for s in obj.sections():
312        if s.is_text():
313            return s.data()
314    raise RuntimeError("unable to get the assembled binary!")
315