xref: /linux/arch/riscv/net/bpf_jit_comp64.c (revision 6801b0ae)
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7 
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include <linux/memory.h>
12 #include <linux/stop_machine.h>
13 #include <asm/patch.h>
14 #include <asm/cfi.h>
15 #include <asm/percpu.h>
16 #include "bpf_jit.h"
17 
18 #define RV_MAX_REG_ARGS 8
19 #define RV_FENTRY_NINSNS 2
20 /* imm that allows emit_imm to emit max count insns */
21 #define RV_MAX_COUNT_IMM 0x7FFF7FF7FF7FF7FF
22 
23 #define RV_REG_TCC RV_REG_A6
24 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
25 #define RV_REG_ARENA RV_REG_S7 /* For storing arena_vm_start */
26 
27 static const int regmap[] = {
28 	[BPF_REG_0] =	RV_REG_A5,
29 	[BPF_REG_1] =	RV_REG_A0,
30 	[BPF_REG_2] =	RV_REG_A1,
31 	[BPF_REG_3] =	RV_REG_A2,
32 	[BPF_REG_4] =	RV_REG_A3,
33 	[BPF_REG_5] =	RV_REG_A4,
34 	[BPF_REG_6] =	RV_REG_S1,
35 	[BPF_REG_7] =	RV_REG_S2,
36 	[BPF_REG_8] =	RV_REG_S3,
37 	[BPF_REG_9] =	RV_REG_S4,
38 	[BPF_REG_FP] =	RV_REG_S5,
39 	[BPF_REG_AX] =	RV_REG_T0,
40 };
41 
42 static const int pt_regmap[] = {
43 	[RV_REG_A0] = offsetof(struct pt_regs, a0),
44 	[RV_REG_A1] = offsetof(struct pt_regs, a1),
45 	[RV_REG_A2] = offsetof(struct pt_regs, a2),
46 	[RV_REG_A3] = offsetof(struct pt_regs, a3),
47 	[RV_REG_A4] = offsetof(struct pt_regs, a4),
48 	[RV_REG_A5] = offsetof(struct pt_regs, a5),
49 	[RV_REG_S1] = offsetof(struct pt_regs, s1),
50 	[RV_REG_S2] = offsetof(struct pt_regs, s2),
51 	[RV_REG_S3] = offsetof(struct pt_regs, s3),
52 	[RV_REG_S4] = offsetof(struct pt_regs, s4),
53 	[RV_REG_S5] = offsetof(struct pt_regs, s5),
54 	[RV_REG_T0] = offsetof(struct pt_regs, t0),
55 };
56 
57 enum {
58 	RV_CTX_F_SEEN_TAIL_CALL =	0,
59 	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
60 	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
61 	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
62 	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
63 	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
64 	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
65 	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
66 };
67 
68 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
69 {
70 	u8 reg = regmap[bpf_reg];
71 
72 	switch (reg) {
73 	case RV_CTX_F_SEEN_S1:
74 	case RV_CTX_F_SEEN_S2:
75 	case RV_CTX_F_SEEN_S3:
76 	case RV_CTX_F_SEEN_S4:
77 	case RV_CTX_F_SEEN_S5:
78 	case RV_CTX_F_SEEN_S6:
79 		__set_bit(reg, &ctx->flags);
80 	}
81 	return reg;
82 };
83 
84 static bool seen_reg(int reg, struct rv_jit_context *ctx)
85 {
86 	switch (reg) {
87 	case RV_CTX_F_SEEN_CALL:
88 	case RV_CTX_F_SEEN_S1:
89 	case RV_CTX_F_SEEN_S2:
90 	case RV_CTX_F_SEEN_S3:
91 	case RV_CTX_F_SEEN_S4:
92 	case RV_CTX_F_SEEN_S5:
93 	case RV_CTX_F_SEEN_S6:
94 		return test_bit(reg, &ctx->flags);
95 	}
96 	return false;
97 }
98 
99 static void mark_fp(struct rv_jit_context *ctx)
100 {
101 	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
102 }
103 
104 static void mark_call(struct rv_jit_context *ctx)
105 {
106 	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
107 }
108 
109 static bool seen_call(struct rv_jit_context *ctx)
110 {
111 	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
112 }
113 
114 static void mark_tail_call(struct rv_jit_context *ctx)
115 {
116 	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
117 }
118 
119 static bool seen_tail_call(struct rv_jit_context *ctx)
120 {
121 	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
122 }
123 
124 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
125 {
126 	mark_tail_call(ctx);
127 
128 	if (seen_call(ctx)) {
129 		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
130 		return RV_REG_S6;
131 	}
132 	return RV_REG_A6;
133 }
134 
135 static bool is_32b_int(s64 val)
136 {
137 	return -(1L << 31) <= val && val < (1L << 31);
138 }
139 
140 static bool in_auipc_jalr_range(s64 val)
141 {
142 	/*
143 	 * auipc+jalr can reach any signed PC-relative offset in the range
144 	 * [-2^31 - 2^11, 2^31 - 2^11).
145 	 */
146 	return (-(1L << 31) - (1L << 11)) <= val &&
147 		val < ((1L << 31) - (1L << 11));
148 }
149 
150 /* Modify rd pointer to alternate reg to avoid corrupting original reg */
151 static void emit_sextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx)
152 {
153 	emit_sextw(ra, *rd, ctx);
154 	*rd = ra;
155 }
156 
157 static void emit_zextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx)
158 {
159 	emit_zextw(ra, *rd, ctx);
160 	*rd = ra;
161 }
162 
163 /* Emit fixed-length instructions for address */
164 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
165 {
166 	/*
167 	 * Use the ro_insns(RX) to calculate the offset as the BPF program will
168 	 * finally run from this memory region.
169 	 */
170 	u64 ip = (u64)(ctx->ro_insns + ctx->ninsns);
171 	s64 off = addr - ip;
172 	s64 upper = (off + (1 << 11)) >> 12;
173 	s64 lower = off & 0xfff;
174 
175 	if (extra_pass && !in_auipc_jalr_range(off)) {
176 		pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
177 		return -ERANGE;
178 	}
179 
180 	emit(rv_auipc(rd, upper), ctx);
181 	emit(rv_addi(rd, rd, lower), ctx);
182 	return 0;
183 }
184 
185 /* Emit variable-length instructions for 32-bit and 64-bit imm */
186 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
187 {
188 	/* Note that the immediate from the add is sign-extended,
189 	 * which means that we need to compensate this by adding 2^12,
190 	 * when the 12th bit is set. A simpler way of doing this, and
191 	 * getting rid of the check, is to just add 2**11 before the
192 	 * shift. The "Loading a 32-Bit constant" example from the
193 	 * "Computer Organization and Design, RISC-V edition" book by
194 	 * Patterson/Hennessy highlights this fact.
195 	 *
196 	 * This also means that we need to process LSB to MSB.
197 	 */
198 	s64 upper = (val + (1 << 11)) >> 12;
199 	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
200 	 * and addi are signed and RVC checks will perform signed comparisons.
201 	 */
202 	s64 lower = ((val & 0xfff) << 52) >> 52;
203 	int shift;
204 
205 	if (is_32b_int(val)) {
206 		if (upper)
207 			emit_lui(rd, upper, ctx);
208 
209 		if (!upper) {
210 			emit_li(rd, lower, ctx);
211 			return;
212 		}
213 
214 		emit_addiw(rd, rd, lower, ctx);
215 		return;
216 	}
217 
218 	shift = __ffs(upper);
219 	upper >>= shift;
220 	shift += 12;
221 
222 	emit_imm(rd, upper, ctx);
223 
224 	emit_slli(rd, rd, shift, ctx);
225 	if (lower)
226 		emit_addi(rd, rd, lower, ctx);
227 }
228 
229 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
230 {
231 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
232 
233 	if (seen_reg(RV_REG_RA, ctx)) {
234 		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
235 		store_offset -= 8;
236 	}
237 	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
238 	store_offset -= 8;
239 	if (seen_reg(RV_REG_S1, ctx)) {
240 		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
241 		store_offset -= 8;
242 	}
243 	if (seen_reg(RV_REG_S2, ctx)) {
244 		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
245 		store_offset -= 8;
246 	}
247 	if (seen_reg(RV_REG_S3, ctx)) {
248 		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
249 		store_offset -= 8;
250 	}
251 	if (seen_reg(RV_REG_S4, ctx)) {
252 		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
253 		store_offset -= 8;
254 	}
255 	if (seen_reg(RV_REG_S5, ctx)) {
256 		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
257 		store_offset -= 8;
258 	}
259 	if (seen_reg(RV_REG_S6, ctx)) {
260 		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
261 		store_offset -= 8;
262 	}
263 	if (ctx->arena_vm_start) {
264 		emit_ld(RV_REG_ARENA, store_offset, RV_REG_SP, ctx);
265 		store_offset -= 8;
266 	}
267 
268 	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
269 	/* Set return value. */
270 	if (!is_tail_call)
271 		emit_addiw(RV_REG_A0, RV_REG_A5, 0, ctx);
272 	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
273 		  is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
274 		  ctx);
275 }
276 
277 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
278 		     struct rv_jit_context *ctx)
279 {
280 	switch (cond) {
281 	case BPF_JEQ:
282 		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
283 		return;
284 	case BPF_JGT:
285 		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
286 		return;
287 	case BPF_JLT:
288 		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
289 		return;
290 	case BPF_JGE:
291 		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
292 		return;
293 	case BPF_JLE:
294 		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
295 		return;
296 	case BPF_JNE:
297 		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
298 		return;
299 	case BPF_JSGT:
300 		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
301 		return;
302 	case BPF_JSLT:
303 		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
304 		return;
305 	case BPF_JSGE:
306 		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
307 		return;
308 	case BPF_JSLE:
309 		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
310 	}
311 }
312 
313 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
314 			struct rv_jit_context *ctx)
315 {
316 	s64 upper, lower;
317 
318 	if (is_13b_int(rvoff)) {
319 		emit_bcc(cond, rd, rs, rvoff, ctx);
320 		return;
321 	}
322 
323 	/* Adjust for jal */
324 	rvoff -= 4;
325 
326 	/* Transform, e.g.:
327 	 *   bne rd,rs,foo
328 	 * to
329 	 *   beq rd,rs,<.L1>
330 	 *   (auipc foo)
331 	 *   jal(r) foo
332 	 * .L1
333 	 */
334 	cond = invert_bpf_cond(cond);
335 	if (is_21b_int(rvoff)) {
336 		emit_bcc(cond, rd, rs, 8, ctx);
337 		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
338 		return;
339 	}
340 
341 	/* 32b No need for an additional rvoff adjustment, since we
342 	 * get that from the auipc at PC', where PC = PC' + 4.
343 	 */
344 	upper = (rvoff + (1 << 11)) >> 12;
345 	lower = rvoff & 0xfff;
346 
347 	emit_bcc(cond, rd, rs, 12, ctx);
348 	emit(rv_auipc(RV_REG_T1, upper), ctx);
349 	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
350 }
351 
352 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
353 {
354 	int tc_ninsn, off, start_insn = ctx->ninsns;
355 	u8 tcc = rv_tail_call_reg(ctx);
356 
357 	/* a0: &ctx
358 	 * a1: &array
359 	 * a2: index
360 	 *
361 	 * if (index >= array->map.max_entries)
362 	 *	goto out;
363 	 */
364 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
365 		   ctx->offset[0];
366 	emit_zextw(RV_REG_A2, RV_REG_A2, ctx);
367 
368 	off = offsetof(struct bpf_array, map.max_entries);
369 	if (is_12b_check(off, insn))
370 		return -1;
371 	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
372 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
373 	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
374 
375 	/* if (--TCC < 0)
376 	 *     goto out;
377 	 */
378 	emit_addi(RV_REG_TCC, tcc, -1, ctx);
379 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
380 	emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
381 
382 	/* prog = array->ptrs[index];
383 	 * if (!prog)
384 	 *     goto out;
385 	 */
386 	emit_sh3add(RV_REG_T2, RV_REG_A2, RV_REG_A1, ctx);
387 	off = offsetof(struct bpf_array, ptrs);
388 	if (is_12b_check(off, insn))
389 		return -1;
390 	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
391 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
392 	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
393 
394 	/* goto *(prog->bpf_func + 4); */
395 	off = offsetof(struct bpf_prog, bpf_func);
396 	if (is_12b_check(off, insn))
397 		return -1;
398 	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
399 	__build_epilogue(true, ctx);
400 	return 0;
401 }
402 
403 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
404 		      struct rv_jit_context *ctx)
405 {
406 	u8 code = insn->code;
407 
408 	switch (code) {
409 	case BPF_JMP | BPF_JA:
410 	case BPF_JMP | BPF_CALL:
411 	case BPF_JMP | BPF_EXIT:
412 	case BPF_JMP | BPF_TAIL_CALL:
413 		break;
414 	default:
415 		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
416 	}
417 
418 	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
419 	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
420 	    code & BPF_LDX || code & BPF_STX)
421 		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
422 }
423 
424 static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
425 			      struct rv_jit_context *ctx)
426 {
427 	s64 upper, lower;
428 
429 	if (rvoff && fixed_addr && is_21b_int(rvoff)) {
430 		emit(rv_jal(rd, rvoff >> 1), ctx);
431 		return 0;
432 	} else if (in_auipc_jalr_range(rvoff)) {
433 		upper = (rvoff + (1 << 11)) >> 12;
434 		lower = rvoff & 0xfff;
435 		emit(rv_auipc(RV_REG_T1, upper), ctx);
436 		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
437 		return 0;
438 	}
439 
440 	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
441 	return -ERANGE;
442 }
443 
444 static bool is_signed_bpf_cond(u8 cond)
445 {
446 	return cond == BPF_JSGT || cond == BPF_JSLT ||
447 		cond == BPF_JSGE || cond == BPF_JSLE;
448 }
449 
450 static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
451 {
452 	s64 off = 0;
453 	u64 ip;
454 
455 	if (addr && ctx->insns && ctx->ro_insns) {
456 		/*
457 		 * Use the ro_insns(RX) to calculate the offset as the BPF
458 		 * program will finally run from this memory region.
459 		 */
460 		ip = (u64)(long)(ctx->ro_insns + ctx->ninsns);
461 		off = addr - ip;
462 	}
463 
464 	return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
465 }
466 
467 static inline void emit_kcfi(u32 hash, struct rv_jit_context *ctx)
468 {
469 	if (IS_ENABLED(CONFIG_CFI_CLANG))
470 		emit(hash, ctx);
471 }
472 
473 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
474 			struct rv_jit_context *ctx)
475 {
476 	u8 r0;
477 	int jmp_offset;
478 
479 	if (off) {
480 		if (is_12b_int(off)) {
481 			emit_addi(RV_REG_T1, rd, off, ctx);
482 		} else {
483 			emit_imm(RV_REG_T1, off, ctx);
484 			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
485 		}
486 		rd = RV_REG_T1;
487 	}
488 
489 	switch (imm) {
490 	/* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
491 	case BPF_ADD:
492 		emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
493 		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
494 		break;
495 	case BPF_AND:
496 		emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
497 		     rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
498 		break;
499 	case BPF_OR:
500 		emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
501 		     rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
502 		break;
503 	case BPF_XOR:
504 		emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
505 		     rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
506 		break;
507 	/* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
508 	case BPF_ADD | BPF_FETCH:
509 		emit(is64 ? rv_amoadd_d(rs, rs, rd, 1, 1) :
510 		     rv_amoadd_w(rs, rs, rd, 1, 1), ctx);
511 		if (!is64)
512 			emit_zextw(rs, rs, ctx);
513 		break;
514 	case BPF_AND | BPF_FETCH:
515 		emit(is64 ? rv_amoand_d(rs, rs, rd, 1, 1) :
516 		     rv_amoand_w(rs, rs, rd, 1, 1), ctx);
517 		if (!is64)
518 			emit_zextw(rs, rs, ctx);
519 		break;
520 	case BPF_OR | BPF_FETCH:
521 		emit(is64 ? rv_amoor_d(rs, rs, rd, 1, 1) :
522 		     rv_amoor_w(rs, rs, rd, 1, 1), ctx);
523 		if (!is64)
524 			emit_zextw(rs, rs, ctx);
525 		break;
526 	case BPF_XOR | BPF_FETCH:
527 		emit(is64 ? rv_amoxor_d(rs, rs, rd, 1, 1) :
528 		     rv_amoxor_w(rs, rs, rd, 1, 1), ctx);
529 		if (!is64)
530 			emit_zextw(rs, rs, ctx);
531 		break;
532 	/* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
533 	case BPF_XCHG:
534 		emit(is64 ? rv_amoswap_d(rs, rs, rd, 1, 1) :
535 		     rv_amoswap_w(rs, rs, rd, 1, 1), ctx);
536 		if (!is64)
537 			emit_zextw(rs, rs, ctx);
538 		break;
539 	/* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
540 	case BPF_CMPXCHG:
541 		r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
542 		if (is64)
543 			emit_mv(RV_REG_T2, r0, ctx);
544 		else
545 			emit_addiw(RV_REG_T2, r0, 0, ctx);
546 		emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
547 		     rv_lr_w(r0, 0, rd, 0, 0), ctx);
548 		jmp_offset = ninsns_rvoff(8);
549 		emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
550 		emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
551 		     rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
552 		jmp_offset = ninsns_rvoff(-6);
553 		emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
554 		emit(rv_fence(0x3, 0x3), ctx);
555 		break;
556 	}
557 }
558 
559 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
560 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
561 #define REG_DONT_CLEAR_MARKER	0	/* RV_REG_ZERO unused in pt_regmap */
562 
563 bool ex_handler_bpf(const struct exception_table_entry *ex,
564 		    struct pt_regs *regs)
565 {
566 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
567 	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
568 
569 	if (regs_offset != REG_DONT_CLEAR_MARKER)
570 		*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
571 	regs->epc = (unsigned long)&ex->fixup - offset;
572 
573 	return true;
574 }
575 
576 /* For accesses to BTF pointers, add an entry to the exception table */
577 static int add_exception_handler(const struct bpf_insn *insn,
578 				 struct rv_jit_context *ctx,
579 				 int dst_reg, int insn_len)
580 {
581 	struct exception_table_entry *ex;
582 	unsigned long pc;
583 	off_t ins_offset;
584 	off_t fixup_offset;
585 
586 	if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
587 	    (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
588 	     BPF_MODE(insn->code) != BPF_PROBE_MEM32))
589 		return 0;
590 
591 	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
592 		return -EINVAL;
593 
594 	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
595 		return -EINVAL;
596 
597 	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
598 		return -EINVAL;
599 
600 	ex = &ctx->prog->aux->extable[ctx->nexentries];
601 	pc = (unsigned long)&ctx->ro_insns[ctx->ninsns - insn_len];
602 
603 	/*
604 	 * This is the relative offset of the instruction that may fault from
605 	 * the exception table itself. This will be written to the exception
606 	 * table and if this instruction faults, the destination register will
607 	 * be set to '0' and the execution will jump to the next instruction.
608 	 */
609 	ins_offset = pc - (long)&ex->insn;
610 	if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
611 		return -ERANGE;
612 
613 	/*
614 	 * Since the extable follows the program, the fixup offset is always
615 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
616 	 * to keep things simple, and put the destination register in the upper
617 	 * bits. We don't need to worry about buildtime or runtime sort
618 	 * modifying the upper bits because the table is already sorted, and
619 	 * isn't part of the main exception table.
620 	 *
621 	 * The fixup_offset is set to the next instruction from the instruction
622 	 * that may fault. The execution will jump to this after handling the
623 	 * fault.
624 	 */
625 	fixup_offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
626 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
627 		return -ERANGE;
628 
629 	/*
630 	 * The offsets above have been calculated using the RO buffer but we
631 	 * need to use the R/W buffer for writes.
632 	 * switch ex to rw buffer for writing.
633 	 */
634 	ex = (void *)ctx->insns + ((void *)ex - (void *)ctx->ro_insns);
635 
636 	ex->insn = ins_offset;
637 
638 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
639 		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
640 	ex->type = EX_TYPE_BPF;
641 
642 	ctx->nexentries++;
643 	return 0;
644 }
645 
646 static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
647 {
648 	s64 rvoff;
649 	struct rv_jit_context ctx;
650 
651 	ctx.ninsns = 0;
652 	ctx.insns = (u16 *)insns;
653 
654 	if (!target) {
655 		emit(rv_nop(), &ctx);
656 		emit(rv_nop(), &ctx);
657 		return 0;
658 	}
659 
660 	rvoff = (s64)(target - ip);
661 	return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
662 }
663 
664 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
665 		       void *old_addr, void *new_addr)
666 {
667 	u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
668 	bool is_call = poke_type == BPF_MOD_CALL;
669 	int ret;
670 
671 	if (!is_kernel_text((unsigned long)ip) &&
672 	    !is_bpf_text_address((unsigned long)ip))
673 		return -ENOTSUPP;
674 
675 	ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
676 	if (ret)
677 		return ret;
678 
679 	if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
680 		return -EFAULT;
681 
682 	ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
683 	if (ret)
684 		return ret;
685 
686 	cpus_read_lock();
687 	mutex_lock(&text_mutex);
688 	if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
689 		ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
690 	mutex_unlock(&text_mutex);
691 	cpus_read_unlock();
692 
693 	return ret;
694 }
695 
696 static void store_args(int nr_arg_slots, int args_off, struct rv_jit_context *ctx)
697 {
698 	int i;
699 
700 	for (i = 0; i < nr_arg_slots; i++) {
701 		if (i < RV_MAX_REG_ARGS) {
702 			emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
703 		} else {
704 			/* skip slots for T0 and FP of traced function */
705 			emit_ld(RV_REG_T1, 16 + (i - RV_MAX_REG_ARGS) * 8, RV_REG_FP, ctx);
706 			emit_sd(RV_REG_FP, -args_off, RV_REG_T1, ctx);
707 		}
708 		args_off -= 8;
709 	}
710 }
711 
712 static void restore_args(int nr_reg_args, int args_off, struct rv_jit_context *ctx)
713 {
714 	int i;
715 
716 	for (i = 0; i < nr_reg_args; i++) {
717 		emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
718 		args_off -= 8;
719 	}
720 }
721 
722 static void restore_stack_args(int nr_stack_args, int args_off, int stk_arg_off,
723 			       struct rv_jit_context *ctx)
724 {
725 	int i;
726 
727 	for (i = 0; i < nr_stack_args; i++) {
728 		emit_ld(RV_REG_T1, -(args_off - RV_MAX_REG_ARGS * 8), RV_REG_FP, ctx);
729 		emit_sd(RV_REG_FP, -stk_arg_off, RV_REG_T1, ctx);
730 		args_off -= 8;
731 		stk_arg_off -= 8;
732 	}
733 }
734 
735 static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
736 			   int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
737 {
738 	int ret, branch_off;
739 	struct bpf_prog *p = l->link.prog;
740 	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
741 
742 	if (l->cookie) {
743 		emit_imm(RV_REG_T1, l->cookie, ctx);
744 		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
745 	} else {
746 		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
747 	}
748 
749 	/* arg1: prog */
750 	emit_imm(RV_REG_A0, (const s64)p, ctx);
751 	/* arg2: &run_ctx */
752 	emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
753 	ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
754 	if (ret)
755 		return ret;
756 
757 	/* store prog start time */
758 	emit_mv(RV_REG_S1, RV_REG_A0, ctx);
759 
760 	/* if (__bpf_prog_enter(prog) == 0)
761 	 *	goto skip_exec_of_prog;
762 	 */
763 	branch_off = ctx->ninsns;
764 	/* nop reserved for conditional jump */
765 	emit(rv_nop(), ctx);
766 
767 	/* arg1: &args_off */
768 	emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
769 	if (!p->jited)
770 		/* arg2: progs[i]->insnsi for interpreter */
771 		emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
772 	ret = emit_call((const u64)p->bpf_func, true, ctx);
773 	if (ret)
774 		return ret;
775 
776 	if (save_ret) {
777 		emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
778 		emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
779 	}
780 
781 	/* update branch with beqz */
782 	if (ctx->insns) {
783 		int offset = ninsns_rvoff(ctx->ninsns - branch_off);
784 		u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
785 		*(u32 *)(ctx->insns + branch_off) = insn;
786 	}
787 
788 	/* arg1: prog */
789 	emit_imm(RV_REG_A0, (const s64)p, ctx);
790 	/* arg2: prog start time */
791 	emit_mv(RV_REG_A1, RV_REG_S1, ctx);
792 	/* arg3: &run_ctx */
793 	emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
794 	ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
795 
796 	return ret;
797 }
798 
799 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
800 					 const struct btf_func_model *m,
801 					 struct bpf_tramp_links *tlinks,
802 					 void *func_addr, u32 flags,
803 					 struct rv_jit_context *ctx)
804 {
805 	int i, ret, offset;
806 	int *branches_off = NULL;
807 	int stack_size = 0, nr_arg_slots = 0;
808 	int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off, stk_arg_off;
809 	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
810 	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
811 	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
812 	bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
813 	void *orig_call = func_addr;
814 	bool save_ret;
815 	u32 insn;
816 
817 	/* Two types of generated trampoline stack layout:
818 	 *
819 	 * 1. trampoline called from function entry
820 	 * --------------------------------------
821 	 * FP + 8	    [ RA to parent func	] return address to parent
822 	 *					  function
823 	 * FP + 0	    [ FP of parent func ] frame pointer of parent
824 	 *					  function
825 	 * FP - 8           [ T0 to traced func ] return address of traced
826 	 *					  function
827 	 * FP - 16	    [ FP of traced func ] frame pointer of traced
828 	 *					  function
829 	 * --------------------------------------
830 	 *
831 	 * 2. trampoline called directly
832 	 * --------------------------------------
833 	 * FP - 8	    [ RA to caller func ] return address to caller
834 	 *					  function
835 	 * FP - 16	    [ FP of caller func	] frame pointer of caller
836 	 *					  function
837 	 * --------------------------------------
838 	 *
839 	 * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
840 	 *					  BPF_TRAMP_F_RET_FENTRY_RET
841 	 *                  [ argN              ]
842 	 *                  [ ...               ]
843 	 * FP - args_off    [ arg1              ]
844 	 *
845 	 * FP - nregs_off   [ regs count        ]
846 	 *
847 	 * FP - ip_off      [ traced func	] BPF_TRAMP_F_IP_ARG
848 	 *
849 	 * FP - run_ctx_off [ bpf_tramp_run_ctx ]
850 	 *
851 	 * FP - sreg_off    [ callee saved reg	]
852 	 *
853 	 *		    [ pads              ] pads for 16 bytes alignment
854 	 *
855 	 *		    [ stack_argN        ]
856 	 *		    [ ...               ]
857 	 * FP - stk_arg_off [ stack_arg1        ] BPF_TRAMP_F_CALL_ORIG
858 	 */
859 
860 	if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
861 		return -ENOTSUPP;
862 
863 	if (m->nr_args > MAX_BPF_FUNC_ARGS)
864 		return -ENOTSUPP;
865 
866 	for (i = 0; i < m->nr_args; i++)
867 		nr_arg_slots += round_up(m->arg_size[i], 8) / 8;
868 
869 	/* room of trampoline frame to store return address and frame pointer */
870 	stack_size += 16;
871 
872 	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
873 	if (save_ret) {
874 		stack_size += 16; /* Save both A5 (BPF R0) and A0 */
875 		retval_off = stack_size;
876 	}
877 
878 	stack_size += nr_arg_slots * 8;
879 	args_off = stack_size;
880 
881 	stack_size += 8;
882 	nregs_off = stack_size;
883 
884 	if (flags & BPF_TRAMP_F_IP_ARG) {
885 		stack_size += 8;
886 		ip_off = stack_size;
887 	}
888 
889 	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
890 	run_ctx_off = stack_size;
891 
892 	stack_size += 8;
893 	sreg_off = stack_size;
894 
895 	if (nr_arg_slots - RV_MAX_REG_ARGS > 0)
896 		stack_size += (nr_arg_slots - RV_MAX_REG_ARGS) * 8;
897 
898 	stack_size = round_up(stack_size, STACK_ALIGN);
899 
900 	/* room for args on stack must be at the top of stack */
901 	stk_arg_off = stack_size;
902 
903 	if (!is_struct_ops) {
904 		/* For the trampoline called from function entry,
905 		 * the frame of traced function and the frame of
906 		 * trampoline need to be considered.
907 		 */
908 		emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
909 		emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
910 		emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
911 		emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
912 
913 		emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
914 		emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
915 		emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
916 		emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
917 	} else {
918 		/* emit kcfi hash */
919 		emit_kcfi(cfi_get_func_hash(func_addr), ctx);
920 		/* For the trampoline called directly, just handle
921 		 * the frame of trampoline.
922 		 */
923 		emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
924 		emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
925 		emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
926 		emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
927 	}
928 
929 	/* callee saved register S1 to pass start time */
930 	emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
931 
932 	/* store ip address of the traced function */
933 	if (flags & BPF_TRAMP_F_IP_ARG) {
934 		emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
935 		emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
936 	}
937 
938 	emit_li(RV_REG_T1, nr_arg_slots, ctx);
939 	emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
940 
941 	store_args(nr_arg_slots, args_off, ctx);
942 
943 	/* skip to actual body of traced function */
944 	if (flags & BPF_TRAMP_F_SKIP_FRAME)
945 		orig_call += RV_FENTRY_NINSNS * 4;
946 
947 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
948 		emit_imm(RV_REG_A0, ctx->insns ? (const s64)im : RV_MAX_COUNT_IMM, ctx);
949 		ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
950 		if (ret)
951 			return ret;
952 	}
953 
954 	for (i = 0; i < fentry->nr_links; i++) {
955 		ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
956 				      flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
957 		if (ret)
958 			return ret;
959 	}
960 
961 	if (fmod_ret->nr_links) {
962 		branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
963 		if (!branches_off)
964 			return -ENOMEM;
965 
966 		/* cleanup to avoid garbage return value confusion */
967 		emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
968 		for (i = 0; i < fmod_ret->nr_links; i++) {
969 			ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
970 					      run_ctx_off, true, ctx);
971 			if (ret)
972 				goto out;
973 			emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
974 			branches_off[i] = ctx->ninsns;
975 			/* nop reserved for conditional jump */
976 			emit(rv_nop(), ctx);
977 		}
978 	}
979 
980 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
981 		restore_args(min_t(int, nr_arg_slots, RV_MAX_REG_ARGS), args_off, ctx);
982 		restore_stack_args(nr_arg_slots - RV_MAX_REG_ARGS, args_off, stk_arg_off, ctx);
983 		ret = emit_call((const u64)orig_call, true, ctx);
984 		if (ret)
985 			goto out;
986 		emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
987 		emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
988 		im->ip_after_call = ctx->ro_insns + ctx->ninsns;
989 		/* 2 nops reserved for auipc+jalr pair */
990 		emit(rv_nop(), ctx);
991 		emit(rv_nop(), ctx);
992 	}
993 
994 	/* update branches saved in invoke_bpf_mod_ret with bnez */
995 	for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
996 		offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
997 		insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
998 		*(u32 *)(ctx->insns + branches_off[i]) = insn;
999 	}
1000 
1001 	for (i = 0; i < fexit->nr_links; i++) {
1002 		ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
1003 				      run_ctx_off, false, ctx);
1004 		if (ret)
1005 			goto out;
1006 	}
1007 
1008 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1009 		im->ip_epilogue = ctx->ro_insns + ctx->ninsns;
1010 		emit_imm(RV_REG_A0, ctx->insns ? (const s64)im : RV_MAX_COUNT_IMM, ctx);
1011 		ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
1012 		if (ret)
1013 			goto out;
1014 	}
1015 
1016 	if (flags & BPF_TRAMP_F_RESTORE_REGS)
1017 		restore_args(min_t(int, nr_arg_slots, RV_MAX_REG_ARGS), args_off, ctx);
1018 
1019 	if (save_ret) {
1020 		emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
1021 		emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx);
1022 	}
1023 
1024 	emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
1025 
1026 	if (!is_struct_ops) {
1027 		/* trampoline called from function entry */
1028 		emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
1029 		emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1030 		emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1031 
1032 		emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
1033 		emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
1034 		emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
1035 
1036 		if (flags & BPF_TRAMP_F_SKIP_FRAME)
1037 			/* return to parent function */
1038 			emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1039 		else
1040 			/* return to traced function */
1041 			emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
1042 	} else {
1043 		/* trampoline called directly */
1044 		emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
1045 		emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1046 		emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1047 
1048 		emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1049 	}
1050 
1051 	ret = ctx->ninsns;
1052 out:
1053 	kfree(branches_off);
1054 	return ret;
1055 }
1056 
1057 int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
1058 			     struct bpf_tramp_links *tlinks, void *func_addr)
1059 {
1060 	struct bpf_tramp_image im;
1061 	struct rv_jit_context ctx;
1062 	int ret;
1063 
1064 	ctx.ninsns = 0;
1065 	ctx.insns = NULL;
1066 	ctx.ro_insns = NULL;
1067 	ret = __arch_prepare_bpf_trampoline(&im, m, tlinks, func_addr, flags, &ctx);
1068 
1069 	return ret < 0 ? ret : ninsns_rvoff(ctx.ninsns);
1070 }
1071 
1072 void *arch_alloc_bpf_trampoline(unsigned int size)
1073 {
1074 	return bpf_prog_pack_alloc(size, bpf_fill_ill_insns);
1075 }
1076 
1077 void arch_free_bpf_trampoline(void *image, unsigned int size)
1078 {
1079 	bpf_prog_pack_free(image, size);
1080 }
1081 
1082 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *ro_image,
1083 				void *ro_image_end, const struct btf_func_model *m,
1084 				u32 flags, struct bpf_tramp_links *tlinks,
1085 				void *func_addr)
1086 {
1087 	int ret;
1088 	void *image, *res;
1089 	struct rv_jit_context ctx;
1090 	u32 size = ro_image_end - ro_image;
1091 
1092 	image = kvmalloc(size, GFP_KERNEL);
1093 	if (!image)
1094 		return -ENOMEM;
1095 
1096 	ctx.ninsns = 0;
1097 	ctx.insns = image;
1098 	ctx.ro_insns = ro_image;
1099 	ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1100 	if (ret < 0)
1101 		goto out;
1102 
1103 	if (WARN_ON(size < ninsns_rvoff(ctx.ninsns))) {
1104 		ret = -E2BIG;
1105 		goto out;
1106 	}
1107 
1108 	res = bpf_arch_text_copy(ro_image, image, size);
1109 	if (IS_ERR(res)) {
1110 		ret = PTR_ERR(res);
1111 		goto out;
1112 	}
1113 
1114 	bpf_flush_icache(ro_image, ro_image_end);
1115 out:
1116 	kvfree(image);
1117 	return ret < 0 ? ret : size;
1118 }
1119 
1120 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1121 		      bool extra_pass)
1122 {
1123 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1124 		    BPF_CLASS(insn->code) == BPF_JMP;
1125 	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1126 	struct bpf_prog_aux *aux = ctx->prog->aux;
1127 	u8 rd = -1, rs = -1, code = insn->code;
1128 	s16 off = insn->off;
1129 	s32 imm = insn->imm;
1130 
1131 	init_regs(&rd, &rs, insn, ctx);
1132 
1133 	switch (code) {
1134 	/* dst = src */
1135 	case BPF_ALU | BPF_MOV | BPF_X:
1136 	case BPF_ALU64 | BPF_MOV | BPF_X:
1137 		if (insn_is_cast_user(insn)) {
1138 			emit_mv(RV_REG_T1, rs, ctx);
1139 			emit_zextw(RV_REG_T1, RV_REG_T1, ctx);
1140 			emit_imm(rd, (ctx->user_vm_start >> 32) << 32, ctx);
1141 			emit(rv_beq(RV_REG_T1, RV_REG_ZERO, 4), ctx);
1142 			emit_or(RV_REG_T1, rd, RV_REG_T1, ctx);
1143 			emit_mv(rd, RV_REG_T1, ctx);
1144 			break;
1145 		} else if (insn_is_mov_percpu_addr(insn)) {
1146 			if (rd != rs)
1147 				emit_mv(rd, rs, ctx);
1148 #ifdef CONFIG_SMP
1149 			/* Load current CPU number in T1 */
1150 			emit_ld(RV_REG_T1, offsetof(struct thread_info, cpu),
1151 				RV_REG_TP, ctx);
1152 			/* Load address of __per_cpu_offset array in T2 */
1153 			emit_addr(RV_REG_T2, (u64)&__per_cpu_offset, extra_pass, ctx);
1154 			/* Get address of __per_cpu_offset[cpu] in T1 */
1155 			emit_sh3add(RV_REG_T1, RV_REG_T1, RV_REG_T2, ctx);
1156 			/* Load __per_cpu_offset[cpu] in T1 */
1157 			emit_ld(RV_REG_T1, 0, RV_REG_T1, ctx);
1158 			/* Add the offset to Rd */
1159 			emit_add(rd, rd, RV_REG_T1, ctx);
1160 #endif
1161 		}
1162 		if (imm == 1) {
1163 			/* Special mov32 for zext */
1164 			emit_zextw(rd, rd, ctx);
1165 			break;
1166 		}
1167 		switch (insn->off) {
1168 		case 0:
1169 			emit_mv(rd, rs, ctx);
1170 			break;
1171 		case 8:
1172 			emit_sextb(rd, rs, ctx);
1173 			break;
1174 		case 16:
1175 			emit_sexth(rd, rs, ctx);
1176 			break;
1177 		case 32:
1178 			emit_sextw(rd, rs, ctx);
1179 			break;
1180 		}
1181 		if (!is64 && !aux->verifier_zext)
1182 			emit_zextw(rd, rd, ctx);
1183 		break;
1184 
1185 	/* dst = dst OP src */
1186 	case BPF_ALU | BPF_ADD | BPF_X:
1187 	case BPF_ALU64 | BPF_ADD | BPF_X:
1188 		emit_add(rd, rd, rs, ctx);
1189 		if (!is64 && !aux->verifier_zext)
1190 			emit_zextw(rd, rd, ctx);
1191 		break;
1192 	case BPF_ALU | BPF_SUB | BPF_X:
1193 	case BPF_ALU64 | BPF_SUB | BPF_X:
1194 		if (is64)
1195 			emit_sub(rd, rd, rs, ctx);
1196 		else
1197 			emit_subw(rd, rd, rs, ctx);
1198 
1199 		if (!is64 && !aux->verifier_zext)
1200 			emit_zextw(rd, rd, ctx);
1201 		break;
1202 	case BPF_ALU | BPF_AND | BPF_X:
1203 	case BPF_ALU64 | BPF_AND | BPF_X:
1204 		emit_and(rd, rd, rs, ctx);
1205 		if (!is64 && !aux->verifier_zext)
1206 			emit_zextw(rd, rd, ctx);
1207 		break;
1208 	case BPF_ALU | BPF_OR | BPF_X:
1209 	case BPF_ALU64 | BPF_OR | BPF_X:
1210 		emit_or(rd, rd, rs, ctx);
1211 		if (!is64 && !aux->verifier_zext)
1212 			emit_zextw(rd, rd, ctx);
1213 		break;
1214 	case BPF_ALU | BPF_XOR | BPF_X:
1215 	case BPF_ALU64 | BPF_XOR | BPF_X:
1216 		emit_xor(rd, rd, rs, ctx);
1217 		if (!is64 && !aux->verifier_zext)
1218 			emit_zextw(rd, rd, ctx);
1219 		break;
1220 	case BPF_ALU | BPF_MUL | BPF_X:
1221 	case BPF_ALU64 | BPF_MUL | BPF_X:
1222 		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1223 		if (!is64 && !aux->verifier_zext)
1224 			emit_zextw(rd, rd, ctx);
1225 		break;
1226 	case BPF_ALU | BPF_DIV | BPF_X:
1227 	case BPF_ALU64 | BPF_DIV | BPF_X:
1228 		if (off)
1229 			emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx);
1230 		else
1231 			emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1232 		if (!is64 && !aux->verifier_zext)
1233 			emit_zextw(rd, rd, ctx);
1234 		break;
1235 	case BPF_ALU | BPF_MOD | BPF_X:
1236 	case BPF_ALU64 | BPF_MOD | BPF_X:
1237 		if (off)
1238 			emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx);
1239 		else
1240 			emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1241 		if (!is64 && !aux->verifier_zext)
1242 			emit_zextw(rd, rd, ctx);
1243 		break;
1244 	case BPF_ALU | BPF_LSH | BPF_X:
1245 	case BPF_ALU64 | BPF_LSH | BPF_X:
1246 		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1247 		if (!is64 && !aux->verifier_zext)
1248 			emit_zextw(rd, rd, ctx);
1249 		break;
1250 	case BPF_ALU | BPF_RSH | BPF_X:
1251 	case BPF_ALU64 | BPF_RSH | BPF_X:
1252 		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1253 		if (!is64 && !aux->verifier_zext)
1254 			emit_zextw(rd, rd, ctx);
1255 		break;
1256 	case BPF_ALU | BPF_ARSH | BPF_X:
1257 	case BPF_ALU64 | BPF_ARSH | BPF_X:
1258 		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1259 		if (!is64 && !aux->verifier_zext)
1260 			emit_zextw(rd, rd, ctx);
1261 		break;
1262 
1263 	/* dst = -dst */
1264 	case BPF_ALU | BPF_NEG:
1265 	case BPF_ALU64 | BPF_NEG:
1266 		emit_sub(rd, RV_REG_ZERO, rd, ctx);
1267 		if (!is64 && !aux->verifier_zext)
1268 			emit_zextw(rd, rd, ctx);
1269 		break;
1270 
1271 	/* dst = BSWAP##imm(dst) */
1272 	case BPF_ALU | BPF_END | BPF_FROM_LE:
1273 		switch (imm) {
1274 		case 16:
1275 			emit_zexth(rd, rd, ctx);
1276 			break;
1277 		case 32:
1278 			if (!aux->verifier_zext)
1279 				emit_zextw(rd, rd, ctx);
1280 			break;
1281 		case 64:
1282 			/* Do nothing */
1283 			break;
1284 		}
1285 		break;
1286 	case BPF_ALU | BPF_END | BPF_FROM_BE:
1287 	case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1288 		emit_bswap(rd, imm, ctx);
1289 		break;
1290 
1291 	/* dst = imm */
1292 	case BPF_ALU | BPF_MOV | BPF_K:
1293 	case BPF_ALU64 | BPF_MOV | BPF_K:
1294 		emit_imm(rd, imm, ctx);
1295 		if (!is64 && !aux->verifier_zext)
1296 			emit_zextw(rd, rd, ctx);
1297 		break;
1298 
1299 	/* dst = dst OP imm */
1300 	case BPF_ALU | BPF_ADD | BPF_K:
1301 	case BPF_ALU64 | BPF_ADD | BPF_K:
1302 		if (is_12b_int(imm)) {
1303 			emit_addi(rd, rd, imm, ctx);
1304 		} else {
1305 			emit_imm(RV_REG_T1, imm, ctx);
1306 			emit_add(rd, rd, RV_REG_T1, ctx);
1307 		}
1308 		if (!is64 && !aux->verifier_zext)
1309 			emit_zextw(rd, rd, ctx);
1310 		break;
1311 	case BPF_ALU | BPF_SUB | BPF_K:
1312 	case BPF_ALU64 | BPF_SUB | BPF_K:
1313 		if (is_12b_int(-imm)) {
1314 			emit_addi(rd, rd, -imm, ctx);
1315 		} else {
1316 			emit_imm(RV_REG_T1, imm, ctx);
1317 			emit_sub(rd, rd, RV_REG_T1, ctx);
1318 		}
1319 		if (!is64 && !aux->verifier_zext)
1320 			emit_zextw(rd, rd, ctx);
1321 		break;
1322 	case BPF_ALU | BPF_AND | BPF_K:
1323 	case BPF_ALU64 | BPF_AND | BPF_K:
1324 		if (is_12b_int(imm)) {
1325 			emit_andi(rd, rd, imm, ctx);
1326 		} else {
1327 			emit_imm(RV_REG_T1, imm, ctx);
1328 			emit_and(rd, rd, RV_REG_T1, ctx);
1329 		}
1330 		if (!is64 && !aux->verifier_zext)
1331 			emit_zextw(rd, rd, ctx);
1332 		break;
1333 	case BPF_ALU | BPF_OR | BPF_K:
1334 	case BPF_ALU64 | BPF_OR | BPF_K:
1335 		if (is_12b_int(imm)) {
1336 			emit(rv_ori(rd, rd, imm), ctx);
1337 		} else {
1338 			emit_imm(RV_REG_T1, imm, ctx);
1339 			emit_or(rd, rd, RV_REG_T1, ctx);
1340 		}
1341 		if (!is64 && !aux->verifier_zext)
1342 			emit_zextw(rd, rd, ctx);
1343 		break;
1344 	case BPF_ALU | BPF_XOR | BPF_K:
1345 	case BPF_ALU64 | BPF_XOR | BPF_K:
1346 		if (is_12b_int(imm)) {
1347 			emit(rv_xori(rd, rd, imm), ctx);
1348 		} else {
1349 			emit_imm(RV_REG_T1, imm, ctx);
1350 			emit_xor(rd, rd, RV_REG_T1, ctx);
1351 		}
1352 		if (!is64 && !aux->verifier_zext)
1353 			emit_zextw(rd, rd, ctx);
1354 		break;
1355 	case BPF_ALU | BPF_MUL | BPF_K:
1356 	case BPF_ALU64 | BPF_MUL | BPF_K:
1357 		emit_imm(RV_REG_T1, imm, ctx);
1358 		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1359 		     rv_mulw(rd, rd, RV_REG_T1), ctx);
1360 		if (!is64 && !aux->verifier_zext)
1361 			emit_zextw(rd, rd, ctx);
1362 		break;
1363 	case BPF_ALU | BPF_DIV | BPF_K:
1364 	case BPF_ALU64 | BPF_DIV | BPF_K:
1365 		emit_imm(RV_REG_T1, imm, ctx);
1366 		if (off)
1367 			emit(is64 ? rv_div(rd, rd, RV_REG_T1) :
1368 			     rv_divw(rd, rd, RV_REG_T1), ctx);
1369 		else
1370 			emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1371 			     rv_divuw(rd, rd, RV_REG_T1), ctx);
1372 		if (!is64 && !aux->verifier_zext)
1373 			emit_zextw(rd, rd, ctx);
1374 		break;
1375 	case BPF_ALU | BPF_MOD | BPF_K:
1376 	case BPF_ALU64 | BPF_MOD | BPF_K:
1377 		emit_imm(RV_REG_T1, imm, ctx);
1378 		if (off)
1379 			emit(is64 ? rv_rem(rd, rd, RV_REG_T1) :
1380 			     rv_remw(rd, rd, RV_REG_T1), ctx);
1381 		else
1382 			emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1383 			     rv_remuw(rd, rd, RV_REG_T1), ctx);
1384 		if (!is64 && !aux->verifier_zext)
1385 			emit_zextw(rd, rd, ctx);
1386 		break;
1387 	case BPF_ALU | BPF_LSH | BPF_K:
1388 	case BPF_ALU64 | BPF_LSH | BPF_K:
1389 		emit_slli(rd, rd, imm, ctx);
1390 
1391 		if (!is64 && !aux->verifier_zext)
1392 			emit_zextw(rd, rd, ctx);
1393 		break;
1394 	case BPF_ALU | BPF_RSH | BPF_K:
1395 	case BPF_ALU64 | BPF_RSH | BPF_K:
1396 		if (is64)
1397 			emit_srli(rd, rd, imm, ctx);
1398 		else
1399 			emit(rv_srliw(rd, rd, imm), ctx);
1400 
1401 		if (!is64 && !aux->verifier_zext)
1402 			emit_zextw(rd, rd, ctx);
1403 		break;
1404 	case BPF_ALU | BPF_ARSH | BPF_K:
1405 	case BPF_ALU64 | BPF_ARSH | BPF_K:
1406 		if (is64)
1407 			emit_srai(rd, rd, imm, ctx);
1408 		else
1409 			emit(rv_sraiw(rd, rd, imm), ctx);
1410 
1411 		if (!is64 && !aux->verifier_zext)
1412 			emit_zextw(rd, rd, ctx);
1413 		break;
1414 
1415 	/* JUMP off */
1416 	case BPF_JMP | BPF_JA:
1417 	case BPF_JMP32 | BPF_JA:
1418 		if (BPF_CLASS(code) == BPF_JMP)
1419 			rvoff = rv_offset(i, off, ctx);
1420 		else
1421 			rvoff = rv_offset(i, imm, ctx);
1422 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1423 		if (ret)
1424 			return ret;
1425 		break;
1426 
1427 	/* IF (dst COND src) JUMP off */
1428 	case BPF_JMP | BPF_JEQ | BPF_X:
1429 	case BPF_JMP32 | BPF_JEQ | BPF_X:
1430 	case BPF_JMP | BPF_JGT | BPF_X:
1431 	case BPF_JMP32 | BPF_JGT | BPF_X:
1432 	case BPF_JMP | BPF_JLT | BPF_X:
1433 	case BPF_JMP32 | BPF_JLT | BPF_X:
1434 	case BPF_JMP | BPF_JGE | BPF_X:
1435 	case BPF_JMP32 | BPF_JGE | BPF_X:
1436 	case BPF_JMP | BPF_JLE | BPF_X:
1437 	case BPF_JMP32 | BPF_JLE | BPF_X:
1438 	case BPF_JMP | BPF_JNE | BPF_X:
1439 	case BPF_JMP32 | BPF_JNE | BPF_X:
1440 	case BPF_JMP | BPF_JSGT | BPF_X:
1441 	case BPF_JMP32 | BPF_JSGT | BPF_X:
1442 	case BPF_JMP | BPF_JSLT | BPF_X:
1443 	case BPF_JMP32 | BPF_JSLT | BPF_X:
1444 	case BPF_JMP | BPF_JSGE | BPF_X:
1445 	case BPF_JMP32 | BPF_JSGE | BPF_X:
1446 	case BPF_JMP | BPF_JSLE | BPF_X:
1447 	case BPF_JMP32 | BPF_JSLE | BPF_X:
1448 	case BPF_JMP | BPF_JSET | BPF_X:
1449 	case BPF_JMP32 | BPF_JSET | BPF_X:
1450 		rvoff = rv_offset(i, off, ctx);
1451 		if (!is64) {
1452 			s = ctx->ninsns;
1453 			if (is_signed_bpf_cond(BPF_OP(code))) {
1454 				emit_sextw_alt(&rs, RV_REG_T1, ctx);
1455 				emit_sextw_alt(&rd, RV_REG_T2, ctx);
1456 			} else {
1457 				emit_zextw_alt(&rs, RV_REG_T1, ctx);
1458 				emit_zextw_alt(&rd, RV_REG_T2, ctx);
1459 			}
1460 			e = ctx->ninsns;
1461 
1462 			/* Adjust for extra insns */
1463 			rvoff -= ninsns_rvoff(e - s);
1464 		}
1465 
1466 		if (BPF_OP(code) == BPF_JSET) {
1467 			/* Adjust for and */
1468 			rvoff -= 4;
1469 			emit_and(RV_REG_T1, rd, rs, ctx);
1470 			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1471 		} else {
1472 			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1473 		}
1474 		break;
1475 
1476 	/* IF (dst COND imm) JUMP off */
1477 	case BPF_JMP | BPF_JEQ | BPF_K:
1478 	case BPF_JMP32 | BPF_JEQ | BPF_K:
1479 	case BPF_JMP | BPF_JGT | BPF_K:
1480 	case BPF_JMP32 | BPF_JGT | BPF_K:
1481 	case BPF_JMP | BPF_JLT | BPF_K:
1482 	case BPF_JMP32 | BPF_JLT | BPF_K:
1483 	case BPF_JMP | BPF_JGE | BPF_K:
1484 	case BPF_JMP32 | BPF_JGE | BPF_K:
1485 	case BPF_JMP | BPF_JLE | BPF_K:
1486 	case BPF_JMP32 | BPF_JLE | BPF_K:
1487 	case BPF_JMP | BPF_JNE | BPF_K:
1488 	case BPF_JMP32 | BPF_JNE | BPF_K:
1489 	case BPF_JMP | BPF_JSGT | BPF_K:
1490 	case BPF_JMP32 | BPF_JSGT | BPF_K:
1491 	case BPF_JMP | BPF_JSLT | BPF_K:
1492 	case BPF_JMP32 | BPF_JSLT | BPF_K:
1493 	case BPF_JMP | BPF_JSGE | BPF_K:
1494 	case BPF_JMP32 | BPF_JSGE | BPF_K:
1495 	case BPF_JMP | BPF_JSLE | BPF_K:
1496 	case BPF_JMP32 | BPF_JSLE | BPF_K:
1497 		rvoff = rv_offset(i, off, ctx);
1498 		s = ctx->ninsns;
1499 		if (imm)
1500 			emit_imm(RV_REG_T1, imm, ctx);
1501 		rs = imm ? RV_REG_T1 : RV_REG_ZERO;
1502 		if (!is64) {
1503 			if (is_signed_bpf_cond(BPF_OP(code))) {
1504 				emit_sextw_alt(&rd, RV_REG_T2, ctx);
1505 				/* rs has been sign extended */
1506 			} else {
1507 				emit_zextw_alt(&rd, RV_REG_T2, ctx);
1508 				if (imm)
1509 					emit_zextw(rs, rs, ctx);
1510 			}
1511 		}
1512 		e = ctx->ninsns;
1513 
1514 		/* Adjust for extra insns */
1515 		rvoff -= ninsns_rvoff(e - s);
1516 		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1517 		break;
1518 
1519 	case BPF_JMP | BPF_JSET | BPF_K:
1520 	case BPF_JMP32 | BPF_JSET | BPF_K:
1521 		rvoff = rv_offset(i, off, ctx);
1522 		s = ctx->ninsns;
1523 		if (is_12b_int(imm)) {
1524 			emit_andi(RV_REG_T1, rd, imm, ctx);
1525 		} else {
1526 			emit_imm(RV_REG_T1, imm, ctx);
1527 			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1528 		}
1529 		/* For jset32, we should clear the upper 32 bits of t1, but
1530 		 * sign-extension is sufficient here and saves one instruction,
1531 		 * as t1 is used only in comparison against zero.
1532 		 */
1533 		if (!is64 && imm < 0)
1534 			emit_sextw(RV_REG_T1, RV_REG_T1, ctx);
1535 		e = ctx->ninsns;
1536 		rvoff -= ninsns_rvoff(e - s);
1537 		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1538 		break;
1539 
1540 	/* function call */
1541 	case BPF_JMP | BPF_CALL:
1542 	{
1543 		bool fixed_addr;
1544 		u64 addr;
1545 
1546 		/* Inline calls to bpf_get_smp_processor_id()
1547 		 *
1548 		 * RV_REG_TP holds the address of the current CPU's task_struct and thread_info is
1549 		 * at offset 0 in task_struct.
1550 		 * Load cpu from thread_info:
1551 		 *     Set R0 to ((struct thread_info *)(RV_REG_TP))->cpu
1552 		 *
1553 		 * This replicates the implementation of raw_smp_processor_id() on RISCV
1554 		 */
1555 		if (insn->src_reg == 0 && insn->imm == BPF_FUNC_get_smp_processor_id) {
1556 			/* Load current CPU number in R0 */
1557 			emit_ld(bpf_to_rv_reg(BPF_REG_0, ctx), offsetof(struct thread_info, cpu),
1558 				RV_REG_TP, ctx);
1559 			break;
1560 		}
1561 
1562 		mark_call(ctx);
1563 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1564 					    &addr, &fixed_addr);
1565 		if (ret < 0)
1566 			return ret;
1567 
1568 		if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
1569 			const struct btf_func_model *fm;
1570 			int idx;
1571 
1572 			fm = bpf_jit_find_kfunc_model(ctx->prog, insn);
1573 			if (!fm)
1574 				return -EINVAL;
1575 
1576 			for (idx = 0; idx < fm->nr_args; idx++) {
1577 				u8 reg = bpf_to_rv_reg(BPF_REG_1 + idx, ctx);
1578 
1579 				if (fm->arg_size[idx] == sizeof(int))
1580 					emit_sextw(reg, reg, ctx);
1581 			}
1582 		}
1583 
1584 		ret = emit_call(addr, fixed_addr, ctx);
1585 		if (ret)
1586 			return ret;
1587 
1588 		if (insn->src_reg != BPF_PSEUDO_CALL)
1589 			emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1590 		break;
1591 	}
1592 	/* tail call */
1593 	case BPF_JMP | BPF_TAIL_CALL:
1594 		if (emit_bpf_tail_call(i, ctx))
1595 			return -1;
1596 		break;
1597 
1598 	/* function return */
1599 	case BPF_JMP | BPF_EXIT:
1600 		if (i == ctx->prog->len - 1)
1601 			break;
1602 
1603 		rvoff = epilogue_offset(ctx);
1604 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1605 		if (ret)
1606 			return ret;
1607 		break;
1608 
1609 	/* dst = imm64 */
1610 	case BPF_LD | BPF_IMM | BPF_DW:
1611 	{
1612 		struct bpf_insn insn1 = insn[1];
1613 		u64 imm64;
1614 
1615 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1616 		if (bpf_pseudo_func(insn)) {
1617 			/* fixed-length insns for extra jit pass */
1618 			ret = emit_addr(rd, imm64, extra_pass, ctx);
1619 			if (ret)
1620 				return ret;
1621 		} else {
1622 			emit_imm(rd, imm64, ctx);
1623 		}
1624 
1625 		return 1;
1626 	}
1627 
1628 	/* LDX: dst = *(unsigned size *)(src + off) */
1629 	case BPF_LDX | BPF_MEM | BPF_B:
1630 	case BPF_LDX | BPF_MEM | BPF_H:
1631 	case BPF_LDX | BPF_MEM | BPF_W:
1632 	case BPF_LDX | BPF_MEM | BPF_DW:
1633 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1634 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1635 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1636 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1637 	/* LDSX: dst = *(signed size *)(src + off) */
1638 	case BPF_LDX | BPF_MEMSX | BPF_B:
1639 	case BPF_LDX | BPF_MEMSX | BPF_H:
1640 	case BPF_LDX | BPF_MEMSX | BPF_W:
1641 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1642 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1643 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1644 	/* LDX | PROBE_MEM32: dst = *(unsigned size *)(src + RV_REG_ARENA + off) */
1645 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
1646 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
1647 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
1648 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
1649 	{
1650 		int insn_len, insns_start;
1651 		bool sign_ext;
1652 
1653 		sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
1654 			   BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
1655 
1656 		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1657 			emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx);
1658 			rs = RV_REG_T2;
1659 		}
1660 
1661 		switch (BPF_SIZE(code)) {
1662 		case BPF_B:
1663 			if (is_12b_int(off)) {
1664 				insns_start = ctx->ninsns;
1665 				if (sign_ext)
1666 					emit(rv_lb(rd, off, rs), ctx);
1667 				else
1668 					emit(rv_lbu(rd, off, rs), ctx);
1669 				insn_len = ctx->ninsns - insns_start;
1670 				break;
1671 			}
1672 
1673 			emit_imm(RV_REG_T1, off, ctx);
1674 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1675 			insns_start = ctx->ninsns;
1676 			if (sign_ext)
1677 				emit(rv_lb(rd, 0, RV_REG_T1), ctx);
1678 			else
1679 				emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1680 			insn_len = ctx->ninsns - insns_start;
1681 			break;
1682 		case BPF_H:
1683 			if (is_12b_int(off)) {
1684 				insns_start = ctx->ninsns;
1685 				if (sign_ext)
1686 					emit(rv_lh(rd, off, rs), ctx);
1687 				else
1688 					emit(rv_lhu(rd, off, rs), ctx);
1689 				insn_len = ctx->ninsns - insns_start;
1690 				break;
1691 			}
1692 
1693 			emit_imm(RV_REG_T1, off, ctx);
1694 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1695 			insns_start = ctx->ninsns;
1696 			if (sign_ext)
1697 				emit(rv_lh(rd, 0, RV_REG_T1), ctx);
1698 			else
1699 				emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1700 			insn_len = ctx->ninsns - insns_start;
1701 			break;
1702 		case BPF_W:
1703 			if (is_12b_int(off)) {
1704 				insns_start = ctx->ninsns;
1705 				if (sign_ext)
1706 					emit(rv_lw(rd, off, rs), ctx);
1707 				else
1708 					emit(rv_lwu(rd, off, rs), ctx);
1709 				insn_len = ctx->ninsns - insns_start;
1710 				break;
1711 			}
1712 
1713 			emit_imm(RV_REG_T1, off, ctx);
1714 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1715 			insns_start = ctx->ninsns;
1716 			if (sign_ext)
1717 				emit(rv_lw(rd, 0, RV_REG_T1), ctx);
1718 			else
1719 				emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1720 			insn_len = ctx->ninsns - insns_start;
1721 			break;
1722 		case BPF_DW:
1723 			if (is_12b_int(off)) {
1724 				insns_start = ctx->ninsns;
1725 				emit_ld(rd, off, rs, ctx);
1726 				insn_len = ctx->ninsns - insns_start;
1727 				break;
1728 			}
1729 
1730 			emit_imm(RV_REG_T1, off, ctx);
1731 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1732 			insns_start = ctx->ninsns;
1733 			emit_ld(rd, 0, RV_REG_T1, ctx);
1734 			insn_len = ctx->ninsns - insns_start;
1735 			break;
1736 		}
1737 
1738 		ret = add_exception_handler(insn, ctx, rd, insn_len);
1739 		if (ret)
1740 			return ret;
1741 
1742 		if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
1743 			return 1;
1744 		break;
1745 	}
1746 	/* speculation barrier */
1747 	case BPF_ST | BPF_NOSPEC:
1748 		break;
1749 
1750 	/* ST: *(size *)(dst + off) = imm */
1751 	case BPF_ST | BPF_MEM | BPF_B:
1752 		emit_imm(RV_REG_T1, imm, ctx);
1753 		if (is_12b_int(off)) {
1754 			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1755 			break;
1756 		}
1757 
1758 		emit_imm(RV_REG_T2, off, ctx);
1759 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1760 		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1761 		break;
1762 
1763 	case BPF_ST | BPF_MEM | BPF_H:
1764 		emit_imm(RV_REG_T1, imm, ctx);
1765 		if (is_12b_int(off)) {
1766 			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1767 			break;
1768 		}
1769 
1770 		emit_imm(RV_REG_T2, off, ctx);
1771 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1772 		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1773 		break;
1774 	case BPF_ST | BPF_MEM | BPF_W:
1775 		emit_imm(RV_REG_T1, imm, ctx);
1776 		if (is_12b_int(off)) {
1777 			emit_sw(rd, off, RV_REG_T1, ctx);
1778 			break;
1779 		}
1780 
1781 		emit_imm(RV_REG_T2, off, ctx);
1782 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1783 		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1784 		break;
1785 	case BPF_ST | BPF_MEM | BPF_DW:
1786 		emit_imm(RV_REG_T1, imm, ctx);
1787 		if (is_12b_int(off)) {
1788 			emit_sd(rd, off, RV_REG_T1, ctx);
1789 			break;
1790 		}
1791 
1792 		emit_imm(RV_REG_T2, off, ctx);
1793 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1794 		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1795 		break;
1796 
1797 	case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
1798 	case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
1799 	case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
1800 	case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
1801 	{
1802 		int insn_len, insns_start;
1803 
1804 		emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx);
1805 		rd = RV_REG_T3;
1806 
1807 		/* Load imm to a register then store it */
1808 		emit_imm(RV_REG_T1, imm, ctx);
1809 
1810 		switch (BPF_SIZE(code)) {
1811 		case BPF_B:
1812 			if (is_12b_int(off)) {
1813 				insns_start = ctx->ninsns;
1814 				emit(rv_sb(rd, off, RV_REG_T1), ctx);
1815 				insn_len = ctx->ninsns - insns_start;
1816 				break;
1817 			}
1818 
1819 			emit_imm(RV_REG_T2, off, ctx);
1820 			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1821 			insns_start = ctx->ninsns;
1822 			emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1823 			insn_len = ctx->ninsns - insns_start;
1824 			break;
1825 		case BPF_H:
1826 			if (is_12b_int(off)) {
1827 				insns_start = ctx->ninsns;
1828 				emit(rv_sh(rd, off, RV_REG_T1), ctx);
1829 				insn_len = ctx->ninsns - insns_start;
1830 				break;
1831 			}
1832 
1833 			emit_imm(RV_REG_T2, off, ctx);
1834 			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1835 			insns_start = ctx->ninsns;
1836 			emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1837 			insn_len = ctx->ninsns - insns_start;
1838 			break;
1839 		case BPF_W:
1840 			if (is_12b_int(off)) {
1841 				insns_start = ctx->ninsns;
1842 				emit_sw(rd, off, RV_REG_T1, ctx);
1843 				insn_len = ctx->ninsns - insns_start;
1844 				break;
1845 			}
1846 
1847 			emit_imm(RV_REG_T2, off, ctx);
1848 			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1849 			insns_start = ctx->ninsns;
1850 			emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1851 			insn_len = ctx->ninsns - insns_start;
1852 			break;
1853 		case BPF_DW:
1854 			if (is_12b_int(off)) {
1855 				insns_start = ctx->ninsns;
1856 				emit_sd(rd, off, RV_REG_T1, ctx);
1857 				insn_len = ctx->ninsns - insns_start;
1858 				break;
1859 			}
1860 
1861 			emit_imm(RV_REG_T2, off, ctx);
1862 			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1863 			insns_start = ctx->ninsns;
1864 			emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1865 			insn_len = ctx->ninsns - insns_start;
1866 			break;
1867 		}
1868 
1869 		ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
1870 					    insn_len);
1871 		if (ret)
1872 			return ret;
1873 
1874 		break;
1875 	}
1876 
1877 	/* STX: *(size *)(dst + off) = src */
1878 	case BPF_STX | BPF_MEM | BPF_B:
1879 		if (is_12b_int(off)) {
1880 			emit(rv_sb(rd, off, rs), ctx);
1881 			break;
1882 		}
1883 
1884 		emit_imm(RV_REG_T1, off, ctx);
1885 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1886 		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1887 		break;
1888 	case BPF_STX | BPF_MEM | BPF_H:
1889 		if (is_12b_int(off)) {
1890 			emit(rv_sh(rd, off, rs), ctx);
1891 			break;
1892 		}
1893 
1894 		emit_imm(RV_REG_T1, off, ctx);
1895 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1896 		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1897 		break;
1898 	case BPF_STX | BPF_MEM | BPF_W:
1899 		if (is_12b_int(off)) {
1900 			emit_sw(rd, off, rs, ctx);
1901 			break;
1902 		}
1903 
1904 		emit_imm(RV_REG_T1, off, ctx);
1905 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1906 		emit_sw(RV_REG_T1, 0, rs, ctx);
1907 		break;
1908 	case BPF_STX | BPF_MEM | BPF_DW:
1909 		if (is_12b_int(off)) {
1910 			emit_sd(rd, off, rs, ctx);
1911 			break;
1912 		}
1913 
1914 		emit_imm(RV_REG_T1, off, ctx);
1915 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1916 		emit_sd(RV_REG_T1, 0, rs, ctx);
1917 		break;
1918 	case BPF_STX | BPF_ATOMIC | BPF_W:
1919 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1920 		emit_atomic(rd, rs, off, imm,
1921 			    BPF_SIZE(code) == BPF_DW, ctx);
1922 		break;
1923 
1924 	case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
1925 	case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
1926 	case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
1927 	case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
1928 	{
1929 		int insn_len, insns_start;
1930 
1931 		emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx);
1932 		rd = RV_REG_T2;
1933 
1934 		switch (BPF_SIZE(code)) {
1935 		case BPF_B:
1936 			if (is_12b_int(off)) {
1937 				insns_start = ctx->ninsns;
1938 				emit(rv_sb(rd, off, rs), ctx);
1939 				insn_len = ctx->ninsns - insns_start;
1940 				break;
1941 			}
1942 
1943 			emit_imm(RV_REG_T1, off, ctx);
1944 			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1945 			insns_start = ctx->ninsns;
1946 			emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1947 			insn_len = ctx->ninsns - insns_start;
1948 			break;
1949 		case BPF_H:
1950 			if (is_12b_int(off)) {
1951 				insns_start = ctx->ninsns;
1952 				emit(rv_sh(rd, off, rs), ctx);
1953 				insn_len = ctx->ninsns - insns_start;
1954 				break;
1955 			}
1956 
1957 			emit_imm(RV_REG_T1, off, ctx);
1958 			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1959 			insns_start = ctx->ninsns;
1960 			emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1961 			insn_len = ctx->ninsns - insns_start;
1962 			break;
1963 		case BPF_W:
1964 			if (is_12b_int(off)) {
1965 				insns_start = ctx->ninsns;
1966 				emit_sw(rd, off, rs, ctx);
1967 				insn_len = ctx->ninsns - insns_start;
1968 				break;
1969 			}
1970 
1971 			emit_imm(RV_REG_T1, off, ctx);
1972 			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1973 			insns_start = ctx->ninsns;
1974 			emit_sw(RV_REG_T1, 0, rs, ctx);
1975 			insn_len = ctx->ninsns - insns_start;
1976 			break;
1977 		case BPF_DW:
1978 			if (is_12b_int(off)) {
1979 				insns_start = ctx->ninsns;
1980 				emit_sd(rd, off, rs, ctx);
1981 				insn_len = ctx->ninsns - insns_start;
1982 				break;
1983 			}
1984 
1985 			emit_imm(RV_REG_T1, off, ctx);
1986 			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1987 			insns_start = ctx->ninsns;
1988 			emit_sd(RV_REG_T1, 0, rs, ctx);
1989 			insn_len = ctx->ninsns - insns_start;
1990 			break;
1991 		}
1992 
1993 		ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
1994 					    insn_len);
1995 		if (ret)
1996 			return ret;
1997 
1998 		break;
1999 	}
2000 
2001 	default:
2002 		pr_err("bpf-jit: unknown opcode %02x\n", code);
2003 		return -EINVAL;
2004 	}
2005 
2006 	return 0;
2007 }
2008 
2009 void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
2010 {
2011 	int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
2012 
2013 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, STACK_ALIGN);
2014 	if (bpf_stack_adjust)
2015 		mark_fp(ctx);
2016 
2017 	if (seen_reg(RV_REG_RA, ctx))
2018 		stack_adjust += 8;
2019 	stack_adjust += 8; /* RV_REG_FP */
2020 	if (seen_reg(RV_REG_S1, ctx))
2021 		stack_adjust += 8;
2022 	if (seen_reg(RV_REG_S2, ctx))
2023 		stack_adjust += 8;
2024 	if (seen_reg(RV_REG_S3, ctx))
2025 		stack_adjust += 8;
2026 	if (seen_reg(RV_REG_S4, ctx))
2027 		stack_adjust += 8;
2028 	if (seen_reg(RV_REG_S5, ctx))
2029 		stack_adjust += 8;
2030 	if (seen_reg(RV_REG_S6, ctx))
2031 		stack_adjust += 8;
2032 	if (ctx->arena_vm_start)
2033 		stack_adjust += 8;
2034 
2035 	stack_adjust = round_up(stack_adjust, STACK_ALIGN);
2036 	stack_adjust += bpf_stack_adjust;
2037 
2038 	store_offset = stack_adjust - 8;
2039 
2040 	/* emit kcfi type preamble immediately before the  first insn */
2041 	emit_kcfi(is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash, ctx);
2042 
2043 	/* nops reserved for auipc+jalr pair */
2044 	for (i = 0; i < RV_FENTRY_NINSNS; i++)
2045 		emit(rv_nop(), ctx);
2046 
2047 	/* First instruction is always setting the tail-call-counter
2048 	 * (TCC) register. This instruction is skipped for tail calls.
2049 	 * Force using a 4-byte (non-compressed) instruction.
2050 	 */
2051 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
2052 
2053 	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
2054 
2055 	if (seen_reg(RV_REG_RA, ctx)) {
2056 		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
2057 		store_offset -= 8;
2058 	}
2059 	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
2060 	store_offset -= 8;
2061 	if (seen_reg(RV_REG_S1, ctx)) {
2062 		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
2063 		store_offset -= 8;
2064 	}
2065 	if (seen_reg(RV_REG_S2, ctx)) {
2066 		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
2067 		store_offset -= 8;
2068 	}
2069 	if (seen_reg(RV_REG_S3, ctx)) {
2070 		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
2071 		store_offset -= 8;
2072 	}
2073 	if (seen_reg(RV_REG_S4, ctx)) {
2074 		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
2075 		store_offset -= 8;
2076 	}
2077 	if (seen_reg(RV_REG_S5, ctx)) {
2078 		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
2079 		store_offset -= 8;
2080 	}
2081 	if (seen_reg(RV_REG_S6, ctx)) {
2082 		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
2083 		store_offset -= 8;
2084 	}
2085 	if (ctx->arena_vm_start) {
2086 		emit_sd(RV_REG_SP, store_offset, RV_REG_ARENA, ctx);
2087 		store_offset -= 8;
2088 	}
2089 
2090 	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
2091 
2092 	if (bpf_stack_adjust)
2093 		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
2094 
2095 	/* Program contains calls and tail calls, so RV_REG_TCC need
2096 	 * to be saved across calls.
2097 	 */
2098 	if (seen_tail_call(ctx) && seen_call(ctx))
2099 		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
2100 
2101 	ctx->stack_size = stack_adjust;
2102 
2103 	if (ctx->arena_vm_start)
2104 		emit_imm(RV_REG_ARENA, ctx->arena_vm_start, ctx);
2105 }
2106 
2107 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
2108 {
2109 	__build_epilogue(false, ctx);
2110 }
2111 
2112 bool bpf_jit_supports_kfunc_call(void)
2113 {
2114 	return true;
2115 }
2116 
2117 bool bpf_jit_supports_ptr_xchg(void)
2118 {
2119 	return true;
2120 }
2121 
2122 bool bpf_jit_supports_arena(void)
2123 {
2124 	return true;
2125 }
2126 
2127 bool bpf_jit_supports_percpu_insn(void)
2128 {
2129 	return true;
2130 }
2131 
2132 bool bpf_jit_inlines_helper_call(s32 imm)
2133 {
2134 	switch (imm) {
2135 	case BPF_FUNC_get_smp_processor_id:
2136 		return true;
2137 	default:
2138 		return false;
2139 	}
2140 }
2141