1#! /usr/bin/env python
2## vim:set ts=4 sw=4 et: -*- coding: utf-8 -*-
3#
4#  cleanasm.py --
5#
6#  This file is part of the UPX executable compressor.
7#
8#  Copyright (C) 1996-2020 Markus Franz Xaver Johannes Oberhumer
9#  All Rights Reserved.
10#
11#  UPX and the UCL library are free software; you can redistribute them
12#  and/or modify them under the terms of the GNU General Public License as
13#  published by the Free Software Foundation; either version 2 of
14#  the License, or (at your option) any later version.
15#
16#  This program is distributed in the hope that it will be useful,
17#  but WITHOUT ANY WARRANTY; without even the implied warranty of
18#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19#  GNU General Public License for more details.
20#
21#  You should have received a copy of the GNU General Public License
22#  along with this program; see the file COPYING.
23#  If not, write to the Free Software Foundation, Inc.,
24#  59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
25#
26#  Markus F.X.J. Oberhumer              Laszlo Molnar
27#  <markus@oberhumer.com>               <ezerotven+github@gmail.com>
28#
29
30
31import getopt, os, re, string, sys
32
33
34class opts:
35    label_prefix = ".L"
36    verbose = 0
37    # optimizer flags
38    auto_inline = 1
39    call_rewrite = 1
40    loop_rewrite = 1
41    mov_rewrite = 1
42
43
44inline_map = {
45    "__aNNalshl":    ["M_aNNalshl", 1],
46    "__aNahdiff":    ["M_aNahdiff", 1],
47    "__PIA":         ["M_PIA", 999],
48    "__PTS":         ["M_PTS", 999],
49    "__PTC":         ["M_PTC", 999],
50    "__U4M":         ["M_U4M", 999],
51}
52
53
54# /***********************************************************************
55# // main
56# ************************************************************************/
57
58def main(argv):
59    shortopts, longopts = "qv", [
60        "label-prefix=", "quiet", "verbose"
61    ]
62    xopts, args = getopt.gnu_getopt(argv[1:], shortopts, longopts)
63    for opt, optarg in xopts:
64        if 0: pass
65        elif opt in ["-q", "--quiet"]: opts.verbose = opts.verbose - 1
66        elif opt in ["-v", "--verbose"]: opts.verbose = opts.verbose + 1
67        elif opt in ["--label-prefix"]: opts.label_prefix = optarg
68        else: assert 0, ("getopt problem:", opt, optarg, xopts, args)
69
70    #
71    assert opts.label_prefix
72    assert len(args) == 2
73    ifile = args[0]
74    ofile = args[1]
75    # read ifile
76    lines = open(ifile, "rb").readlines()
77    lines = filter(None, map(string.rstrip, lines))
78    #
79    #
80    def inst_has_label(inst):
81        return inst in [
82            "call", "ja", "jae", "jb", "jbe", "jcxz", "je",
83            "jg", "jge", "jl", "jle", "jmp", "jne", "loop",
84        ]
85    labels = {}
86    def parse_label(inst, args):
87        k = v = None
88        m = re.search(r"^(.*?)\b(2|R_386_PC16)\s+(__\w+)$", args)
89        if m and k is None:
90            # external 2-byte label
91            k, v = m.group(3).strip(), [1, 2, None, 0]
92        m = re.search("^0x([0-9a-z]+)$", args)
93        if m and k is None:
94            # local label
95            k, v = m.group(1).strip(), [0, 0, None, 0]
96        m = re.search("^([0-9a-z]+)\s+<", args)
97        if m and k is None:
98            # local label
99            k, v = m.group(1).strip(), [0, 0, None, 0]
100        assert k and v, (inst, args)
101        v[2] = k                # new name
102        if labels.has_key(k):
103            assert labels[k][:2] == v[:2]
104        return k, v
105    def add_label(k, v):
106        if labels.has_key(k):
107            assert labels[k][:2] == v[:2]
108        else:
109            labels[k] = v
110        labels[k][3] += 1       # usage counter
111        return k
112
113    olines = []
114    def omatch(pos, mlen, m, debug=0):
115        assert len(m) >= abs(mlen)
116        def sgn(x):
117            if x < 0: return -1
118            if x > 0: return  1
119            return 0
120        def match(a, b):
121            if b is None:
122                return False
123            if "^" in a or "*" in a or "$" in a:
124                # regexp
125                return re.search(a, b.lower())
126            else:
127                return a.lower() == b.lower()
128        mpos = []
129        while len(mpos) != abs(mlen):
130            if pos < 0 or pos >= len(olines):
131                return []
132            o = olines[pos]
133            if o[1] != "*DEL*":
134                mpos.append(pos)
135            pos += sgn(mlen)
136        if mlen < 0:
137            mpos.reverse()
138        if debug and 1: print mlen, m, [olines[x] for x in mpos]
139        dpos = []
140        i = -abs(mlen)
141        while i < 0:
142            pos = mpos[i]
143            o = olines[pos]
144            assert o[1] != "*DEL*"
145            assert len(m[i]) == 2, (i, m)
146            m0 = match(m[i][0], o[1])
147            m1 = match(m[i][1], o[2])
148            if not m0 or not m1:
149                return []
150            dpos.append([pos, m0, m1])
151            i += 1
152        assert len(dpos) == abs(mlen)
153        return dpos
154    def orewrite_inst(i, inst, args, dpos):
155        for pos, m0, m1 in dpos:
156            olines[pos][1] = "*DEL*"
157        olines[i][1] = inst
158        olines[i][2] = args
159        olines[i][3] = None
160    def orewrite_call(i, k, v, dpos):
161        for pos, m0, m1 in dpos:
162            olines[pos][1] = "*DEL*"
163        v[2] = k
164        olines[i][2] = None
165        olines[i][3] = add_label(k, v)
166
167    #
168    # pass 1
169    func = None
170    for i in range(len(lines)):
171        l = lines[i]
172        m = re.search(r"^0{8,16}\s*<(\.text\.)?(\w+)>:", l)
173        if m:
174            func = re.sub(r"^_+|_+$", "", m.group(2))
175        if not func in ["LzmaDecode"]:
176            continue
177        m = re.search(r"^(\s*[0-9a-z]+):\s+(\w+)(.*)", l)
178        if not m:
179            continue
180        label = m.group(1).strip()
181        inst = m.group(2).strip()
182        args = ""
183        if m.group(3):
184            args = m.group(3).strip()
185        if not inst_has_label(inst):
186            def hex2int(m): return str(int(m.group(0), 16))
187            args = re.sub(r"\b0x[0-9a-fA-F]+\b", hex2int, args)
188        #
189        if 1 and inst in ["movl",] and re.search(r"\b[de]s\b", args):
190            # work around a bug in objdump 2.17 (fixed in binutils 2.18)
191            inst = "mov"
192        m = re.search(r"^(.+?)\b(0|0x0)\s+(\w+):\s+(1|2|R_386_16|R_386_PC16)\s+(__\w+)$", args)
193        if m:
194            # 1 or 2 byte reloc
195            args = m.group(1) + m.group(5)
196        olines.append([label, inst, args, None])
197    #
198    # pass 2
199    for i in range(len(olines)):
200        label, inst, args, args_label = olines[i]
201        #
202        if inst == "*DEL*":
203            continue
204        #
205        if opts.call_rewrite and inst in ["call"]:
206            k, v = parse_label(inst, args)
207            if v[:2] == [1, 2]:     # external 2-byte
208                if k == "__aNahdiff":
209                    s = [
210                        ["push", "word ptr [bp+8]"],
211                        ["push", "word ptr [bp+6]"],
212                        ["push", r"word ptr \[bp([+-](\d+))\]$"],
213                        ["push", r"word ptr \[bp([+-](\d+))\]$"],
214                    ]
215                    dpos = omatch(i-1, -4, s)
216                    if dpos:
217                        orewrite_inst(i, "*DEL*", "", dpos)
218                        continue
219                if k in ["__LMUL", "__U4M",]:
220                    s1 = [
221                        ["mov",  "bx,768"],     # 0x300
222                        ["xor",  "cx,cx"],
223                    ]
224                    s2 = [
225                        ["shl",  "ax,1"],
226                        ["rcl",  "dx,1"],
227                    ]
228                    dpos1 = omatch(i-1, -2, s1)
229                    dpos2 = omatch(i+1,  2, s2)
230                    if dpos1 and dpos2:
231                        orewrite_inst(i, "M_U4M_dxax_0x0600", "", dpos1 + dpos2)
232                        continue
233                    s = [
234                        ["mov",  "bx,word ptr [bx]"],
235                        ["xor",  "cx,cx"],
236                    ]
237                    dpos = omatch(i-1, -2, s, debug=0)
238                    if 0 and dpos:
239                        orewrite_inst(i, "M_U4M_dxax_00bx_ptr", "", dpos)
240                        continue
241                    dpos = omatch(i-1, -1, s)
242                    if dpos:
243                        orewrite_inst(i, "M_U4M_dxax_00bx", "", dpos)
244                        continue
245                if k == "__PIA":
246                    s = [
247                        ["mov",  "bx,1"],
248                        ["xor",  "cx,cx"],
249                    ]
250                    dpos = omatch(i-1, -2, s)
251                    if dpos:
252                        orewrite_inst(i, "M_PIA1", "", dpos)
253                        continue
254                if k == "__PTC":
255                    s = [
256                        ["jne",  "(.*)"],
257                    ]
258                    dpos = omatch(i+1, 1, s)
259                    if dpos:
260                        olines[i][1] = "M_PTC_JNE"
261                        k, v = parse_label("jne", dpos[0][2].group(1))
262                        orewrite_call(i, k, v, dpos)
263                        continue
264        if opts.loop_rewrite and inst in ["loop"]:
265            s = [
266                ["mov",  r"^c[lx],11$"],
267                ["shr",  "dx,1"],
268                ["rcr",  "ax,1"],
269            ]
270            dpos = omatch(i-1, -3, s)
271            if dpos:
272                orewrite_inst(i, "M_shrd_11", "", dpos)
273                continue
274            s = [
275                ["mov",  r"^c[lx],8$"],
276                ["shl",  "ax,1"],
277                ["rcl",  "dx,1"],
278            ]
279            dpos = omatch(i-1, -3, s)
280            if dpos:
281                orewrite_inst(i, "M_shld_8", "", dpos)
282                continue
283            s1 = [
284                ["mov",  r"^c[lx],8$"],
285                ["shl",  "si,1"],
286                ["rcl",  "di,1"],
287            ]
288            s2 = [
289                ["les",  r"^bx,dword ptr \[bp([+-](\d+))\]$"],
290            ]
291            dpos1 = omatch(i-1, -3, s1)
292            dpos2 = omatch(i+1,  1, s2)
293            if 1 and dpos1 and dpos2:
294                # bx and cx are free for use
295                orewrite_inst(i, "M_shld_disi_8_bxcx", "", dpos1)
296                continue
297            s1 = [
298                ["mov",  "ax,si"],
299                ["mov",  r"^c[lx],8$"],
300                ["shl",  "ax,1"],
301                ["rcl",  "di,1"],
302            ]
303            s2 = [
304                ["mov",  "si,ax"],
305                ["les",  r"^bx,dword ptr \[bp([+-](\d+))\]$"],
306            ]
307            dpos1 = omatch(i-1, -4, s1)
308            dpos2 = omatch(i+1,  2, s2)
309            if 1 and dpos1 and dpos2:
310                # bx and cx are free for use
311                orewrite_inst(i, "M_shld_diax_8_bxcx", "", dpos1[-3:])
312                continue
313            s1 = [
314                ["mov",  r"^c[lx],8$"],
315                ["shl",  r"^word ptr \[bp([+-](\d+))\],1$"],
316                ["rcl",  r"^word ptr \[bp([+-](\d+))\],1$"],
317            ]
318            s2 = [
319                ["mov",  r"^dx,word ptr"],
320                ["mov",  r"^ax,word ptr"],
321            ]
322            s3 = [
323                ["mov",  r"^ax,word ptr"],
324                ["mov",  r"^dx,word ptr"],
325            ]
326            dpos1 = omatch(i-1, -3, s1)
327            dpos2 = omatch(i+1,  2, s2)
328            dpos3 = omatch(i+1,  2, s3)
329            if dpos1 and (dpos2 or dos3):
330                bp_dx, bp_ax = dpos1[-1][2].group(1), dpos1[-2][2].group(1)
331                m = "M_shld_8_bp %s %s" % (bp_dx, bp_ax)
332                orewrite_inst(i, m, "", dpos1)
333                continue
334            s1 = [
335                ["mov",  r"^word ptr \[bp([+-](\d+))\],si$"],
336                ["mov",  r"^word ptr \[bp([+-](\d+))\],di$"],
337                ["mov",  r"^c[lx],11$"],
338                ["shr",  r"^word ptr \[bp([+-](\d+))\],1$"],
339                ["rcr",  r"^word ptr \[bp([+-](\d+))\],1$"],
340            ]
341            s2 = [
342                ["mov",  r"^bx,word ptr"],
343                ["mov",  r"^bx,word ptr"],
344                ["mov",  r"^ax,word ptr \[bp([+-](\d+))\]$"],
345                ["mov",  r"^dx,word ptr \[bp([+-](\d+))\]$"],
346            ]
347            dpos1 = omatch(i-1, -5, s1)
348            dpos2 = omatch(i+1,  4, s2)
349            if dpos1 and dpos2:
350                bp_dx, bp_ax = dpos1[-2][2].group(1), dpos1[-1][2].group(1)
351                bp_di, bp_si = dpos1[-4][2].group(1), dpos1[-5][2].group(1)
352                assert bp_dx == dpos2[-1][2].group(1)
353                assert bp_ax == dpos2[-2][2].group(1)
354                assert bp_dx == bp_di
355                assert bp_ax == bp_si
356                m = "M_shrd_11_disi_bp %s %s" % (bp_dx, bp_ax)
357                orewrite_inst(i, m, "", dpos1 + dpos2[-2:])
358                continue
359        if opts.mov_rewrite and inst in ["mov"]:
360            s = [
361                ["mov",  r"^al,byte ptr \[(di|si)\]$"],
362                ["xor",  r"^ah,ah$"],
363                ["mov",  r"^word ptr \[bp([+-](\d+))\],ax$"],
364                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
365                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
366                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
367                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
368                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
369                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
370                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
371                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
372                ["mov",  r"^word ptr \[bp([+-](\d+))\],(0|1)$"],
373                ["mov",  r"^bx,word ptr \[bp([+-](\d+))\]$"],
374                ["mov",  r"^word ptr \[bx\],(0)$"],
375                ["mov",  r"^word ptr \[bx([+-](\d+))\],(0)$"],
376                ["mov",  r"^bx,word ptr \[bp([+-](\d+))\]$"],
377                ["mov",  r"^word ptr \[bx\],(0)$"],
378                ["mov",  r"^word ptr \[bx([+-](\d+))\],(0)$"],
379                ["mov",  r"^dl,byte ptr \[(di|si)([+-](\d+))\]$"],
380                ["xor",  r"^dh,dh$"],
381                ["mov",  r"^cx,ax$"],
382            ]
383            dpos = omatch(i, -len(s), s)
384            if dpos:
385                ipos, n_del = 16, 0
386                pos0 = dpos[0][0]
387                r = []
388                for pos, m0, m1 in dpos:
389                    assert pos == pos0 + len(r)
390                    r.append([olines[pos][1], olines[pos][2]])
391                z0 = r[0]; z1 = r[2]; del r[:3]
392                r.insert(0, ["xor", "ax,ax"])
393                r.insert(ipos, z0); r.insert(ipos + 1, z1)
394                i = 0
395                while i < len(r):
396                    inst, args = r[i]
397                    if inst == "mov" and args.endswith(",0"):
398                        r[i] = [inst, args[:-1] + "ax"]
399                    elif inst == "mov" and args.endswith(",1"):
400                        assert i < ipos
401                        r.insert(ipos, [inst, args[:-1] + "ax"])
402                        del r[i]; i -= 1; n_del += 1
403                    i += 1
404                assert len(r) == len(dpos)
405                pos = pos0
406                for inst, args in r:
407                    ##print pos-pos0, inst, args
408                    olines[pos][1] = inst
409                    olines[pos][2] = args
410                    pos += 1
411                if n_del:
412                    olines.insert(pos0 + ipos - n_del, [None, "inc", "ax", None])
413                continue
414        #
415        if inst_has_label(inst):
416            k, v = parse_label(inst, args)
417            olines[i][2] = None
418            olines[i][3] = add_label(k, v)
419    #
420    # pass 3
421    digits, i = 1, len(labels)
422    while i >= 10:
423        digits += 1
424        i /= 10
425    format = "%s0%dd" % ("%", digits)
426    counter = 0
427    for i in range(len(olines)):
428        label, inst, args, args_label = olines[i]
429        # rewrite local labels
430        v = labels.get(label)
431        if v is not None:
432            assert v[:3] == [0, 0, label], (label, v)
433            v[2] = opts.label_prefix + format % counter
434            counter += 1
435        # handle inlining
436        if opts.auto_inline and inst == "call":
437            v = labels[args_label]
438            if v[:2] == [1, 2]:     # external 2-byte
439                x = inline_map.get(v[2])
440                if x and v[3] <= x[1]:       # max. number of calls
441                    ##print "inline", v, x
442                    if x:
443                        olines[i][1] = x[0]
444                        olines[i][2] = "/* inlined */"
445                        olines[i][2] = ""
446                        olines[i][3] = None
447    #
448    # write ofile
449    ofp = open(ofile, "wb")
450    current_label = None
451    for label, inst, args, args_label in olines:
452        if labels.has_key(label):
453            current_label = labels[label][2]
454            if opts.verbose:
455                ofp.write("%s: /* %d */\n" % (labels[label][2], labels[label][3]))
456            else:
457                ofp.write("%s:\n" % (labels[label][2]))
458        if inst == "*DEL*":
459            continue
460        if 1 and current_label in [".Lf122", ".Lf123", ".Lf124", ".Ls122", ".Ls123", ".Ls124"]:
461            continue
462        if args_label:
463            if opts.verbose:
464                args = "%s /* %d */" % (labels[args_label][2], labels[args_label][3])
465            else:
466                args = labels[args_label][2]
467        if 0:
468            # remove unneeded "byte/word/dword ptr"
469            # [this works, but disabled for now as we gain nothing]
470            if re.search(r"\bbyte ptr ", args):
471                if re.search(r"^[abcd][hl],", args): args = args.replace("byte ptr ", "")
472                if re.search(r",[abcd][hl]$", args): args = args.replace("byte ptr ", "")
473            if re.search(r"\bword ptr ", args):
474                if re.search(r"^[abcds][ix],", args): args = args.replace("word ptr ", "")
475                if re.search(r",[abcds][ix]$", args): args = args.replace("word ptr ", "")
476            if re.search(r"\bdword ptr ", args):
477                if re.search(r"^[abcd][x],",  args): args = args.replace("dword ptr ", "")
478        l = "%8s%-7s %s" % ("", inst, args)
479        ofp.write(l.rstrip() + "\n")
480    ofp.close()
481    ##print olines
482
483
484if __name__ == "__main__":
485    sys.exit(main(sys.argv))
486
487