1package rardecode
2
3import (
4	"encoding/binary"
5	"errors"
6)
7
8const (
9	// vm flag bits
10	flagC = 1          // Carry
11	flagZ = 2          // Zero
12	flagS = 0x80000000 // Sign
13
14	maxCommands = 25000000 // maximum number of commands that can be run in a program
15
16	vmRegs = 8       // number if registers
17	vmSize = 0x40000 // memory size
18	vmMask = vmSize - 1
19)
20
21var (
22	errInvalidVMInstruction = errors.New("rardecode: invalid vm instruction")
23)
24
25type vm struct {
26	ip    uint32         // instruction pointer
27	ipMod bool           // ip was modified
28	fl    uint32         // flag bits
29	r     [vmRegs]uint32 // registers
30	m     []byte         // memory
31}
32
33func (v *vm) setIP(ip uint32) {
34	v.ip = ip
35	v.ipMod = true
36}
37
38// execute runs a list of commands on the vm.
39func (v *vm) execute(cmd []command) {
40	v.ip = 0 // reset instruction pointer
41	for n := 0; n < maxCommands; n++ {
42		ip := v.ip
43		if ip >= uint32(len(cmd)) {
44			return
45		}
46		ins := cmd[ip]
47		ins.f(v, ins.bm, ins.op) // run cpu instruction
48		if v.ipMod {
49			// command modified ip, don't increment
50			v.ipMod = false
51		} else {
52			v.ip++ // increment ip for next command
53		}
54	}
55}
56
57// newVM creates a new RAR virtual machine using the byte slice as memory.
58func newVM(mem []byte) *vm {
59	v := new(vm)
60
61	if cap(mem) < vmSize+4 {
62		v.m = make([]byte, vmSize+4)
63		copy(v.m, mem)
64	} else {
65		v.m = mem[:vmSize+4]
66		for i := len(mem); i < len(v.m); i++ {
67			v.m[i] = 0
68		}
69	}
70	v.r[7] = vmSize
71	return v
72}
73
74type operand interface {
75	get(v *vm, byteMode bool) uint32
76	set(v *vm, byteMode bool, n uint32)
77}
78
79// Immediate Operand
80type opI uint32
81
82func (op opI) get(v *vm, bm bool) uint32    { return uint32(op) }
83func (op opI) set(v *vm, bm bool, n uint32) {}
84
85// Direct Operand
86type opD uint32
87
88func (op opD) get(v *vm, byteMode bool) uint32 {
89	if byteMode {
90		return uint32(v.m[op])
91	}
92	return binary.LittleEndian.Uint32(v.m[op:])
93}
94
95func (op opD) set(v *vm, byteMode bool, n uint32) {
96	if byteMode {
97		v.m[op] = byte(n)
98	} else {
99		binary.LittleEndian.PutUint32(v.m[op:], n)
100	}
101}
102
103// Register  Operand
104type opR uint32
105
106func (op opR) get(v *vm, byteMode bool) uint32 {
107	if byteMode {
108		return v.r[op] & 0xFF
109	}
110	return v.r[op]
111}
112
113func (op opR) set(v *vm, byteMode bool, n uint32) {
114	if byteMode {
115		v.r[op] = (v.r[op] & 0xFFFFFF00) | (n & 0xFF)
116	} else {
117		v.r[op] = n
118	}
119}
120
121// Register Indirect Operand
122type opRI uint32
123
124func (op opRI) get(v *vm, byteMode bool) uint32 {
125	i := v.r[op] & vmMask
126	if byteMode {
127		return uint32(v.m[i])
128	}
129	return binary.LittleEndian.Uint32(v.m[i:])
130}
131func (op opRI) set(v *vm, byteMode bool, n uint32) {
132	i := v.r[op] & vmMask
133	if byteMode {
134		v.m[i] = byte(n)
135	} else {
136		binary.LittleEndian.PutUint32(v.m[i:], n)
137	}
138}
139
140// Base Plus Index Indirect Operand
141type opBI struct {
142	r uint32
143	i uint32
144}
145
146func (op opBI) get(v *vm, byteMode bool) uint32 {
147	i := (v.r[op.r] + op.i) & vmMask
148	if byteMode {
149		return uint32(v.m[i])
150	}
151	return binary.LittleEndian.Uint32(v.m[i:])
152}
153func (op opBI) set(v *vm, byteMode bool, n uint32) {
154	i := (v.r[op.r] + op.i) & vmMask
155	if byteMode {
156		v.m[i] = byte(n)
157	} else {
158		binary.LittleEndian.PutUint32(v.m[i:], n)
159	}
160}
161
162type commandFunc func(v *vm, byteMode bool, op []operand)
163
164type command struct {
165	f  commandFunc
166	bm bool // is byte mode
167	op []operand
168}
169
170var (
171	ops = []struct {
172		f        commandFunc
173		byteMode bool // supports byte mode
174		nops     int  // number of operands
175		jop      bool // is a jump op
176	}{
177		{mov, true, 2, false},
178		{cmp, true, 2, false},
179		{add, true, 2, false},
180		{sub, true, 2, false},
181		{jz, false, 1, true},
182		{jnz, false, 1, true},
183		{inc, true, 1, false},
184		{dec, true, 1, false},
185		{jmp, false, 1, true},
186		{xor, true, 2, false},
187		{and, true, 2, false},
188		{or, true, 2, false},
189		{test, true, 2, false},
190		{js, false, 1, true},
191		{jns, false, 1, true},
192		{jb, false, 1, true},
193		{jbe, false, 1, true},
194		{ja, false, 1, true},
195		{jae, false, 1, true},
196		{push, false, 1, false},
197		{pop, false, 1, false},
198		{call, false, 1, true},
199		{ret, false, 0, false},
200		{not, true, 1, false},
201		{shl, true, 2, false},
202		{shr, true, 2, false},
203		{sar, true, 2, false},
204		{neg, true, 1, false},
205		{pusha, false, 0, false},
206		{popa, false, 0, false},
207		{pushf, false, 0, false},
208		{popf, false, 0, false},
209		{movzx, false, 2, false},
210		{movsx, false, 2, false},
211		{xchg, true, 2, false},
212		{mul, true, 2, false},
213		{div, true, 2, false},
214		{adc, true, 2, false},
215		{sbb, true, 2, false},
216		{print, false, 0, false},
217	}
218)
219
220func mov(v *vm, bm bool, op []operand) {
221	op[0].set(v, bm, op[1].get(v, bm))
222}
223
224func cmp(v *vm, bm bool, op []operand) {
225	v1 := op[0].get(v, bm)
226	r := v1 - op[1].get(v, bm)
227	if r == 0 {
228		v.fl = flagZ
229	} else {
230		v.fl = 0
231		if r > v1 {
232			v.fl = flagC
233		}
234		v.fl |= r & flagS
235	}
236}
237
238func add(v *vm, bm bool, op []operand) {
239	v1 := op[0].get(v, bm)
240	r := v1 + op[1].get(v, bm)
241	v.fl = 0
242	signBit := uint32(flagS)
243	if bm {
244		r &= 0xFF
245		signBit = 0x80
246	}
247	if r < v1 {
248		v.fl |= flagC
249	}
250	if r == 0 {
251		v.fl |= flagZ
252	} else if r&signBit > 0 {
253		v.fl |= flagS
254	}
255	op[0].set(v, bm, r)
256}
257
258func sub(v *vm, bm bool, op []operand) {
259	v1 := op[0].get(v, bm)
260	r := v1 - op[1].get(v, bm)
261	v.fl = 0
262
263	if r == 0 {
264		v.fl = flagZ
265	} else {
266		v.fl = 0
267		if r > v1 {
268			v.fl = flagC
269		}
270		v.fl |= r & flagS
271	}
272	op[0].set(v, bm, r)
273}
274
275func jz(v *vm, bm bool, op []operand) {
276	if v.fl&flagZ > 0 {
277		v.setIP(op[0].get(v, false))
278	}
279}
280
281func jnz(v *vm, bm bool, op []operand) {
282	if v.fl&flagZ == 0 {
283		v.setIP(op[0].get(v, false))
284	}
285}
286
287func inc(v *vm, bm bool, op []operand) {
288	r := op[0].get(v, bm) + 1
289	if bm {
290		r &= 0xFF
291	}
292	op[0].set(v, bm, r)
293	if r == 0 {
294		v.fl = flagZ
295	} else {
296		v.fl = r & flagS
297	}
298}
299
300func dec(v *vm, bm bool, op []operand) {
301	r := op[0].get(v, bm) - 1
302	op[0].set(v, bm, r)
303	if r == 0 {
304		v.fl = flagZ
305	} else {
306		v.fl = r & flagS
307	}
308}
309
310func jmp(v *vm, bm bool, op []operand) {
311	v.setIP(op[0].get(v, false))
312}
313
314func xor(v *vm, bm bool, op []operand) {
315	r := op[0].get(v, bm) ^ op[1].get(v, bm)
316	op[0].set(v, bm, r)
317	if r == 0 {
318		v.fl = flagZ
319	} else {
320		v.fl = r & flagS
321	}
322}
323
324func and(v *vm, bm bool, op []operand) {
325	r := op[0].get(v, bm) & op[1].get(v, bm)
326	op[0].set(v, bm, r)
327	if r == 0 {
328		v.fl = flagZ
329	} else {
330		v.fl = r & flagS
331	}
332}
333
334func or(v *vm, bm bool, op []operand) {
335	r := op[0].get(v, bm) | op[1].get(v, bm)
336	op[0].set(v, bm, r)
337	if r == 0 {
338		v.fl = flagZ
339	} else {
340		v.fl = r & flagS
341	}
342}
343
344func test(v *vm, bm bool, op []operand) {
345	r := op[0].get(v, bm) & op[1].get(v, bm)
346	if r == 0 {
347		v.fl = flagZ
348	} else {
349		v.fl = r & flagS
350	}
351}
352
353func js(v *vm, bm bool, op []operand) {
354	if v.fl&flagS > 0 {
355		v.setIP(op[0].get(v, false))
356	}
357}
358
359func jns(v *vm, bm bool, op []operand) {
360	if v.fl&flagS == 0 {
361		v.setIP(op[0].get(v, false))
362	}
363}
364
365func jb(v *vm, bm bool, op []operand) {
366	if v.fl&flagC > 0 {
367		v.setIP(op[0].get(v, false))
368	}
369}
370
371func jbe(v *vm, bm bool, op []operand) {
372	if v.fl&(flagC|flagZ) > 0 {
373		v.setIP(op[0].get(v, false))
374	}
375}
376
377func ja(v *vm, bm bool, op []operand) {
378	if v.fl&(flagC|flagZ) == 0 {
379		v.setIP(op[0].get(v, false))
380	}
381}
382
383func jae(v *vm, bm bool, op []operand) {
384	if v.fl&flagC == 0 {
385		v.setIP(op[0].get(v, false))
386	}
387}
388
389func push(v *vm, bm bool, op []operand) {
390	v.r[7] -= 4
391	opRI(7).set(v, false, op[0].get(v, false))
392
393}
394
395func pop(v *vm, bm bool, op []operand) {
396	op[0].set(v, false, opRI(7).get(v, false))
397	v.r[7] += 4
398}
399
400func call(v *vm, bm bool, op []operand) {
401	v.r[7] -= 4
402	opRI(7).set(v, false, v.ip+1)
403	v.setIP(op[0].get(v, false))
404}
405
406func ret(v *vm, bm bool, op []operand) {
407	r7 := v.r[7]
408	if r7 >= vmSize {
409		v.setIP(0xFFFFFFFF) // trigger end of program
410	} else {
411		v.setIP(binary.LittleEndian.Uint32(v.m[r7:]))
412		v.r[7] += 4
413	}
414}
415
416func not(v *vm, bm bool, op []operand) {
417	op[0].set(v, bm, ^op[0].get(v, bm))
418}
419
420func shl(v *vm, bm bool, op []operand) {
421	v1 := op[0].get(v, bm)
422	v2 := op[1].get(v, bm)
423	r := v1 << v2
424	op[0].set(v, bm, r)
425	if r == 0 {
426		v.fl = flagZ
427	} else {
428		v.fl = r & flagS
429	}
430	if (v1<<(v2-1))&0x80000000 > 0 {
431		v.fl |= flagC
432	}
433}
434
435func shr(v *vm, bm bool, op []operand) {
436	v1 := op[0].get(v, bm)
437	v2 := op[1].get(v, bm)
438	r := v1 >> v2
439	op[0].set(v, bm, r)
440	if r == 0 {
441		v.fl = flagZ
442	} else {
443		v.fl = r & flagS
444	}
445	if (v1>>(v2-1))&0x1 > 0 {
446		v.fl |= flagC
447	}
448}
449
450func sar(v *vm, bm bool, op []operand) {
451	v1 := op[0].get(v, bm)
452	v2 := op[1].get(v, bm)
453	r := uint32(int32(v1) >> v2)
454	op[0].set(v, bm, r)
455	if r == 0 {
456		v.fl = flagZ
457	} else {
458		v.fl = r & flagS
459	}
460	if (v1>>(v2-1))&0x1 > 0 {
461		v.fl |= flagC
462	}
463}
464
465func neg(v *vm, bm bool, op []operand) {
466	r := 0 - op[0].get(v, bm)
467	op[0].set(v, bm, r)
468	if r == 0 {
469		v.fl = flagZ
470	} else {
471		v.fl = r&flagS | flagC
472	}
473}
474
475func pusha(v *vm, bm bool, op []operand) {
476	sp := opD(v.r[7])
477	for _, r := range v.r {
478		sp = (sp - 4) & vmMask
479		sp.set(v, false, r)
480	}
481	v.r[7] = uint32(sp)
482}
483
484func popa(v *vm, bm bool, op []operand) {
485	sp := opD(v.r[7])
486	for i := 7; i >= 0; i-- {
487		v.r[i] = sp.get(v, false)
488		sp = (sp + 4) & vmMask
489	}
490}
491
492func pushf(v *vm, bm bool, op []operand) {
493	v.r[7] -= 4
494	opRI(7).set(v, false, v.fl)
495}
496
497func popf(v *vm, bm bool, op []operand) {
498	v.fl = opRI(7).get(v, false)
499	v.r[7] += 4
500}
501
502func movzx(v *vm, bm bool, op []operand) {
503	op[0].set(v, false, op[1].get(v, true))
504}
505
506func movsx(v *vm, bm bool, op []operand) {
507	op[0].set(v, false, uint32(int8(op[1].get(v, true))))
508}
509
510func xchg(v *vm, bm bool, op []operand) {
511	v1 := op[0].get(v, bm)
512	op[0].set(v, bm, op[1].get(v, bm))
513	op[1].set(v, bm, v1)
514}
515
516func mul(v *vm, bm bool, op []operand) {
517	r := op[0].get(v, bm) * op[1].get(v, bm)
518	op[0].set(v, bm, r)
519}
520
521func div(v *vm, bm bool, op []operand) {
522	div := op[1].get(v, bm)
523	if div != 0 {
524		r := op[0].get(v, bm) / div
525		op[0].set(v, bm, r)
526	}
527}
528
529func adc(v *vm, bm bool, op []operand) {
530	v1 := op[0].get(v, bm)
531	fc := v.fl & flagC
532	r := v1 + op[1].get(v, bm) + fc
533	if bm {
534		r &= 0xFF
535	}
536	op[0].set(v, bm, r)
537
538	if r == 0 {
539		v.fl = flagZ
540	} else {
541		v.fl = r & flagS
542	}
543	if r < v1 || (r == v1 && fc > 0) {
544		v.fl |= flagC
545	}
546}
547
548func sbb(v *vm, bm bool, op []operand) {
549	v1 := op[0].get(v, bm)
550	fc := v.fl & flagC
551	r := v1 - op[1].get(v, bm) - fc
552	if bm {
553		r &= 0xFF
554	}
555	op[0].set(v, bm, r)
556
557	if r == 0 {
558		v.fl = flagZ
559	} else {
560		v.fl = r & flagS
561	}
562	if r > v1 || (r == v1 && fc > 0) {
563		v.fl |= flagC
564	}
565}
566
567func print(v *vm, bm bool, op []operand) {
568	// TODO: ignore print for the moment
569}
570
571func decodeArg(br *rarBitReader, byteMode bool) (operand, error) {
572	n, err := br.readBits(1)
573	if err != nil {
574		return nil, err
575	}
576	if n > 0 { // Register
577		n, err = br.readBits(3)
578		return opR(n), err
579	}
580	n, err = br.readBits(1)
581	if err != nil {
582		return nil, err
583	}
584	if n == 0 { // Immediate
585		if byteMode {
586			n, err = br.readBits(8)
587		} else {
588			m, err := br.readUint32()
589			return opI(m), err
590		}
591		return opI(n), err
592	}
593	n, err = br.readBits(1)
594	if err != nil {
595		return nil, err
596	}
597	if n == 0 {
598		// Register Indirect
599		n, err = br.readBits(3)
600		return opRI(n), err
601	}
602	n, err = br.readBits(1)
603	if err != nil {
604		return nil, err
605	}
606	if n == 0 {
607		// Base + Index Indirect
608		n, err = br.readBits(3)
609		if err != nil {
610			return nil, err
611		}
612		i, err := br.readUint32()
613		return opBI{r: uint32(n), i: i}, err
614	}
615	// Direct addressing
616	m, err := br.readUint32()
617	return opD(m & vmMask), err
618}
619
620func fixJumpOp(op operand, off int) operand {
621	n, ok := op.(opI)
622	if !ok {
623		return op
624	}
625	if n >= 256 {
626		return n - 256
627	}
628	if n >= 136 {
629		n -= 264
630	} else if n >= 16 {
631		n -= 8
632	} else if n >= 8 {
633		n -= 16
634	}
635	return n + opI(off)
636}
637
638func readCommands(br *rarBitReader) ([]command, error) {
639	var cmds []command
640
641	for {
642		code, err := br.readBits(4)
643		if err != nil {
644			return cmds, err
645		}
646		if code&0x08 > 0 {
647			n, err := br.readBits(2)
648			if err != nil {
649				return cmds, err
650			}
651			code = (code<<2 | n) - 24
652		}
653
654		if code >= len(ops) {
655			return cmds, errInvalidVMInstruction
656		}
657		ins := ops[code]
658
659		var com command
660
661		if ins.byteMode {
662			n, err := br.readBits(1)
663			if err != nil {
664				return cmds, err
665			}
666			com.bm = n > 0
667		}
668		com.f = ins.f
669
670		if ins.nops > 0 {
671			com.op = make([]operand, ins.nops)
672			com.op[0], err = decodeArg(br, com.bm)
673			if err != nil {
674				return cmds, err
675			}
676			if ins.nops == 2 {
677				com.op[1], err = decodeArg(br, com.bm)
678				if err != nil {
679					return cmds, err
680				}
681			} else if ins.jop {
682				com.op[0] = fixJumpOp(com.op[0], len(cmds))
683			}
684		}
685		cmds = append(cmds, com)
686	}
687}
688