1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package bpf
6
7import "fmt"
8
9// An Instruction is one instruction executed by the BPF virtual
10// machine.
11type Instruction interface {
12	// Assemble assembles the Instruction into a RawInstruction.
13	Assemble() (RawInstruction, error)
14}
15
16// A RawInstruction is a raw BPF virtual machine instruction.
17type RawInstruction struct {
18	// Operation to execute.
19	Op uint16
20	// For conditional jump instructions, the number of instructions
21	// to skip if the condition is true/false.
22	Jt uint8
23	Jf uint8
24	// Constant parameter. The meaning depends on the Op.
25	K uint32
26}
27
28// Assemble implements the Instruction Assemble method.
29func (ri RawInstruction) Assemble() (RawInstruction, error) { return ri, nil }
30
31// Disassemble parses ri into an Instruction and returns it. If ri is
32// not recognized by this package, ri itself is returned.
33func (ri RawInstruction) Disassemble() Instruction {
34	switch ri.Op & opMaskCls {
35	case opClsLoadA, opClsLoadX:
36		reg := Register(ri.Op & opMaskLoadDest)
37		sz := 0
38		switch ri.Op & opMaskLoadWidth {
39		case opLoadWidth4:
40			sz = 4
41		case opLoadWidth2:
42			sz = 2
43		case opLoadWidth1:
44			sz = 1
45		default:
46			return ri
47		}
48		switch ri.Op & opMaskLoadMode {
49		case opAddrModeImmediate:
50			if sz != 4 {
51				return ri
52			}
53			return LoadConstant{Dst: reg, Val: ri.K}
54		case opAddrModeScratch:
55			if sz != 4 || ri.K > 15 {
56				return ri
57			}
58			return LoadScratch{Dst: reg, N: int(ri.K)}
59		case opAddrModeAbsolute:
60			if ri.K > extOffset+0xffffffff {
61				return LoadExtension{Num: Extension(-extOffset + ri.K)}
62			}
63			return LoadAbsolute{Size: sz, Off: ri.K}
64		case opAddrModeIndirect:
65			return LoadIndirect{Size: sz, Off: ri.K}
66		case opAddrModePacketLen:
67			if sz != 4 {
68				return ri
69			}
70			return LoadExtension{Num: ExtLen}
71		case opAddrModeMemShift:
72			return LoadMemShift{Off: ri.K}
73		default:
74			return ri
75		}
76
77	case opClsStoreA:
78		if ri.Op != opClsStoreA || ri.K > 15 {
79			return ri
80		}
81		return StoreScratch{Src: RegA, N: int(ri.K)}
82
83	case opClsStoreX:
84		if ri.Op != opClsStoreX || ri.K > 15 {
85			return ri
86		}
87		return StoreScratch{Src: RegX, N: int(ri.K)}
88
89	case opClsALU:
90		switch op := ALUOp(ri.Op & opMaskOperator); op {
91		case ALUOpAdd, ALUOpSub, ALUOpMul, ALUOpDiv, ALUOpOr, ALUOpAnd, ALUOpShiftLeft, ALUOpShiftRight, ALUOpMod, ALUOpXor:
92			switch operand := opOperand(ri.Op & opMaskOperand); operand {
93			case opOperandX:
94				return ALUOpX{Op: op}
95			case opOperandConstant:
96				return ALUOpConstant{Op: op, Val: ri.K}
97			default:
98				return ri
99			}
100		case aluOpNeg:
101			return NegateA{}
102		default:
103			return ri
104		}
105
106	case opClsJump:
107		switch op := jumpOp(ri.Op & opMaskOperator); op {
108		case opJumpAlways:
109			return Jump{Skip: ri.K}
110		case opJumpEqual, opJumpGT, opJumpGE, opJumpSet:
111			cond, skipTrue, skipFalse := jumpOpToTest(op, ri.Jt, ri.Jf)
112			switch operand := opOperand(ri.Op & opMaskOperand); operand {
113			case opOperandX:
114				return JumpIfX{Cond: cond, SkipTrue: skipTrue, SkipFalse: skipFalse}
115			case opOperandConstant:
116				return JumpIf{Cond: cond, Val: ri.K, SkipTrue: skipTrue, SkipFalse: skipFalse}
117			default:
118				return ri
119			}
120		default:
121			return ri
122		}
123
124	case opClsReturn:
125		switch ri.Op {
126		case opClsReturn | opRetSrcA:
127			return RetA{}
128		case opClsReturn | opRetSrcConstant:
129			return RetConstant{Val: ri.K}
130		default:
131			return ri
132		}
133
134	case opClsMisc:
135		switch ri.Op {
136		case opClsMisc | opMiscTAX:
137			return TAX{}
138		case opClsMisc | opMiscTXA:
139			return TXA{}
140		default:
141			return ri
142		}
143
144	default:
145		panic("unreachable") // switch is exhaustive on the bit pattern
146	}
147}
148
149func jumpOpToTest(op jumpOp, skipTrue uint8, skipFalse uint8) (JumpTest, uint8, uint8) {
150	var test JumpTest
151
152	// Decode "fake" jump conditions that don't appear in machine code
153	// Ensures the Assemble -> Disassemble stage recreates the same instructions
154	// See https://github.com/golang/go/issues/18470
155	if skipTrue == 0 {
156		switch op {
157		case opJumpEqual:
158			test = JumpNotEqual
159		case opJumpGT:
160			test = JumpLessOrEqual
161		case opJumpGE:
162			test = JumpLessThan
163		case opJumpSet:
164			test = JumpBitsNotSet
165		}
166
167		return test, skipFalse, 0
168	}
169
170	switch op {
171	case opJumpEqual:
172		test = JumpEqual
173	case opJumpGT:
174		test = JumpGreaterThan
175	case opJumpGE:
176		test = JumpGreaterOrEqual
177	case opJumpSet:
178		test = JumpBitsSet
179	}
180
181	return test, skipTrue, skipFalse
182}
183
184// LoadConstant loads Val into register Dst.
185type LoadConstant struct {
186	Dst Register
187	Val uint32
188}
189
190// Assemble implements the Instruction Assemble method.
191func (a LoadConstant) Assemble() (RawInstruction, error) {
192	return assembleLoad(a.Dst, 4, opAddrModeImmediate, a.Val)
193}
194
195// String returns the instruction in assembler notation.
196func (a LoadConstant) String() string {
197	switch a.Dst {
198	case RegA:
199		return fmt.Sprintf("ld #%d", a.Val)
200	case RegX:
201		return fmt.Sprintf("ldx #%d", a.Val)
202	default:
203		return fmt.Sprintf("unknown instruction: %#v", a)
204	}
205}
206
207// LoadScratch loads scratch[N] into register Dst.
208type LoadScratch struct {
209	Dst Register
210	N   int // 0-15
211}
212
213// Assemble implements the Instruction Assemble method.
214func (a LoadScratch) Assemble() (RawInstruction, error) {
215	if a.N < 0 || a.N > 15 {
216		return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N)
217	}
218	return assembleLoad(a.Dst, 4, opAddrModeScratch, uint32(a.N))
219}
220
221// String returns the instruction in assembler notation.
222func (a LoadScratch) String() string {
223	switch a.Dst {
224	case RegA:
225		return fmt.Sprintf("ld M[%d]", a.N)
226	case RegX:
227		return fmt.Sprintf("ldx M[%d]", a.N)
228	default:
229		return fmt.Sprintf("unknown instruction: %#v", a)
230	}
231}
232
233// LoadAbsolute loads packet[Off:Off+Size] as an integer value into
234// register A.
235type LoadAbsolute struct {
236	Off  uint32
237	Size int // 1, 2 or 4
238}
239
240// Assemble implements the Instruction Assemble method.
241func (a LoadAbsolute) Assemble() (RawInstruction, error) {
242	return assembleLoad(RegA, a.Size, opAddrModeAbsolute, a.Off)
243}
244
245// String returns the instruction in assembler notation.
246func (a LoadAbsolute) String() string {
247	switch a.Size {
248	case 1: // byte
249		return fmt.Sprintf("ldb [%d]", a.Off)
250	case 2: // half word
251		return fmt.Sprintf("ldh [%d]", a.Off)
252	case 4: // word
253		if a.Off > extOffset+0xffffffff {
254			return LoadExtension{Num: Extension(a.Off + 0x1000)}.String()
255		}
256		return fmt.Sprintf("ld [%d]", a.Off)
257	default:
258		return fmt.Sprintf("unknown instruction: %#v", a)
259	}
260}
261
262// LoadIndirect loads packet[X+Off:X+Off+Size] as an integer value
263// into register A.
264type LoadIndirect struct {
265	Off  uint32
266	Size int // 1, 2 or 4
267}
268
269// Assemble implements the Instruction Assemble method.
270func (a LoadIndirect) Assemble() (RawInstruction, error) {
271	return assembleLoad(RegA, a.Size, opAddrModeIndirect, a.Off)
272}
273
274// String returns the instruction in assembler notation.
275func (a LoadIndirect) String() string {
276	switch a.Size {
277	case 1: // byte
278		return fmt.Sprintf("ldb [x + %d]", a.Off)
279	case 2: // half word
280		return fmt.Sprintf("ldh [x + %d]", a.Off)
281	case 4: // word
282		return fmt.Sprintf("ld [x + %d]", a.Off)
283	default:
284		return fmt.Sprintf("unknown instruction: %#v", a)
285	}
286}
287
288// LoadMemShift multiplies the first 4 bits of the byte at packet[Off]
289// by 4 and stores the result in register X.
290//
291// This instruction is mainly useful to load into X the length of an
292// IPv4 packet header in a single instruction, rather than have to do
293// the arithmetic on the header's first byte by hand.
294type LoadMemShift struct {
295	Off uint32
296}
297
298// Assemble implements the Instruction Assemble method.
299func (a LoadMemShift) Assemble() (RawInstruction, error) {
300	return assembleLoad(RegX, 1, opAddrModeMemShift, a.Off)
301}
302
303// String returns the instruction in assembler notation.
304func (a LoadMemShift) String() string {
305	return fmt.Sprintf("ldx 4*([%d]&0xf)", a.Off)
306}
307
308// LoadExtension invokes a linux-specific extension and stores the
309// result in register A.
310type LoadExtension struct {
311	Num Extension
312}
313
314// Assemble implements the Instruction Assemble method.
315func (a LoadExtension) Assemble() (RawInstruction, error) {
316	if a.Num == ExtLen {
317		return assembleLoad(RegA, 4, opAddrModePacketLen, 0)
318	}
319	return assembleLoad(RegA, 4, opAddrModeAbsolute, uint32(extOffset+a.Num))
320}
321
322// String returns the instruction in assembler notation.
323func (a LoadExtension) String() string {
324	switch a.Num {
325	case ExtLen:
326		return "ld #len"
327	case ExtProto:
328		return "ld #proto"
329	case ExtType:
330		return "ld #type"
331	case ExtPayloadOffset:
332		return "ld #poff"
333	case ExtInterfaceIndex:
334		return "ld #ifidx"
335	case ExtNetlinkAttr:
336		return "ld #nla"
337	case ExtNetlinkAttrNested:
338		return "ld #nlan"
339	case ExtMark:
340		return "ld #mark"
341	case ExtQueue:
342		return "ld #queue"
343	case ExtLinkLayerType:
344		return "ld #hatype"
345	case ExtRXHash:
346		return "ld #rxhash"
347	case ExtCPUID:
348		return "ld #cpu"
349	case ExtVLANTag:
350		return "ld #vlan_tci"
351	case ExtVLANTagPresent:
352		return "ld #vlan_avail"
353	case ExtVLANProto:
354		return "ld #vlan_tpid"
355	case ExtRand:
356		return "ld #rand"
357	default:
358		return fmt.Sprintf("unknown instruction: %#v", a)
359	}
360}
361
362// StoreScratch stores register Src into scratch[N].
363type StoreScratch struct {
364	Src Register
365	N   int // 0-15
366}
367
368// Assemble implements the Instruction Assemble method.
369func (a StoreScratch) Assemble() (RawInstruction, error) {
370	if a.N < 0 || a.N > 15 {
371		return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N)
372	}
373	var op uint16
374	switch a.Src {
375	case RegA:
376		op = opClsStoreA
377	case RegX:
378		op = opClsStoreX
379	default:
380		return RawInstruction{}, fmt.Errorf("invalid source register %v", a.Src)
381	}
382
383	return RawInstruction{
384		Op: op,
385		K:  uint32(a.N),
386	}, nil
387}
388
389// String returns the instruction in assembler notation.
390func (a StoreScratch) String() string {
391	switch a.Src {
392	case RegA:
393		return fmt.Sprintf("st M[%d]", a.N)
394	case RegX:
395		return fmt.Sprintf("stx M[%d]", a.N)
396	default:
397		return fmt.Sprintf("unknown instruction: %#v", a)
398	}
399}
400
401// ALUOpConstant executes A = A <Op> Val.
402type ALUOpConstant struct {
403	Op  ALUOp
404	Val uint32
405}
406
407// Assemble implements the Instruction Assemble method.
408func (a ALUOpConstant) Assemble() (RawInstruction, error) {
409	return RawInstruction{
410		Op: opClsALU | uint16(opOperandConstant) | uint16(a.Op),
411		K:  a.Val,
412	}, nil
413}
414
415// String returns the instruction in assembler notation.
416func (a ALUOpConstant) String() string {
417	switch a.Op {
418	case ALUOpAdd:
419		return fmt.Sprintf("add #%d", a.Val)
420	case ALUOpSub:
421		return fmt.Sprintf("sub #%d", a.Val)
422	case ALUOpMul:
423		return fmt.Sprintf("mul #%d", a.Val)
424	case ALUOpDiv:
425		return fmt.Sprintf("div #%d", a.Val)
426	case ALUOpMod:
427		return fmt.Sprintf("mod #%d", a.Val)
428	case ALUOpAnd:
429		return fmt.Sprintf("and #%d", a.Val)
430	case ALUOpOr:
431		return fmt.Sprintf("or #%d", a.Val)
432	case ALUOpXor:
433		return fmt.Sprintf("xor #%d", a.Val)
434	case ALUOpShiftLeft:
435		return fmt.Sprintf("lsh #%d", a.Val)
436	case ALUOpShiftRight:
437		return fmt.Sprintf("rsh #%d", a.Val)
438	default:
439		return fmt.Sprintf("unknown instruction: %#v", a)
440	}
441}
442
443// ALUOpX executes A = A <Op> X
444type ALUOpX struct {
445	Op ALUOp
446}
447
448// Assemble implements the Instruction Assemble method.
449func (a ALUOpX) Assemble() (RawInstruction, error) {
450	return RawInstruction{
451		Op: opClsALU | uint16(opOperandX) | uint16(a.Op),
452	}, nil
453}
454
455// String returns the instruction in assembler notation.
456func (a ALUOpX) String() string {
457	switch a.Op {
458	case ALUOpAdd:
459		return "add x"
460	case ALUOpSub:
461		return "sub x"
462	case ALUOpMul:
463		return "mul x"
464	case ALUOpDiv:
465		return "div x"
466	case ALUOpMod:
467		return "mod x"
468	case ALUOpAnd:
469		return "and x"
470	case ALUOpOr:
471		return "or x"
472	case ALUOpXor:
473		return "xor x"
474	case ALUOpShiftLeft:
475		return "lsh x"
476	case ALUOpShiftRight:
477		return "rsh x"
478	default:
479		return fmt.Sprintf("unknown instruction: %#v", a)
480	}
481}
482
483// NegateA executes A = -A.
484type NegateA struct{}
485
486// Assemble implements the Instruction Assemble method.
487func (a NegateA) Assemble() (RawInstruction, error) {
488	return RawInstruction{
489		Op: opClsALU | uint16(aluOpNeg),
490	}, nil
491}
492
493// String returns the instruction in assembler notation.
494func (a NegateA) String() string {
495	return fmt.Sprintf("neg")
496}
497
498// Jump skips the following Skip instructions in the program.
499type Jump struct {
500	Skip uint32
501}
502
503// Assemble implements the Instruction Assemble method.
504func (a Jump) Assemble() (RawInstruction, error) {
505	return RawInstruction{
506		Op: opClsJump | uint16(opJumpAlways),
507		K:  a.Skip,
508	}, nil
509}
510
511// String returns the instruction in assembler notation.
512func (a Jump) String() string {
513	return fmt.Sprintf("ja %d", a.Skip)
514}
515
516// JumpIf skips the following Skip instructions in the program if A
517// <Cond> Val is true.
518type JumpIf struct {
519	Cond      JumpTest
520	Val       uint32
521	SkipTrue  uint8
522	SkipFalse uint8
523}
524
525// Assemble implements the Instruction Assemble method.
526func (a JumpIf) Assemble() (RawInstruction, error) {
527	return jumpToRaw(a.Cond, opOperandConstant, a.Val, a.SkipTrue, a.SkipFalse)
528}
529
530// String returns the instruction in assembler notation.
531func (a JumpIf) String() string {
532	return jumpToString(a.Cond, fmt.Sprintf("#%d", a.Val), a.SkipTrue, a.SkipFalse)
533}
534
535// JumpIfX skips the following Skip instructions in the program if A
536// <Cond> X is true.
537type JumpIfX struct {
538	Cond      JumpTest
539	SkipTrue  uint8
540	SkipFalse uint8
541}
542
543// Assemble implements the Instruction Assemble method.
544func (a JumpIfX) Assemble() (RawInstruction, error) {
545	return jumpToRaw(a.Cond, opOperandX, 0, a.SkipTrue, a.SkipFalse)
546}
547
548// String returns the instruction in assembler notation.
549func (a JumpIfX) String() string {
550	return jumpToString(a.Cond, "x", a.SkipTrue, a.SkipFalse)
551}
552
553// jumpToRaw assembles a jump instruction into a RawInstruction
554func jumpToRaw(test JumpTest, operand opOperand, k uint32, skipTrue, skipFalse uint8) (RawInstruction, error) {
555	var (
556		cond jumpOp
557		flip bool
558	)
559	switch test {
560	case JumpEqual:
561		cond = opJumpEqual
562	case JumpNotEqual:
563		cond, flip = opJumpEqual, true
564	case JumpGreaterThan:
565		cond = opJumpGT
566	case JumpLessThan:
567		cond, flip = opJumpGE, true
568	case JumpGreaterOrEqual:
569		cond = opJumpGE
570	case JumpLessOrEqual:
571		cond, flip = opJumpGT, true
572	case JumpBitsSet:
573		cond = opJumpSet
574	case JumpBitsNotSet:
575		cond, flip = opJumpSet, true
576	default:
577		return RawInstruction{}, fmt.Errorf("unknown JumpTest %v", test)
578	}
579	jt, jf := skipTrue, skipFalse
580	if flip {
581		jt, jf = jf, jt
582	}
583	return RawInstruction{
584		Op: opClsJump | uint16(cond) | uint16(operand),
585		Jt: jt,
586		Jf: jf,
587		K:  k,
588	}, nil
589}
590
591// jumpToString converts a jump instruction to assembler notation
592func jumpToString(cond JumpTest, operand string, skipTrue, skipFalse uint8) string {
593	switch cond {
594	// K == A
595	case JumpEqual:
596		return conditionalJump(operand, skipTrue, skipFalse, "jeq", "jneq")
597	// K != A
598	case JumpNotEqual:
599		return fmt.Sprintf("jneq %s,%d", operand, skipTrue)
600	// K > A
601	case JumpGreaterThan:
602		return conditionalJump(operand, skipTrue, skipFalse, "jgt", "jle")
603	// K < A
604	case JumpLessThan:
605		return fmt.Sprintf("jlt %s,%d", operand, skipTrue)
606	// K >= A
607	case JumpGreaterOrEqual:
608		return conditionalJump(operand, skipTrue, skipFalse, "jge", "jlt")
609	// K <= A
610	case JumpLessOrEqual:
611		return fmt.Sprintf("jle %s,%d", operand, skipTrue)
612	// K & A != 0
613	case JumpBitsSet:
614		if skipFalse > 0 {
615			return fmt.Sprintf("jset %s,%d,%d", operand, skipTrue, skipFalse)
616		}
617		return fmt.Sprintf("jset %s,%d", operand, skipTrue)
618	// K & A == 0, there is no assembler instruction for JumpBitNotSet, use JumpBitSet and invert skips
619	case JumpBitsNotSet:
620		return jumpToString(JumpBitsSet, operand, skipFalse, skipTrue)
621	default:
622		return fmt.Sprintf("unknown JumpTest %#v", cond)
623	}
624}
625
626func conditionalJump(operand string, skipTrue, skipFalse uint8, positiveJump, negativeJump string) string {
627	if skipTrue > 0 {
628		if skipFalse > 0 {
629			return fmt.Sprintf("%s %s,%d,%d", positiveJump, operand, skipTrue, skipFalse)
630		}
631		return fmt.Sprintf("%s %s,%d", positiveJump, operand, skipTrue)
632	}
633	return fmt.Sprintf("%s %s,%d", negativeJump, operand, skipFalse)
634}
635
636// RetA exits the BPF program, returning the value of register A.
637type RetA struct{}
638
639// Assemble implements the Instruction Assemble method.
640func (a RetA) Assemble() (RawInstruction, error) {
641	return RawInstruction{
642		Op: opClsReturn | opRetSrcA,
643	}, nil
644}
645
646// String returns the instruction in assembler notation.
647func (a RetA) String() string {
648	return fmt.Sprintf("ret a")
649}
650
651// RetConstant exits the BPF program, returning a constant value.
652type RetConstant struct {
653	Val uint32
654}
655
656// Assemble implements the Instruction Assemble method.
657func (a RetConstant) Assemble() (RawInstruction, error) {
658	return RawInstruction{
659		Op: opClsReturn | opRetSrcConstant,
660		K:  a.Val,
661	}, nil
662}
663
664// String returns the instruction in assembler notation.
665func (a RetConstant) String() string {
666	return fmt.Sprintf("ret #%d", a.Val)
667}
668
669// TXA copies the value of register X to register A.
670type TXA struct{}
671
672// Assemble implements the Instruction Assemble method.
673func (a TXA) Assemble() (RawInstruction, error) {
674	return RawInstruction{
675		Op: opClsMisc | opMiscTXA,
676	}, nil
677}
678
679// String returns the instruction in assembler notation.
680func (a TXA) String() string {
681	return fmt.Sprintf("txa")
682}
683
684// TAX copies the value of register A to register X.
685type TAX struct{}
686
687// Assemble implements the Instruction Assemble method.
688func (a TAX) Assemble() (RawInstruction, error) {
689	return RawInstruction{
690		Op: opClsMisc | opMiscTAX,
691	}, nil
692}
693
694// String returns the instruction in assembler notation.
695func (a TAX) String() string {
696	return fmt.Sprintf("tax")
697}
698
699func assembleLoad(dst Register, loadSize int, mode uint16, k uint32) (RawInstruction, error) {
700	var (
701		cls uint16
702		sz  uint16
703	)
704	switch dst {
705	case RegA:
706		cls = opClsLoadA
707	case RegX:
708		cls = opClsLoadX
709	default:
710		return RawInstruction{}, fmt.Errorf("invalid target register %v", dst)
711	}
712	switch loadSize {
713	case 1:
714		sz = opLoadWidth1
715	case 2:
716		sz = opLoadWidth2
717	case 4:
718		sz = opLoadWidth4
719	default:
720		return RawInstruction{}, fmt.Errorf("invalid load byte length %d", sz)
721	}
722	return RawInstruction{
723		Op: cls | sz | mode,
724		K:  k,
725	}, nil
726}
727