1package asm
2
3import (
4	"encoding/binary"
5	"errors"
6	"fmt"
7	"io"
8	"math"
9	"strings"
10)
11
12// InstructionSize is the size of a BPF instruction in bytes
13const InstructionSize = 8
14
15// Instruction is a single eBPF instruction.
16type Instruction struct {
17	OpCode    OpCode
18	Dst       Register
19	Src       Register
20	Offset    int16
21	Constant  int64
22	Reference string
23	Symbol    string
24}
25
26// Sym creates a symbol.
27func (ins Instruction) Sym(name string) Instruction {
28	ins.Symbol = name
29	return ins
30}
31
32// Unmarshal decodes a BPF instruction.
33func (ins *Instruction) Unmarshal(r io.Reader, bo binary.ByteOrder) (uint64, error) {
34	var bi bpfInstruction
35	err := binary.Read(r, bo, &bi)
36	if err != nil {
37		return 0, err
38	}
39
40	ins.OpCode = bi.OpCode
41	ins.Offset = bi.Offset
42	ins.Constant = int64(bi.Constant)
43	ins.Dst, ins.Src, err = bi.Registers.Unmarshal(bo)
44	if err != nil {
45		return 0, fmt.Errorf("can't unmarshal registers: %s", err)
46	}
47
48	if !bi.OpCode.isDWordLoad() {
49		return InstructionSize, nil
50	}
51
52	var bi2 bpfInstruction
53	if err := binary.Read(r, bo, &bi2); err != nil {
54		// No Wrap, to avoid io.EOF clash
55		return 0, errors.New("64bit immediate is missing second half")
56	}
57	if bi2.OpCode != 0 || bi2.Offset != 0 || bi2.Registers != 0 {
58		return 0, errors.New("64bit immediate has non-zero fields")
59	}
60	ins.Constant = int64(uint64(uint32(bi2.Constant))<<32 | uint64(uint32(bi.Constant)))
61
62	return 2 * InstructionSize, nil
63}
64
65// Marshal encodes a BPF instruction.
66func (ins Instruction) Marshal(w io.Writer, bo binary.ByteOrder) (uint64, error) {
67	if ins.OpCode == InvalidOpCode {
68		return 0, errors.New("invalid opcode")
69	}
70
71	isDWordLoad := ins.OpCode.isDWordLoad()
72
73	cons := int32(ins.Constant)
74	if isDWordLoad {
75		// Encode least significant 32bit first for 64bit operations.
76		cons = int32(uint32(ins.Constant))
77	}
78
79	regs, err := newBPFRegisters(ins.Dst, ins.Src, bo)
80	if err != nil {
81		return 0, fmt.Errorf("can't marshal registers: %s", err)
82	}
83
84	bpfi := bpfInstruction{
85		ins.OpCode,
86		regs,
87		ins.Offset,
88		cons,
89	}
90
91	if err := binary.Write(w, bo, &bpfi); err != nil {
92		return 0, err
93	}
94
95	if !isDWordLoad {
96		return InstructionSize, nil
97	}
98
99	bpfi = bpfInstruction{
100		Constant: int32(ins.Constant >> 32),
101	}
102
103	if err := binary.Write(w, bo, &bpfi); err != nil {
104		return 0, err
105	}
106
107	return 2 * InstructionSize, nil
108}
109
110// RewriteMapPtr changes an instruction to use a new map fd.
111//
112// Returns an error if the instruction doesn't load a map.
113func (ins *Instruction) RewriteMapPtr(fd int) error {
114	if !ins.OpCode.isDWordLoad() {
115		return fmt.Errorf("%s is not a 64 bit load", ins.OpCode)
116	}
117
118	if ins.Src != PseudoMapFD && ins.Src != PseudoMapValue {
119		return errors.New("not a load from a map")
120	}
121
122	// Preserve the offset value for direct map loads.
123	offset := uint64(ins.Constant) & (math.MaxUint32 << 32)
124	rawFd := uint64(uint32(fd))
125	ins.Constant = int64(offset | rawFd)
126	return nil
127}
128
129func (ins *Instruction) mapPtr() uint32 {
130	return uint32(uint64(ins.Constant) & math.MaxUint32)
131}
132
133// RewriteMapOffset changes the offset of a direct load from a map.
134//
135// Returns an error if the instruction is not a direct load.
136func (ins *Instruction) RewriteMapOffset(offset uint32) error {
137	if !ins.OpCode.isDWordLoad() {
138		return fmt.Errorf("%s is not a 64 bit load", ins.OpCode)
139	}
140
141	if ins.Src != PseudoMapValue {
142		return errors.New("not a direct load from a map")
143	}
144
145	fd := uint64(ins.Constant) & math.MaxUint32
146	ins.Constant = int64(uint64(offset)<<32 | fd)
147	return nil
148}
149
150func (ins *Instruction) mapOffset() uint32 {
151	return uint32(uint64(ins.Constant) >> 32)
152}
153
154func (ins *Instruction) isLoadFromMap() bool {
155	return ins.OpCode == LoadImmOp(DWord) && (ins.Src == PseudoMapFD || ins.Src == PseudoMapValue)
156}
157
158// Format implements fmt.Formatter.
159func (ins Instruction) Format(f fmt.State, c rune) {
160	if c != 'v' {
161		fmt.Fprintf(f, "{UNRECOGNIZED: %c}", c)
162		return
163	}
164
165	op := ins.OpCode
166
167	if op == InvalidOpCode {
168		fmt.Fprint(f, "INVALID")
169		return
170	}
171
172	// Omit trailing space for Exit
173	if op.JumpOp() == Exit {
174		fmt.Fprint(f, op)
175		return
176	}
177
178	if ins.isLoadFromMap() {
179		fd := int32(ins.mapPtr())
180		switch ins.Src {
181		case PseudoMapFD:
182			fmt.Fprintf(f, "LoadMapPtr dst: %s fd: %d", ins.Dst, fd)
183
184		case PseudoMapValue:
185			fmt.Fprintf(f, "LoadMapValue dst: %s, fd: %d off: %d", ins.Dst, fd, ins.mapOffset())
186		}
187
188		goto ref
189	}
190
191	fmt.Fprintf(f, "%v ", op)
192	switch cls := op.Class(); cls {
193	case LdClass, LdXClass, StClass, StXClass:
194		switch op.Mode() {
195		case ImmMode:
196			fmt.Fprintf(f, "dst: %s imm: %d", ins.Dst, ins.Constant)
197		case AbsMode:
198			fmt.Fprintf(f, "imm: %d", ins.Constant)
199		case IndMode:
200			fmt.Fprintf(f, "dst: %s src: %s imm: %d", ins.Dst, ins.Src, ins.Constant)
201		case MemMode:
202			fmt.Fprintf(f, "dst: %s src: %s off: %d imm: %d", ins.Dst, ins.Src, ins.Offset, ins.Constant)
203		case XAddMode:
204			fmt.Fprintf(f, "dst: %s src: %s", ins.Dst, ins.Src)
205		}
206
207	case ALU64Class, ALUClass:
208		fmt.Fprintf(f, "dst: %s ", ins.Dst)
209		if op.ALUOp() == Swap || op.Source() == ImmSource {
210			fmt.Fprintf(f, "imm: %d", ins.Constant)
211		} else {
212			fmt.Fprintf(f, "src: %s", ins.Src)
213		}
214
215	case JumpClass:
216		switch jop := op.JumpOp(); jop {
217		case Call:
218			if ins.Src == PseudoCall {
219				// bpf-to-bpf call
220				fmt.Fprint(f, ins.Constant)
221			} else {
222				fmt.Fprint(f, BuiltinFunc(ins.Constant))
223			}
224
225		default:
226			fmt.Fprintf(f, "dst: %s off: %d ", ins.Dst, ins.Offset)
227			if op.Source() == ImmSource {
228				fmt.Fprintf(f, "imm: %d", ins.Constant)
229			} else {
230				fmt.Fprintf(f, "src: %s", ins.Src)
231			}
232		}
233	}
234
235ref:
236	if ins.Reference != "" {
237		fmt.Fprintf(f, " <%s>", ins.Reference)
238	}
239}
240
241// Instructions is an eBPF program.
242type Instructions []Instruction
243
244func (insns Instructions) String() string {
245	return fmt.Sprint(insns)
246}
247
248// RewriteMapPtr rewrites all loads of a specific map pointer to a new fd.
249//
250// Returns an error if the symbol isn't used, see IsUnreferencedSymbol.
251func (insns Instructions) RewriteMapPtr(symbol string, fd int) error {
252	if symbol == "" {
253		return errors.New("empty symbol")
254	}
255
256	found := false
257	for i := range insns {
258		ins := &insns[i]
259		if ins.Reference != symbol {
260			continue
261		}
262
263		if err := ins.RewriteMapPtr(fd); err != nil {
264			return err
265		}
266
267		found = true
268	}
269
270	if !found {
271		return &unreferencedSymbolError{symbol}
272	}
273
274	return nil
275}
276
277// SymbolOffsets returns the set of symbols and their offset in
278// the instructions.
279func (insns Instructions) SymbolOffsets() (map[string]int, error) {
280	offsets := make(map[string]int)
281
282	for i, ins := range insns {
283		if ins.Symbol == "" {
284			continue
285		}
286
287		if _, ok := offsets[ins.Symbol]; ok {
288			return nil, fmt.Errorf("duplicate symbol %s", ins.Symbol)
289		}
290
291		offsets[ins.Symbol] = i
292	}
293
294	return offsets, nil
295}
296
297// ReferenceOffsets returns the set of references and their offset in
298// the instructions.
299func (insns Instructions) ReferenceOffsets() map[string][]int {
300	offsets := make(map[string][]int)
301
302	for i, ins := range insns {
303		if ins.Reference == "" {
304			continue
305		}
306
307		offsets[ins.Reference] = append(offsets[ins.Reference], i)
308	}
309
310	return offsets
311}
312
313func (insns Instructions) marshalledOffsets() (map[string]int, error) {
314	symbols := make(map[string]int)
315
316	marshalledPos := 0
317	for _, ins := range insns {
318		currentPos := marshalledPos
319		marshalledPos += ins.OpCode.marshalledInstructions()
320
321		if ins.Symbol == "" {
322			continue
323		}
324
325		if _, ok := symbols[ins.Symbol]; ok {
326			return nil, fmt.Errorf("duplicate symbol %s", ins.Symbol)
327		}
328
329		symbols[ins.Symbol] = currentPos
330	}
331
332	return symbols, nil
333}
334
335// Format implements fmt.Formatter.
336//
337// You can control indentation of symbols by
338// specifying a width. Setting a precision controls the indentation of
339// instructions.
340// The default character is a tab, which can be overriden by specifying
341// the ' ' space flag.
342func (insns Instructions) Format(f fmt.State, c rune) {
343	if c != 's' && c != 'v' {
344		fmt.Fprintf(f, "{UNKNOWN FORMAT '%c'}", c)
345		return
346	}
347
348	// Precision is better in this case, because it allows
349	// specifying 0 padding easily.
350	padding, ok := f.Precision()
351	if !ok {
352		padding = 1
353	}
354
355	indent := strings.Repeat("\t", padding)
356	if f.Flag(' ') {
357		indent = strings.Repeat(" ", padding)
358	}
359
360	symPadding, ok := f.Width()
361	if !ok {
362		symPadding = padding - 1
363	}
364	if symPadding < 0 {
365		symPadding = 0
366	}
367
368	symIndent := strings.Repeat("\t", symPadding)
369	if f.Flag(' ') {
370		symIndent = strings.Repeat(" ", symPadding)
371	}
372
373	// Figure out how many digits we need to represent the highest
374	// offset.
375	highestOffset := 0
376	for _, ins := range insns {
377		highestOffset += ins.OpCode.marshalledInstructions()
378	}
379	offsetWidth := int(math.Ceil(math.Log10(float64(highestOffset))))
380
381	offset := 0
382	for _, ins := range insns {
383		if ins.Symbol != "" {
384			fmt.Fprintf(f, "%s%s:\n", symIndent, ins.Symbol)
385		}
386		fmt.Fprintf(f, "%s%*d: %v\n", indent, offsetWidth, offset, ins)
387		offset += ins.OpCode.marshalledInstructions()
388	}
389
390	return
391}
392
393// Marshal encodes a BPF program into the kernel format.
394func (insns Instructions) Marshal(w io.Writer, bo binary.ByteOrder) error {
395	absoluteOffsets, err := insns.marshalledOffsets()
396	if err != nil {
397		return err
398	}
399
400	num := 0
401	for i, ins := range insns {
402		switch {
403		case ins.OpCode.JumpOp() == Call && ins.Src == PseudoCall && ins.Constant == -1:
404			// Rewrite bpf to bpf call
405			offset, ok := absoluteOffsets[ins.Reference]
406			if !ok {
407				return fmt.Errorf("instruction %d: reference to missing symbol %s", i, ins.Reference)
408			}
409
410			ins.Constant = int64(offset - num - 1)
411
412		case ins.OpCode.Class() == JumpClass && ins.Offset == -1:
413			// Rewrite jump to label
414			offset, ok := absoluteOffsets[ins.Reference]
415			if !ok {
416				return fmt.Errorf("instruction %d: reference to missing symbol %s", i, ins.Reference)
417			}
418
419			ins.Offset = int16(offset - num - 1)
420		}
421
422		n, err := ins.Marshal(w, bo)
423		if err != nil {
424			return fmt.Errorf("instruction %d: %w", i, err)
425		}
426
427		num += int(n / InstructionSize)
428	}
429	return nil
430}
431
432type bpfInstruction struct {
433	OpCode    OpCode
434	Registers bpfRegisters
435	Offset    int16
436	Constant  int32
437}
438
439type bpfRegisters uint8
440
441func newBPFRegisters(dst, src Register, bo binary.ByteOrder) (bpfRegisters, error) {
442	switch bo {
443	case binary.LittleEndian:
444		return bpfRegisters((src << 4) | (dst & 0xF)), nil
445	case binary.BigEndian:
446		return bpfRegisters((dst << 4) | (src & 0xF)), nil
447	default:
448		return 0, fmt.Errorf("unrecognized ByteOrder %T", bo)
449	}
450}
451
452func (r bpfRegisters) Unmarshal(bo binary.ByteOrder) (dst, src Register, err error) {
453	switch bo {
454	case binary.LittleEndian:
455		return Register(r & 0xF), Register(r >> 4), nil
456	case binary.BigEndian:
457		return Register(r >> 4), Register(r & 0xf), nil
458	default:
459		return 0, 0, fmt.Errorf("unrecognized ByteOrder %T", bo)
460	}
461}
462
463type unreferencedSymbolError struct {
464	symbol string
465}
466
467func (use *unreferencedSymbolError) Error() string {
468	return fmt.Sprintf("unreferenced symbol %s", use.symbol)
469}
470
471// IsUnreferencedSymbol returns true if err was caused by
472// an unreferenced symbol.
473func IsUnreferencedSymbol(err error) bool {
474	_, ok := err.(*unreferencedSymbolError)
475	return ok
476}
477