1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package bpf
6
7import "fmt"
8
9// An Instruction is one instruction executed by the BPF virtual
10// machine.
11type Instruction interface {
12	// Assemble assembles the Instruction into a RawInstruction.
13	Assemble() (RawInstruction, error)
14}
15
16// A RawInstruction is a raw BPF virtual machine instruction.
17type RawInstruction struct {
18	// Operation to execute.
19	Op uint16
20	// For conditional jump instructions, the number of instructions
21	// to skip if the condition is true/false.
22	Jt uint8
23	Jf uint8
24	// Constant parameter. The meaning depends on the Op.
25	K uint32
26}
27
28// Assemble implements the Instruction Assemble method.
29func (ri RawInstruction) Assemble() (RawInstruction, error) { return ri, nil }
30
31// Disassemble parses ri into an Instruction and returns it. If ri is
32// not recognized by this package, ri itself is returned.
33func (ri RawInstruction) Disassemble() Instruction {
34	switch ri.Op & opMaskCls {
35	case opClsLoadA, opClsLoadX:
36		reg := Register(ri.Op & opMaskLoadDest)
37		sz := 0
38		switch ri.Op & opMaskLoadWidth {
39		case opLoadWidth4:
40			sz = 4
41		case opLoadWidth2:
42			sz = 2
43		case opLoadWidth1:
44			sz = 1
45		default:
46			return ri
47		}
48		switch ri.Op & opMaskLoadMode {
49		case opAddrModeImmediate:
50			if sz != 4 {
51				return ri
52			}
53			return LoadConstant{Dst: reg, Val: ri.K}
54		case opAddrModeScratch:
55			if sz != 4 || ri.K > 15 {
56				return ri
57			}
58			return LoadScratch{Dst: reg, N: int(ri.K)}
59		case opAddrModeAbsolute:
60			if ri.K > extOffset+0xffffffff {
61				return LoadExtension{Num: Extension(-extOffset + ri.K)}
62			}
63			return LoadAbsolute{Size: sz, Off: ri.K}
64		case opAddrModeIndirect:
65			return LoadIndirect{Size: sz, Off: ri.K}
66		case opAddrModePacketLen:
67			if sz != 4 {
68				return ri
69			}
70			return LoadExtension{Num: ExtLen}
71		case opAddrModeMemShift:
72			return LoadMemShift{Off: ri.K}
73		default:
74			return ri
75		}
76
77	case opClsStoreA:
78		if ri.Op != opClsStoreA || ri.K > 15 {
79			return ri
80		}
81		return StoreScratch{Src: RegA, N: int(ri.K)}
82
83	case opClsStoreX:
84		if ri.Op != opClsStoreX || ri.K > 15 {
85			return ri
86		}
87		return StoreScratch{Src: RegX, N: int(ri.K)}
88
89	case opClsALU:
90		switch op := ALUOp(ri.Op & opMaskOperator); op {
91		case ALUOpAdd, ALUOpSub, ALUOpMul, ALUOpDiv, ALUOpOr, ALUOpAnd, ALUOpShiftLeft, ALUOpShiftRight, ALUOpMod, ALUOpXor:
92			if ri.Op&opMaskOperandSrc != 0 {
93				return ALUOpX{Op: op}
94			}
95			return ALUOpConstant{Op: op, Val: ri.K}
96		case aluOpNeg:
97			return NegateA{}
98		default:
99			return ri
100		}
101
102	case opClsJump:
103		if ri.Op&opMaskJumpConst != opClsJump {
104			return ri
105		}
106		switch ri.Op & opMaskJumpCond {
107		case opJumpAlways:
108			return Jump{Skip: ri.K}
109		case opJumpEqual:
110			if ri.Jt == 0 {
111				return JumpIf{
112					Cond:      JumpNotEqual,
113					Val:       ri.K,
114					SkipTrue:  ri.Jf,
115					SkipFalse: 0,
116				}
117			}
118			return JumpIf{
119				Cond:      JumpEqual,
120				Val:       ri.K,
121				SkipTrue:  ri.Jt,
122				SkipFalse: ri.Jf,
123			}
124		case opJumpGT:
125			if ri.Jt == 0 {
126				return JumpIf{
127					Cond:      JumpLessOrEqual,
128					Val:       ri.K,
129					SkipTrue:  ri.Jf,
130					SkipFalse: 0,
131				}
132			}
133			return JumpIf{
134				Cond:      JumpGreaterThan,
135				Val:       ri.K,
136				SkipTrue:  ri.Jt,
137				SkipFalse: ri.Jf,
138			}
139		case opJumpGE:
140			if ri.Jt == 0 {
141				return JumpIf{
142					Cond:      JumpLessThan,
143					Val:       ri.K,
144					SkipTrue:  ri.Jf,
145					SkipFalse: 0,
146				}
147			}
148			return JumpIf{
149				Cond:      JumpGreaterOrEqual,
150				Val:       ri.K,
151				SkipTrue:  ri.Jt,
152				SkipFalse: ri.Jf,
153			}
154		case opJumpSet:
155			return JumpIf{
156				Cond:      JumpBitsSet,
157				Val:       ri.K,
158				SkipTrue:  ri.Jt,
159				SkipFalse: ri.Jf,
160			}
161		default:
162			return ri
163		}
164
165	case opClsReturn:
166		switch ri.Op {
167		case opClsReturn | opRetSrcA:
168			return RetA{}
169		case opClsReturn | opRetSrcConstant:
170			return RetConstant{Val: ri.K}
171		default:
172			return ri
173		}
174
175	case opClsMisc:
176		switch ri.Op {
177		case opClsMisc | opMiscTAX:
178			return TAX{}
179		case opClsMisc | opMiscTXA:
180			return TXA{}
181		default:
182			return ri
183		}
184
185	default:
186		panic("unreachable") // switch is exhaustive on the bit pattern
187	}
188}
189
190// LoadConstant loads Val into register Dst.
191type LoadConstant struct {
192	Dst Register
193	Val uint32
194}
195
196// Assemble implements the Instruction Assemble method.
197func (a LoadConstant) Assemble() (RawInstruction, error) {
198	return assembleLoad(a.Dst, 4, opAddrModeImmediate, a.Val)
199}
200
201// String returns the instruction in assembler notation.
202func (a LoadConstant) String() string {
203	switch a.Dst {
204	case RegA:
205		return fmt.Sprintf("ld #%d", a.Val)
206	case RegX:
207		return fmt.Sprintf("ldx #%d", a.Val)
208	default:
209		return fmt.Sprintf("unknown instruction: %#v", a)
210	}
211}
212
213// LoadScratch loads scratch[N] into register Dst.
214type LoadScratch struct {
215	Dst Register
216	N   int // 0-15
217}
218
219// Assemble implements the Instruction Assemble method.
220func (a LoadScratch) Assemble() (RawInstruction, error) {
221	if a.N < 0 || a.N > 15 {
222		return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N)
223	}
224	return assembleLoad(a.Dst, 4, opAddrModeScratch, uint32(a.N))
225}
226
227// String returns the instruction in assembler notation.
228func (a LoadScratch) String() string {
229	switch a.Dst {
230	case RegA:
231		return fmt.Sprintf("ld M[%d]", a.N)
232	case RegX:
233		return fmt.Sprintf("ldx M[%d]", a.N)
234	default:
235		return fmt.Sprintf("unknown instruction: %#v", a)
236	}
237}
238
239// LoadAbsolute loads packet[Off:Off+Size] as an integer value into
240// register A.
241type LoadAbsolute struct {
242	Off  uint32
243	Size int // 1, 2 or 4
244}
245
246// Assemble implements the Instruction Assemble method.
247func (a LoadAbsolute) Assemble() (RawInstruction, error) {
248	return assembleLoad(RegA, a.Size, opAddrModeAbsolute, a.Off)
249}
250
251// String returns the instruction in assembler notation.
252func (a LoadAbsolute) String() string {
253	switch a.Size {
254	case 1: // byte
255		return fmt.Sprintf("ldb [%d]", a.Off)
256	case 2: // half word
257		return fmt.Sprintf("ldh [%d]", a.Off)
258	case 4: // word
259		if a.Off > extOffset+0xffffffff {
260			return LoadExtension{Num: Extension(a.Off + 0x1000)}.String()
261		}
262		return fmt.Sprintf("ld [%d]", a.Off)
263	default:
264		return fmt.Sprintf("unknown instruction: %#v", a)
265	}
266}
267
268// LoadIndirect loads packet[X+Off:X+Off+Size] as an integer value
269// into register A.
270type LoadIndirect struct {
271	Off  uint32
272	Size int // 1, 2 or 4
273}
274
275// Assemble implements the Instruction Assemble method.
276func (a LoadIndirect) Assemble() (RawInstruction, error) {
277	return assembleLoad(RegA, a.Size, opAddrModeIndirect, a.Off)
278}
279
280// String returns the instruction in assembler notation.
281func (a LoadIndirect) String() string {
282	switch a.Size {
283	case 1: // byte
284		return fmt.Sprintf("ldb [x + %d]", a.Off)
285	case 2: // half word
286		return fmt.Sprintf("ldh [x + %d]", a.Off)
287	case 4: // word
288		return fmt.Sprintf("ld [x + %d]", a.Off)
289	default:
290		return fmt.Sprintf("unknown instruction: %#v", a)
291	}
292}
293
294// LoadMemShift multiplies the first 4 bits of the byte at packet[Off]
295// by 4 and stores the result in register X.
296//
297// This instruction is mainly useful to load into X the length of an
298// IPv4 packet header in a single instruction, rather than have to do
299// the arithmetic on the header's first byte by hand.
300type LoadMemShift struct {
301	Off uint32
302}
303
304// Assemble implements the Instruction Assemble method.
305func (a LoadMemShift) Assemble() (RawInstruction, error) {
306	return assembleLoad(RegX, 1, opAddrModeMemShift, a.Off)
307}
308
309// String returns the instruction in assembler notation.
310func (a LoadMemShift) String() string {
311	return fmt.Sprintf("ldx 4*([%d]&0xf)", a.Off)
312}
313
314// LoadExtension invokes a linux-specific extension and stores the
315// result in register A.
316type LoadExtension struct {
317	Num Extension
318}
319
320// Assemble implements the Instruction Assemble method.
321func (a LoadExtension) Assemble() (RawInstruction, error) {
322	if a.Num == ExtLen {
323		return assembleLoad(RegA, 4, opAddrModePacketLen, 0)
324	}
325	return assembleLoad(RegA, 4, opAddrModeAbsolute, uint32(extOffset+a.Num))
326}
327
328// String returns the instruction in assembler notation.
329func (a LoadExtension) String() string {
330	switch a.Num {
331	case ExtLen:
332		return "ld #len"
333	case ExtProto:
334		return "ld #proto"
335	case ExtType:
336		return "ld #type"
337	case ExtPayloadOffset:
338		return "ld #poff"
339	case ExtInterfaceIndex:
340		return "ld #ifidx"
341	case ExtNetlinkAttr:
342		return "ld #nla"
343	case ExtNetlinkAttrNested:
344		return "ld #nlan"
345	case ExtMark:
346		return "ld #mark"
347	case ExtQueue:
348		return "ld #queue"
349	case ExtLinkLayerType:
350		return "ld #hatype"
351	case ExtRXHash:
352		return "ld #rxhash"
353	case ExtCPUID:
354		return "ld #cpu"
355	case ExtVLANTag:
356		return "ld #vlan_tci"
357	case ExtVLANTagPresent:
358		return "ld #vlan_avail"
359	case ExtVLANProto:
360		return "ld #vlan_tpid"
361	case ExtRand:
362		return "ld #rand"
363	default:
364		return fmt.Sprintf("unknown instruction: %#v", a)
365	}
366}
367
368// StoreScratch stores register Src into scratch[N].
369type StoreScratch struct {
370	Src Register
371	N   int // 0-15
372}
373
374// Assemble implements the Instruction Assemble method.
375func (a StoreScratch) Assemble() (RawInstruction, error) {
376	if a.N < 0 || a.N > 15 {
377		return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N)
378	}
379	var op uint16
380	switch a.Src {
381	case RegA:
382		op = opClsStoreA
383	case RegX:
384		op = opClsStoreX
385	default:
386		return RawInstruction{}, fmt.Errorf("invalid source register %v", a.Src)
387	}
388
389	return RawInstruction{
390		Op: op,
391		K:  uint32(a.N),
392	}, nil
393}
394
395// String returns the instruction in assembler notation.
396func (a StoreScratch) String() string {
397	switch a.Src {
398	case RegA:
399		return fmt.Sprintf("st M[%d]", a.N)
400	case RegX:
401		return fmt.Sprintf("stx M[%d]", a.N)
402	default:
403		return fmt.Sprintf("unknown instruction: %#v", a)
404	}
405}
406
407// ALUOpConstant executes A = A <Op> Val.
408type ALUOpConstant struct {
409	Op  ALUOp
410	Val uint32
411}
412
413// Assemble implements the Instruction Assemble method.
414func (a ALUOpConstant) Assemble() (RawInstruction, error) {
415	return RawInstruction{
416		Op: opClsALU | opALUSrcConstant | uint16(a.Op),
417		K:  a.Val,
418	}, nil
419}
420
421// String returns the instruction in assembler notation.
422func (a ALUOpConstant) String() string {
423	switch a.Op {
424	case ALUOpAdd:
425		return fmt.Sprintf("add #%d", a.Val)
426	case ALUOpSub:
427		return fmt.Sprintf("sub #%d", a.Val)
428	case ALUOpMul:
429		return fmt.Sprintf("mul #%d", a.Val)
430	case ALUOpDiv:
431		return fmt.Sprintf("div #%d", a.Val)
432	case ALUOpMod:
433		return fmt.Sprintf("mod #%d", a.Val)
434	case ALUOpAnd:
435		return fmt.Sprintf("and #%d", a.Val)
436	case ALUOpOr:
437		return fmt.Sprintf("or #%d", a.Val)
438	case ALUOpXor:
439		return fmt.Sprintf("xor #%d", a.Val)
440	case ALUOpShiftLeft:
441		return fmt.Sprintf("lsh #%d", a.Val)
442	case ALUOpShiftRight:
443		return fmt.Sprintf("rsh #%d", a.Val)
444	default:
445		return fmt.Sprintf("unknown instruction: %#v", a)
446	}
447}
448
449// ALUOpX executes A = A <Op> X
450type ALUOpX struct {
451	Op ALUOp
452}
453
454// Assemble implements the Instruction Assemble method.
455func (a ALUOpX) Assemble() (RawInstruction, error) {
456	return RawInstruction{
457		Op: opClsALU | opALUSrcX | uint16(a.Op),
458	}, nil
459}
460
461// String returns the instruction in assembler notation.
462func (a ALUOpX) String() string {
463	switch a.Op {
464	case ALUOpAdd:
465		return "add x"
466	case ALUOpSub:
467		return "sub x"
468	case ALUOpMul:
469		return "mul x"
470	case ALUOpDiv:
471		return "div x"
472	case ALUOpMod:
473		return "mod x"
474	case ALUOpAnd:
475		return "and x"
476	case ALUOpOr:
477		return "or x"
478	case ALUOpXor:
479		return "xor x"
480	case ALUOpShiftLeft:
481		return "lsh x"
482	case ALUOpShiftRight:
483		return "rsh x"
484	default:
485		return fmt.Sprintf("unknown instruction: %#v", a)
486	}
487}
488
489// NegateA executes A = -A.
490type NegateA struct{}
491
492// Assemble implements the Instruction Assemble method.
493func (a NegateA) Assemble() (RawInstruction, error) {
494	return RawInstruction{
495		Op: opClsALU | uint16(aluOpNeg),
496	}, nil
497}
498
499// String returns the instruction in assembler notation.
500func (a NegateA) String() string {
501	return fmt.Sprintf("neg")
502}
503
504// Jump skips the following Skip instructions in the program.
505type Jump struct {
506	Skip uint32
507}
508
509// Assemble implements the Instruction Assemble method.
510func (a Jump) Assemble() (RawInstruction, error) {
511	return RawInstruction{
512		Op: opClsJump | opJumpAlways,
513		K:  a.Skip,
514	}, nil
515}
516
517// String returns the instruction in assembler notation.
518func (a Jump) String() string {
519	return fmt.Sprintf("ja %d", a.Skip)
520}
521
522// JumpIf skips the following Skip instructions in the program if A
523// <Cond> Val is true.
524type JumpIf struct {
525	Cond      JumpTest
526	Val       uint32
527	SkipTrue  uint8
528	SkipFalse uint8
529}
530
531// Assemble implements the Instruction Assemble method.
532func (a JumpIf) Assemble() (RawInstruction, error) {
533	var (
534		cond uint16
535		flip bool
536	)
537	switch a.Cond {
538	case JumpEqual:
539		cond = opJumpEqual
540	case JumpNotEqual:
541		cond, flip = opJumpEqual, true
542	case JumpGreaterThan:
543		cond = opJumpGT
544	case JumpLessThan:
545		cond, flip = opJumpGE, true
546	case JumpGreaterOrEqual:
547		cond = opJumpGE
548	case JumpLessOrEqual:
549		cond, flip = opJumpGT, true
550	case JumpBitsSet:
551		cond = opJumpSet
552	case JumpBitsNotSet:
553		cond, flip = opJumpSet, true
554	default:
555		return RawInstruction{}, fmt.Errorf("unknown JumpTest %v", a.Cond)
556	}
557	jt, jf := a.SkipTrue, a.SkipFalse
558	if flip {
559		jt, jf = jf, jt
560	}
561	return RawInstruction{
562		Op: opClsJump | cond,
563		Jt: jt,
564		Jf: jf,
565		K:  a.Val,
566	}, nil
567}
568
569// String returns the instruction in assembler notation.
570func (a JumpIf) String() string {
571	switch a.Cond {
572	// K == A
573	case JumpEqual:
574		return conditionalJump(a, "jeq", "jneq")
575	// K != A
576	case JumpNotEqual:
577		return fmt.Sprintf("jneq #%d,%d", a.Val, a.SkipTrue)
578	// K > A
579	case JumpGreaterThan:
580		return conditionalJump(a, "jgt", "jle")
581	// K < A
582	case JumpLessThan:
583		return fmt.Sprintf("jlt #%d,%d", a.Val, a.SkipTrue)
584	// K >= A
585	case JumpGreaterOrEqual:
586		return conditionalJump(a, "jge", "jlt")
587	// K <= A
588	case JumpLessOrEqual:
589		return fmt.Sprintf("jle #%d,%d", a.Val, a.SkipTrue)
590	// K & A != 0
591	case JumpBitsSet:
592		if a.SkipFalse > 0 {
593			return fmt.Sprintf("jset #%d,%d,%d", a.Val, a.SkipTrue, a.SkipFalse)
594		}
595		return fmt.Sprintf("jset #%d,%d", a.Val, a.SkipTrue)
596	// K & A == 0, there is no assembler instruction for JumpBitNotSet, use JumpBitSet and invert skips
597	case JumpBitsNotSet:
598		return JumpIf{Cond: JumpBitsSet, SkipTrue: a.SkipFalse, SkipFalse: a.SkipTrue, Val: a.Val}.String()
599	default:
600		return fmt.Sprintf("unknown instruction: %#v", a)
601	}
602}
603
604func conditionalJump(inst JumpIf, positiveJump, negativeJump string) string {
605	if inst.SkipTrue > 0 {
606		if inst.SkipFalse > 0 {
607			return fmt.Sprintf("%s #%d,%d,%d", positiveJump, inst.Val, inst.SkipTrue, inst.SkipFalse)
608		}
609		return fmt.Sprintf("%s #%d,%d", positiveJump, inst.Val, inst.SkipTrue)
610	}
611	return fmt.Sprintf("%s #%d,%d", negativeJump, inst.Val, inst.SkipFalse)
612}
613
614// RetA exits the BPF program, returning the value of register A.
615type RetA struct{}
616
617// Assemble implements the Instruction Assemble method.
618func (a RetA) Assemble() (RawInstruction, error) {
619	return RawInstruction{
620		Op: opClsReturn | opRetSrcA,
621	}, nil
622}
623
624// String returns the instruction in assembler notation.
625func (a RetA) String() string {
626	return fmt.Sprintf("ret a")
627}
628
629// RetConstant exits the BPF program, returning a constant value.
630type RetConstant struct {
631	Val uint32
632}
633
634// Assemble implements the Instruction Assemble method.
635func (a RetConstant) Assemble() (RawInstruction, error) {
636	return RawInstruction{
637		Op: opClsReturn | opRetSrcConstant,
638		K:  a.Val,
639	}, nil
640}
641
642// String returns the instruction in assembler notation.
643func (a RetConstant) String() string {
644	return fmt.Sprintf("ret #%d", a.Val)
645}
646
647// TXA copies the value of register X to register A.
648type TXA struct{}
649
650// Assemble implements the Instruction Assemble method.
651func (a TXA) Assemble() (RawInstruction, error) {
652	return RawInstruction{
653		Op: opClsMisc | opMiscTXA,
654	}, nil
655}
656
657// String returns the instruction in assembler notation.
658func (a TXA) String() string {
659	return fmt.Sprintf("txa")
660}
661
662// TAX copies the value of register A to register X.
663type TAX struct{}
664
665// Assemble implements the Instruction Assemble method.
666func (a TAX) Assemble() (RawInstruction, error) {
667	return RawInstruction{
668		Op: opClsMisc | opMiscTAX,
669	}, nil
670}
671
672// String returns the instruction in assembler notation.
673func (a TAX) String() string {
674	return fmt.Sprintf("tax")
675}
676
677func assembleLoad(dst Register, loadSize int, mode uint16, k uint32) (RawInstruction, error) {
678	var (
679		cls uint16
680		sz  uint16
681	)
682	switch dst {
683	case RegA:
684		cls = opClsLoadA
685	case RegX:
686		cls = opClsLoadX
687	default:
688		return RawInstruction{}, fmt.Errorf("invalid target register %v", dst)
689	}
690	switch loadSize {
691	case 1:
692		sz = opLoadWidth1
693	case 2:
694		sz = opLoadWidth2
695	case 4:
696		sz = opLoadWidth4
697	default:
698		return RawInstruction{}, fmt.Errorf("invalid load byte length %d", sz)
699	}
700	return RawInstruction{
701		Op: cls | sz | mode,
702		K:  k,
703	}, nil
704}
705