1package asm
2
3import (
4	"encoding/binary"
5	"errors"
6	"fmt"
7	"io"
8	"math"
9	"strings"
10)
11
12// InstructionSize is the size of a BPF instruction in bytes
13const InstructionSize = 8
14
15// RawInstructionOffset is an offset in units of raw BPF instructions.
16type RawInstructionOffset uint64
17
18// Bytes returns the offset of an instruction in bytes.
19func (rio RawInstructionOffset) Bytes() uint64 {
20	return uint64(rio) * InstructionSize
21}
22
23// Instruction is a single eBPF instruction.
24type Instruction struct {
25	OpCode    OpCode
26	Dst       Register
27	Src       Register
28	Offset    int16
29	Constant  int64
30	Reference string
31	Symbol    string
32}
33
34// Sym creates a symbol.
35func (ins Instruction) Sym(name string) Instruction {
36	ins.Symbol = name
37	return ins
38}
39
40// Unmarshal decodes a BPF instruction.
41func (ins *Instruction) Unmarshal(r io.Reader, bo binary.ByteOrder) (uint64, error) {
42	var bi bpfInstruction
43	err := binary.Read(r, bo, &bi)
44	if err != nil {
45		return 0, err
46	}
47
48	ins.OpCode = bi.OpCode
49	ins.Offset = bi.Offset
50	ins.Constant = int64(bi.Constant)
51	ins.Dst, ins.Src, err = bi.Registers.Unmarshal(bo)
52	if err != nil {
53		return 0, fmt.Errorf("can't unmarshal registers: %s", err)
54	}
55
56	if !bi.OpCode.isDWordLoad() {
57		return InstructionSize, nil
58	}
59
60	var bi2 bpfInstruction
61	if err := binary.Read(r, bo, &bi2); err != nil {
62		// No Wrap, to avoid io.EOF clash
63		return 0, errors.New("64bit immediate is missing second half")
64	}
65	if bi2.OpCode != 0 || bi2.Offset != 0 || bi2.Registers != 0 {
66		return 0, errors.New("64bit immediate has non-zero fields")
67	}
68	ins.Constant = int64(uint64(uint32(bi2.Constant))<<32 | uint64(uint32(bi.Constant)))
69
70	return 2 * InstructionSize, nil
71}
72
73// Marshal encodes a BPF instruction.
74func (ins Instruction) Marshal(w io.Writer, bo binary.ByteOrder) (uint64, error) {
75	if ins.OpCode == InvalidOpCode {
76		return 0, errors.New("invalid opcode")
77	}
78
79	isDWordLoad := ins.OpCode.isDWordLoad()
80
81	cons := int32(ins.Constant)
82	if isDWordLoad {
83		// Encode least significant 32bit first for 64bit operations.
84		cons = int32(uint32(ins.Constant))
85	}
86
87	regs, err := newBPFRegisters(ins.Dst, ins.Src, bo)
88	if err != nil {
89		return 0, fmt.Errorf("can't marshal registers: %s", err)
90	}
91
92	bpfi := bpfInstruction{
93		ins.OpCode,
94		regs,
95		ins.Offset,
96		cons,
97	}
98
99	if err := binary.Write(w, bo, &bpfi); err != nil {
100		return 0, err
101	}
102
103	if !isDWordLoad {
104		return InstructionSize, nil
105	}
106
107	bpfi = bpfInstruction{
108		Constant: int32(ins.Constant >> 32),
109	}
110
111	if err := binary.Write(w, bo, &bpfi); err != nil {
112		return 0, err
113	}
114
115	return 2 * InstructionSize, nil
116}
117
118// RewriteMapPtr changes an instruction to use a new map fd.
119//
120// Returns an error if the instruction doesn't load a map.
121func (ins *Instruction) RewriteMapPtr(fd int) error {
122	if !ins.OpCode.isDWordLoad() {
123		return fmt.Errorf("%s is not a 64 bit load", ins.OpCode)
124	}
125
126	if ins.Src != PseudoMapFD && ins.Src != PseudoMapValue {
127		return errors.New("not a load from a map")
128	}
129
130	// Preserve the offset value for direct map loads.
131	offset := uint64(ins.Constant) & (math.MaxUint32 << 32)
132	rawFd := uint64(uint32(fd))
133	ins.Constant = int64(offset | rawFd)
134	return nil
135}
136
137func (ins *Instruction) mapPtr() uint32 {
138	return uint32(uint64(ins.Constant) & math.MaxUint32)
139}
140
141// RewriteMapOffset changes the offset of a direct load from a map.
142//
143// Returns an error if the instruction is not a direct load.
144func (ins *Instruction) RewriteMapOffset(offset uint32) error {
145	if !ins.OpCode.isDWordLoad() {
146		return fmt.Errorf("%s is not a 64 bit load", ins.OpCode)
147	}
148
149	if ins.Src != PseudoMapValue {
150		return errors.New("not a direct load from a map")
151	}
152
153	fd := uint64(ins.Constant) & math.MaxUint32
154	ins.Constant = int64(uint64(offset)<<32 | fd)
155	return nil
156}
157
158func (ins *Instruction) mapOffset() uint32 {
159	return uint32(uint64(ins.Constant) >> 32)
160}
161
162func (ins *Instruction) isLoadFromMap() bool {
163	return ins.OpCode == LoadImmOp(DWord) && (ins.Src == PseudoMapFD || ins.Src == PseudoMapValue)
164}
165
166// IsFunctionCall returns true if the instruction calls another BPF function.
167//
168// This is not the same thing as a BPF helper call.
169func (ins *Instruction) IsFunctionCall() bool {
170	return ins.OpCode.JumpOp() == Call && ins.Src == PseudoCall
171}
172
173// Format implements fmt.Formatter.
174func (ins Instruction) Format(f fmt.State, c rune) {
175	if c != 'v' {
176		fmt.Fprintf(f, "{UNRECOGNIZED: %c}", c)
177		return
178	}
179
180	op := ins.OpCode
181
182	if op == InvalidOpCode {
183		fmt.Fprint(f, "INVALID")
184		return
185	}
186
187	// Omit trailing space for Exit
188	if op.JumpOp() == Exit {
189		fmt.Fprint(f, op)
190		return
191	}
192
193	if ins.isLoadFromMap() {
194		fd := int32(ins.mapPtr())
195		switch ins.Src {
196		case PseudoMapFD:
197			fmt.Fprintf(f, "LoadMapPtr dst: %s fd: %d", ins.Dst, fd)
198
199		case PseudoMapValue:
200			fmt.Fprintf(f, "LoadMapValue dst: %s, fd: %d off: %d", ins.Dst, fd, ins.mapOffset())
201		}
202
203		goto ref
204	}
205
206	fmt.Fprintf(f, "%v ", op)
207	switch cls := op.Class(); cls {
208	case LdClass, LdXClass, StClass, StXClass:
209		switch op.Mode() {
210		case ImmMode:
211			fmt.Fprintf(f, "dst: %s imm: %d", ins.Dst, ins.Constant)
212		case AbsMode:
213			fmt.Fprintf(f, "imm: %d", ins.Constant)
214		case IndMode:
215			fmt.Fprintf(f, "dst: %s src: %s imm: %d", ins.Dst, ins.Src, ins.Constant)
216		case MemMode:
217			fmt.Fprintf(f, "dst: %s src: %s off: %d imm: %d", ins.Dst, ins.Src, ins.Offset, ins.Constant)
218		case XAddMode:
219			fmt.Fprintf(f, "dst: %s src: %s", ins.Dst, ins.Src)
220		}
221
222	case ALU64Class, ALUClass:
223		fmt.Fprintf(f, "dst: %s ", ins.Dst)
224		if op.ALUOp() == Swap || op.Source() == ImmSource {
225			fmt.Fprintf(f, "imm: %d", ins.Constant)
226		} else {
227			fmt.Fprintf(f, "src: %s", ins.Src)
228		}
229
230	case JumpClass:
231		switch jop := op.JumpOp(); jop {
232		case Call:
233			if ins.Src == PseudoCall {
234				// bpf-to-bpf call
235				fmt.Fprint(f, ins.Constant)
236			} else {
237				fmt.Fprint(f, BuiltinFunc(ins.Constant))
238			}
239
240		default:
241			fmt.Fprintf(f, "dst: %s off: %d ", ins.Dst, ins.Offset)
242			if op.Source() == ImmSource {
243				fmt.Fprintf(f, "imm: %d", ins.Constant)
244			} else {
245				fmt.Fprintf(f, "src: %s", ins.Src)
246			}
247		}
248	}
249
250ref:
251	if ins.Reference != "" {
252		fmt.Fprintf(f, " <%s>", ins.Reference)
253	}
254}
255
256// Instructions is an eBPF program.
257type Instructions []Instruction
258
259func (insns Instructions) String() string {
260	return fmt.Sprint(insns)
261}
262
263// RewriteMapPtr rewrites all loads of a specific map pointer to a new fd.
264//
265// Returns an error if the symbol isn't used, see IsUnreferencedSymbol.
266func (insns Instructions) RewriteMapPtr(symbol string, fd int) error {
267	if symbol == "" {
268		return errors.New("empty symbol")
269	}
270
271	found := false
272	for i := range insns {
273		ins := &insns[i]
274		if ins.Reference != symbol {
275			continue
276		}
277
278		if err := ins.RewriteMapPtr(fd); err != nil {
279			return err
280		}
281
282		found = true
283	}
284
285	if !found {
286		return &unreferencedSymbolError{symbol}
287	}
288
289	return nil
290}
291
292// SymbolOffsets returns the set of symbols and their offset in
293// the instructions.
294func (insns Instructions) SymbolOffsets() (map[string]int, error) {
295	offsets := make(map[string]int)
296
297	for i, ins := range insns {
298		if ins.Symbol == "" {
299			continue
300		}
301
302		if _, ok := offsets[ins.Symbol]; ok {
303			return nil, fmt.Errorf("duplicate symbol %s", ins.Symbol)
304		}
305
306		offsets[ins.Symbol] = i
307	}
308
309	return offsets, nil
310}
311
312// ReferenceOffsets returns the set of references and their offset in
313// the instructions.
314func (insns Instructions) ReferenceOffsets() map[string][]int {
315	offsets := make(map[string][]int)
316
317	for i, ins := range insns {
318		if ins.Reference == "" {
319			continue
320		}
321
322		offsets[ins.Reference] = append(offsets[ins.Reference], i)
323	}
324
325	return offsets
326}
327
328// Format implements fmt.Formatter.
329//
330// You can control indentation of symbols by
331// specifying a width. Setting a precision controls the indentation of
332// instructions.
333// The default character is a tab, which can be overriden by specifying
334// the ' ' space flag.
335func (insns Instructions) Format(f fmt.State, c rune) {
336	if c != 's' && c != 'v' {
337		fmt.Fprintf(f, "{UNKNOWN FORMAT '%c'}", c)
338		return
339	}
340
341	// Precision is better in this case, because it allows
342	// specifying 0 padding easily.
343	padding, ok := f.Precision()
344	if !ok {
345		padding = 1
346	}
347
348	indent := strings.Repeat("\t", padding)
349	if f.Flag(' ') {
350		indent = strings.Repeat(" ", padding)
351	}
352
353	symPadding, ok := f.Width()
354	if !ok {
355		symPadding = padding - 1
356	}
357	if symPadding < 0 {
358		symPadding = 0
359	}
360
361	symIndent := strings.Repeat("\t", symPadding)
362	if f.Flag(' ') {
363		symIndent = strings.Repeat(" ", symPadding)
364	}
365
366	// Guess how many digits we need at most, by assuming that all instructions
367	// are double wide.
368	highestOffset := len(insns) * 2
369	offsetWidth := int(math.Ceil(math.Log10(float64(highestOffset))))
370
371	iter := insns.Iterate()
372	for iter.Next() {
373		if iter.Ins.Symbol != "" {
374			fmt.Fprintf(f, "%s%s:\n", symIndent, iter.Ins.Symbol)
375		}
376		fmt.Fprintf(f, "%s%*d: %v\n", indent, offsetWidth, iter.Offset, iter.Ins)
377	}
378
379	return
380}
381
382// Marshal encodes a BPF program into the kernel format.
383func (insns Instructions) Marshal(w io.Writer, bo binary.ByteOrder) error {
384	for i, ins := range insns {
385		_, err := ins.Marshal(w, bo)
386		if err != nil {
387			return fmt.Errorf("instruction %d: %w", i, err)
388		}
389	}
390	return nil
391}
392
393// Iterate allows iterating a BPF program while keeping track of
394// various offsets.
395//
396// Modifying the instruction slice will lead to undefined behaviour.
397func (insns Instructions) Iterate() *InstructionIterator {
398	return &InstructionIterator{insns: insns}
399}
400
401// InstructionIterator iterates over a BPF program.
402type InstructionIterator struct {
403	insns Instructions
404	// The instruction in question.
405	Ins *Instruction
406	// The index of the instruction in the original instruction slice.
407	Index int
408	// The offset of the instruction in raw BPF instructions. This accounts
409	// for double-wide instructions.
410	Offset RawInstructionOffset
411}
412
413// Next returns true as long as there are any instructions remaining.
414func (iter *InstructionIterator) Next() bool {
415	if len(iter.insns) == 0 {
416		return false
417	}
418
419	if iter.Ins != nil {
420		iter.Offset += RawInstructionOffset(iter.Ins.OpCode.rawInstructions())
421	}
422	iter.Ins = &iter.insns[0]
423	iter.insns = iter.insns[1:]
424	return true
425}
426
427type bpfInstruction struct {
428	OpCode    OpCode
429	Registers bpfRegisters
430	Offset    int16
431	Constant  int32
432}
433
434type bpfRegisters uint8
435
436func newBPFRegisters(dst, src Register, bo binary.ByteOrder) (bpfRegisters, error) {
437	switch bo {
438	case binary.LittleEndian:
439		return bpfRegisters((src << 4) | (dst & 0xF)), nil
440	case binary.BigEndian:
441		return bpfRegisters((dst << 4) | (src & 0xF)), nil
442	default:
443		return 0, fmt.Errorf("unrecognized ByteOrder %T", bo)
444	}
445}
446
447func (r bpfRegisters) Unmarshal(bo binary.ByteOrder) (dst, src Register, err error) {
448	switch bo {
449	case binary.LittleEndian:
450		return Register(r & 0xF), Register(r >> 4), nil
451	case binary.BigEndian:
452		return Register(r >> 4), Register(r & 0xf), nil
453	default:
454		return 0, 0, fmt.Errorf("unrecognized ByteOrder %T", bo)
455	}
456}
457
458type unreferencedSymbolError struct {
459	symbol string
460}
461
462func (use *unreferencedSymbolError) Error() string {
463	return fmt.Sprintf("unreferenced symbol %s", use.symbol)
464}
465
466// IsUnreferencedSymbol returns true if err was caused by
467// an unreferenced symbol.
468func IsUnreferencedSymbol(err error) bool {
469	_, ok := err.(*unreferencedSymbolError)
470	return ok
471}
472