xref: /linux/arch/loongarch/net/bpf_jit.c (revision 6c8c1406)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for LoongArch
4  *
5  * Copyright (C) 2022 Loongson Technology Corporation Limited
6  */
7 #include "bpf_jit.h"
8 
9 #define REG_TCC		LOONGARCH_GPR_A6
10 #define TCC_SAVED	LOONGARCH_GPR_S5
11 
12 #define SAVE_RA		BIT(0)
13 #define SAVE_TCC	BIT(1)
14 
15 static const int regmap[] = {
16 	/* return value from in-kernel function, and exit value for eBPF program */
17 	[BPF_REG_0] = LOONGARCH_GPR_A5,
18 	/* arguments from eBPF program to in-kernel function */
19 	[BPF_REG_1] = LOONGARCH_GPR_A0,
20 	[BPF_REG_2] = LOONGARCH_GPR_A1,
21 	[BPF_REG_3] = LOONGARCH_GPR_A2,
22 	[BPF_REG_4] = LOONGARCH_GPR_A3,
23 	[BPF_REG_5] = LOONGARCH_GPR_A4,
24 	/* callee saved registers that in-kernel function will preserve */
25 	[BPF_REG_6] = LOONGARCH_GPR_S0,
26 	[BPF_REG_7] = LOONGARCH_GPR_S1,
27 	[BPF_REG_8] = LOONGARCH_GPR_S2,
28 	[BPF_REG_9] = LOONGARCH_GPR_S3,
29 	/* read-only frame pointer to access stack */
30 	[BPF_REG_FP] = LOONGARCH_GPR_S4,
31 	/* temporary register for blinding constants */
32 	[BPF_REG_AX] = LOONGARCH_GPR_T0,
33 };
34 
35 static void mark_call(struct jit_ctx *ctx)
36 {
37 	ctx->flags |= SAVE_RA;
38 }
39 
40 static void mark_tail_call(struct jit_ctx *ctx)
41 {
42 	ctx->flags |= SAVE_TCC;
43 }
44 
45 static bool seen_call(struct jit_ctx *ctx)
46 {
47 	return (ctx->flags & SAVE_RA);
48 }
49 
50 static bool seen_tail_call(struct jit_ctx *ctx)
51 {
52 	return (ctx->flags & SAVE_TCC);
53 }
54 
55 static u8 tail_call_reg(struct jit_ctx *ctx)
56 {
57 	if (seen_call(ctx))
58 		return TCC_SAVED;
59 
60 	return REG_TCC;
61 }
62 
63 /*
64  * eBPF prog stack layout:
65  *
66  *                                        high
67  * original $sp ------------> +-------------------------+ <--LOONGARCH_GPR_FP
68  *                            |           $ra           |
69  *                            +-------------------------+
70  *                            |           $fp           |
71  *                            +-------------------------+
72  *                            |           $s0           |
73  *                            +-------------------------+
74  *                            |           $s1           |
75  *                            +-------------------------+
76  *                            |           $s2           |
77  *                            +-------------------------+
78  *                            |           $s3           |
79  *                            +-------------------------+
80  *                            |           $s4           |
81  *                            +-------------------------+
82  *                            |           $s5           |
83  *                            +-------------------------+ <--BPF_REG_FP
84  *                            |  prog->aux->stack_depth |
85  *                            |        (optional)       |
86  * current $sp -------------> +-------------------------+
87  *                                        low
88  */
89 static void build_prologue(struct jit_ctx *ctx)
90 {
91 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
92 
93 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
94 
95 	/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
96 	stack_adjust += sizeof(long) * 8;
97 
98 	stack_adjust = round_up(stack_adjust, 16);
99 	stack_adjust += bpf_stack_adjust;
100 
101 	/*
102 	 * First instruction initializes the tail call count (TCC).
103 	 * On tail call we skip this instruction, and the TCC is
104 	 * passed in REG_TCC from the caller.
105 	 */
106 	emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
107 
108 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
109 
110 	store_offset = stack_adjust - sizeof(long);
111 	emit_insn(ctx, std, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, store_offset);
112 
113 	store_offset -= sizeof(long);
114 	emit_insn(ctx, std, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, store_offset);
115 
116 	store_offset -= sizeof(long);
117 	emit_insn(ctx, std, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, store_offset);
118 
119 	store_offset -= sizeof(long);
120 	emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, store_offset);
121 
122 	store_offset -= sizeof(long);
123 	emit_insn(ctx, std, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, store_offset);
124 
125 	store_offset -= sizeof(long);
126 	emit_insn(ctx, std, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, store_offset);
127 
128 	store_offset -= sizeof(long);
129 	emit_insn(ctx, std, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, store_offset);
130 
131 	store_offset -= sizeof(long);
132 	emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
133 
134 	emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
135 
136 	if (bpf_stack_adjust)
137 		emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
138 
139 	/*
140 	 * Program contains calls and tail calls, so REG_TCC need
141 	 * to be saved across calls.
142 	 */
143 	if (seen_tail_call(ctx) && seen_call(ctx))
144 		move_reg(ctx, TCC_SAVED, REG_TCC);
145 
146 	ctx->stack_size = stack_adjust;
147 }
148 
149 static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
150 {
151 	int stack_adjust = ctx->stack_size;
152 	int load_offset;
153 
154 	load_offset = stack_adjust - sizeof(long);
155 	emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, load_offset);
156 
157 	load_offset -= sizeof(long);
158 	emit_insn(ctx, ldd, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, load_offset);
159 
160 	load_offset -= sizeof(long);
161 	emit_insn(ctx, ldd, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, load_offset);
162 
163 	load_offset -= sizeof(long);
164 	emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, load_offset);
165 
166 	load_offset -= sizeof(long);
167 	emit_insn(ctx, ldd, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, load_offset);
168 
169 	load_offset -= sizeof(long);
170 	emit_insn(ctx, ldd, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, load_offset);
171 
172 	load_offset -= sizeof(long);
173 	emit_insn(ctx, ldd, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, load_offset);
174 
175 	load_offset -= sizeof(long);
176 	emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
177 
178 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
179 
180 	if (!is_tail_call) {
181 		/* Set return value */
182 		move_reg(ctx, LOONGARCH_GPR_A0, regmap[BPF_REG_0]);
183 		/* Return to the caller */
184 		emit_insn(ctx, jirl, LOONGARCH_GPR_RA, LOONGARCH_GPR_ZERO, 0);
185 	} else {
186 		/*
187 		 * Call the next bpf prog and skip the first instruction
188 		 * of TCC initialization.
189 		 */
190 		emit_insn(ctx, jirl, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, 1);
191 	}
192 }
193 
194 static void build_epilogue(struct jit_ctx *ctx)
195 {
196 	__build_epilogue(ctx, false);
197 }
198 
199 bool bpf_jit_supports_kfunc_call(void)
200 {
201 	return true;
202 }
203 
204 /* initialized on the first pass of build_body() */
205 static int out_offset = -1;
206 static int emit_bpf_tail_call(struct jit_ctx *ctx)
207 {
208 	int off;
209 	u8 tcc = tail_call_reg(ctx);
210 	u8 a1 = LOONGARCH_GPR_A1;
211 	u8 a2 = LOONGARCH_GPR_A2;
212 	u8 t1 = LOONGARCH_GPR_T1;
213 	u8 t2 = LOONGARCH_GPR_T2;
214 	u8 t3 = LOONGARCH_GPR_T3;
215 	const int idx0 = ctx->idx;
216 
217 #define cur_offset (ctx->idx - idx0)
218 #define jmp_offset (out_offset - (cur_offset))
219 
220 	/*
221 	 * a0: &ctx
222 	 * a1: &array
223 	 * a2: index
224 	 *
225 	 * if (index >= array->map.max_entries)
226 	 *	 goto out;
227 	 */
228 	off = offsetof(struct bpf_array, map.max_entries);
229 	emit_insn(ctx, ldwu, t1, a1, off);
230 	/* bgeu $a2, $t1, jmp_offset */
231 	if (emit_tailcall_jmp(ctx, BPF_JGE, a2, t1, jmp_offset) < 0)
232 		goto toofar;
233 
234 	/*
235 	 * if (--TCC < 0)
236 	 *	 goto out;
237 	 */
238 	emit_insn(ctx, addid, REG_TCC, tcc, -1);
239 	if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
240 		goto toofar;
241 
242 	/*
243 	 * prog = array->ptrs[index];
244 	 * if (!prog)
245 	 *	 goto out;
246 	 */
247 	emit_insn(ctx, alsld, t2, a2, a1, 2);
248 	off = offsetof(struct bpf_array, ptrs);
249 	emit_insn(ctx, ldd, t2, t2, off);
250 	/* beq $t2, $zero, jmp_offset */
251 	if (emit_tailcall_jmp(ctx, BPF_JEQ, t2, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
252 		goto toofar;
253 
254 	/* goto *(prog->bpf_func + 4); */
255 	off = offsetof(struct bpf_prog, bpf_func);
256 	emit_insn(ctx, ldd, t3, t2, off);
257 	__build_epilogue(ctx, true);
258 
259 	/* out: */
260 	if (out_offset == -1)
261 		out_offset = cur_offset;
262 	if (cur_offset != out_offset) {
263 		pr_err_once("tail_call out_offset = %d, expected %d!\n",
264 			    cur_offset, out_offset);
265 		return -1;
266 	}
267 
268 	return 0;
269 
270 toofar:
271 	pr_info_once("tail_call: jump too far\n");
272 	return -1;
273 #undef cur_offset
274 #undef jmp_offset
275 }
276 
277 static void emit_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
278 {
279 	const u8 t1 = LOONGARCH_GPR_T1;
280 	const u8 t2 = LOONGARCH_GPR_T2;
281 	const u8 t3 = LOONGARCH_GPR_T3;
282 	const u8 r0 = regmap[BPF_REG_0];
283 	const u8 src = regmap[insn->src_reg];
284 	const u8 dst = regmap[insn->dst_reg];
285 	const s16 off = insn->off;
286 	const s32 imm = insn->imm;
287 	const bool isdw = BPF_SIZE(insn->code) == BPF_DW;
288 
289 	move_imm(ctx, t1, off, false);
290 	emit_insn(ctx, addd, t1, dst, t1);
291 	move_reg(ctx, t3, src);
292 
293 	switch (imm) {
294 	/* lock *(size *)(dst + off) <op>= src */
295 	case BPF_ADD:
296 		if (isdw)
297 			emit_insn(ctx, amaddd, t2, t1, src);
298 		else
299 			emit_insn(ctx, amaddw, t2, t1, src);
300 		break;
301 	case BPF_AND:
302 		if (isdw)
303 			emit_insn(ctx, amandd, t2, t1, src);
304 		else
305 			emit_insn(ctx, amandw, t2, t1, src);
306 		break;
307 	case BPF_OR:
308 		if (isdw)
309 			emit_insn(ctx, amord, t2, t1, src);
310 		else
311 			emit_insn(ctx, amorw, t2, t1, src);
312 		break;
313 	case BPF_XOR:
314 		if (isdw)
315 			emit_insn(ctx, amxord, t2, t1, src);
316 		else
317 			emit_insn(ctx, amxorw, t2, t1, src);
318 		break;
319 	/* src = atomic_fetch_<op>(dst + off, src) */
320 	case BPF_ADD | BPF_FETCH:
321 		if (isdw) {
322 			emit_insn(ctx, amaddd, src, t1, t3);
323 		} else {
324 			emit_insn(ctx, amaddw, src, t1, t3);
325 			emit_zext_32(ctx, src, true);
326 		}
327 		break;
328 	case BPF_AND | BPF_FETCH:
329 		if (isdw) {
330 			emit_insn(ctx, amandd, src, t1, t3);
331 		} else {
332 			emit_insn(ctx, amandw, src, t1, t3);
333 			emit_zext_32(ctx, src, true);
334 		}
335 		break;
336 	case BPF_OR | BPF_FETCH:
337 		if (isdw) {
338 			emit_insn(ctx, amord, src, t1, t3);
339 		} else {
340 			emit_insn(ctx, amorw, src, t1, t3);
341 			emit_zext_32(ctx, src, true);
342 		}
343 		break;
344 	case BPF_XOR | BPF_FETCH:
345 		if (isdw) {
346 			emit_insn(ctx, amxord, src, t1, t3);
347 		} else {
348 			emit_insn(ctx, amxorw, src, t1, t3);
349 			emit_zext_32(ctx, src, true);
350 		}
351 		break;
352 	/* src = atomic_xchg(dst + off, src); */
353 	case BPF_XCHG:
354 		if (isdw) {
355 			emit_insn(ctx, amswapd, src, t1, t3);
356 		} else {
357 			emit_insn(ctx, amswapw, src, t1, t3);
358 			emit_zext_32(ctx, src, true);
359 		}
360 		break;
361 	/* r0 = atomic_cmpxchg(dst + off, r0, src); */
362 	case BPF_CMPXCHG:
363 		move_reg(ctx, t2, r0);
364 		if (isdw) {
365 			emit_insn(ctx, lld, r0, t1, 0);
366 			emit_insn(ctx, bne, t2, r0, 4);
367 			move_reg(ctx, t3, src);
368 			emit_insn(ctx, scd, t3, t1, 0);
369 			emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -4);
370 		} else {
371 			emit_insn(ctx, llw, r0, t1, 0);
372 			emit_zext_32(ctx, t2, true);
373 			emit_zext_32(ctx, r0, true);
374 			emit_insn(ctx, bne, t2, r0, 4);
375 			move_reg(ctx, t3, src);
376 			emit_insn(ctx, scw, t3, t1, 0);
377 			emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -6);
378 			emit_zext_32(ctx, r0, true);
379 		}
380 		break;
381 	}
382 }
383 
384 static bool is_signed_bpf_cond(u8 cond)
385 {
386 	return cond == BPF_JSGT || cond == BPF_JSLT ||
387 	       cond == BPF_JSGE || cond == BPF_JSLE;
388 }
389 
390 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool extra_pass)
391 {
392 	u8 tm = -1;
393 	u64 func_addr;
394 	bool func_addr_fixed;
395 	int i = insn - ctx->prog->insnsi;
396 	int ret, jmp_offset;
397 	const u8 code = insn->code;
398 	const u8 cond = BPF_OP(code);
399 	const u8 t1 = LOONGARCH_GPR_T1;
400 	const u8 t2 = LOONGARCH_GPR_T2;
401 	const u8 src = regmap[insn->src_reg];
402 	const u8 dst = regmap[insn->dst_reg];
403 	const s16 off = insn->off;
404 	const s32 imm = insn->imm;
405 	const u64 imm64 = (u64)(insn + 1)->imm << 32 | (u32)insn->imm;
406 	const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32;
407 
408 	switch (code) {
409 	/* dst = src */
410 	case BPF_ALU | BPF_MOV | BPF_X:
411 	case BPF_ALU64 | BPF_MOV | BPF_X:
412 		move_reg(ctx, dst, src);
413 		emit_zext_32(ctx, dst, is32);
414 		break;
415 
416 	/* dst = imm */
417 	case BPF_ALU | BPF_MOV | BPF_K:
418 	case BPF_ALU64 | BPF_MOV | BPF_K:
419 		move_imm(ctx, dst, imm, is32);
420 		break;
421 
422 	/* dst = dst + src */
423 	case BPF_ALU | BPF_ADD | BPF_X:
424 	case BPF_ALU64 | BPF_ADD | BPF_X:
425 		emit_insn(ctx, addd, dst, dst, src);
426 		emit_zext_32(ctx, dst, is32);
427 		break;
428 
429 	/* dst = dst + imm */
430 	case BPF_ALU | BPF_ADD | BPF_K:
431 	case BPF_ALU64 | BPF_ADD | BPF_K:
432 		if (is_signed_imm12(imm)) {
433 			emit_insn(ctx, addid, dst, dst, imm);
434 		} else {
435 			move_imm(ctx, t1, imm, is32);
436 			emit_insn(ctx, addd, dst, dst, t1);
437 		}
438 		emit_zext_32(ctx, dst, is32);
439 		break;
440 
441 	/* dst = dst - src */
442 	case BPF_ALU | BPF_SUB | BPF_X:
443 	case BPF_ALU64 | BPF_SUB | BPF_X:
444 		emit_insn(ctx, subd, dst, dst, src);
445 		emit_zext_32(ctx, dst, is32);
446 		break;
447 
448 	/* dst = dst - imm */
449 	case BPF_ALU | BPF_SUB | BPF_K:
450 	case BPF_ALU64 | BPF_SUB | BPF_K:
451 		if (is_signed_imm12(-imm)) {
452 			emit_insn(ctx, addid, dst, dst, -imm);
453 		} else {
454 			move_imm(ctx, t1, imm, is32);
455 			emit_insn(ctx, subd, dst, dst, t1);
456 		}
457 		emit_zext_32(ctx, dst, is32);
458 		break;
459 
460 	/* dst = dst * src */
461 	case BPF_ALU | BPF_MUL | BPF_X:
462 	case BPF_ALU64 | BPF_MUL | BPF_X:
463 		emit_insn(ctx, muld, dst, dst, src);
464 		emit_zext_32(ctx, dst, is32);
465 		break;
466 
467 	/* dst = dst * imm */
468 	case BPF_ALU | BPF_MUL | BPF_K:
469 	case BPF_ALU64 | BPF_MUL | BPF_K:
470 		move_imm(ctx, t1, imm, is32);
471 		emit_insn(ctx, muld, dst, dst, t1);
472 		emit_zext_32(ctx, dst, is32);
473 		break;
474 
475 	/* dst = dst / src */
476 	case BPF_ALU | BPF_DIV | BPF_X:
477 	case BPF_ALU64 | BPF_DIV | BPF_X:
478 		emit_zext_32(ctx, dst, is32);
479 		move_reg(ctx, t1, src);
480 		emit_zext_32(ctx, t1, is32);
481 		emit_insn(ctx, divdu, dst, dst, t1);
482 		emit_zext_32(ctx, dst, is32);
483 		break;
484 
485 	/* dst = dst / imm */
486 	case BPF_ALU | BPF_DIV | BPF_K:
487 	case BPF_ALU64 | BPF_DIV | BPF_K:
488 		move_imm(ctx, t1, imm, is32);
489 		emit_zext_32(ctx, dst, is32);
490 		emit_insn(ctx, divdu, dst, dst, t1);
491 		emit_zext_32(ctx, dst, is32);
492 		break;
493 
494 	/* dst = dst % src */
495 	case BPF_ALU | BPF_MOD | BPF_X:
496 	case BPF_ALU64 | BPF_MOD | BPF_X:
497 		emit_zext_32(ctx, dst, is32);
498 		move_reg(ctx, t1, src);
499 		emit_zext_32(ctx, t1, is32);
500 		emit_insn(ctx, moddu, dst, dst, t1);
501 		emit_zext_32(ctx, dst, is32);
502 		break;
503 
504 	/* dst = dst % imm */
505 	case BPF_ALU | BPF_MOD | BPF_K:
506 	case BPF_ALU64 | BPF_MOD | BPF_K:
507 		move_imm(ctx, t1, imm, is32);
508 		emit_zext_32(ctx, dst, is32);
509 		emit_insn(ctx, moddu, dst, dst, t1);
510 		emit_zext_32(ctx, dst, is32);
511 		break;
512 
513 	/* dst = -dst */
514 	case BPF_ALU | BPF_NEG:
515 	case BPF_ALU64 | BPF_NEG:
516 		move_imm(ctx, t1, imm, is32);
517 		emit_insn(ctx, subd, dst, LOONGARCH_GPR_ZERO, dst);
518 		emit_zext_32(ctx, dst, is32);
519 		break;
520 
521 	/* dst = dst & src */
522 	case BPF_ALU | BPF_AND | BPF_X:
523 	case BPF_ALU64 | BPF_AND | BPF_X:
524 		emit_insn(ctx, and, dst, dst, src);
525 		emit_zext_32(ctx, dst, is32);
526 		break;
527 
528 	/* dst = dst & imm */
529 	case BPF_ALU | BPF_AND | BPF_K:
530 	case BPF_ALU64 | BPF_AND | BPF_K:
531 		if (is_unsigned_imm12(imm)) {
532 			emit_insn(ctx, andi, dst, dst, imm);
533 		} else {
534 			move_imm(ctx, t1, imm, is32);
535 			emit_insn(ctx, and, dst, dst, t1);
536 		}
537 		emit_zext_32(ctx, dst, is32);
538 		break;
539 
540 	/* dst = dst | src */
541 	case BPF_ALU | BPF_OR | BPF_X:
542 	case BPF_ALU64 | BPF_OR | BPF_X:
543 		emit_insn(ctx, or, dst, dst, src);
544 		emit_zext_32(ctx, dst, is32);
545 		break;
546 
547 	/* dst = dst | imm */
548 	case BPF_ALU | BPF_OR | BPF_K:
549 	case BPF_ALU64 | BPF_OR | BPF_K:
550 		if (is_unsigned_imm12(imm)) {
551 			emit_insn(ctx, ori, dst, dst, imm);
552 		} else {
553 			move_imm(ctx, t1, imm, is32);
554 			emit_insn(ctx, or, dst, dst, t1);
555 		}
556 		emit_zext_32(ctx, dst, is32);
557 		break;
558 
559 	/* dst = dst ^ src */
560 	case BPF_ALU | BPF_XOR | BPF_X:
561 	case BPF_ALU64 | BPF_XOR | BPF_X:
562 		emit_insn(ctx, xor, dst, dst, src);
563 		emit_zext_32(ctx, dst, is32);
564 		break;
565 
566 	/* dst = dst ^ imm */
567 	case BPF_ALU | BPF_XOR | BPF_K:
568 	case BPF_ALU64 | BPF_XOR | BPF_K:
569 		if (is_unsigned_imm12(imm)) {
570 			emit_insn(ctx, xori, dst, dst, imm);
571 		} else {
572 			move_imm(ctx, t1, imm, is32);
573 			emit_insn(ctx, xor, dst, dst, t1);
574 		}
575 		emit_zext_32(ctx, dst, is32);
576 		break;
577 
578 	/* dst = dst << src (logical) */
579 	case BPF_ALU | BPF_LSH | BPF_X:
580 		emit_insn(ctx, sllw, dst, dst, src);
581 		emit_zext_32(ctx, dst, is32);
582 		break;
583 
584 	case BPF_ALU64 | BPF_LSH | BPF_X:
585 		emit_insn(ctx, slld, dst, dst, src);
586 		break;
587 
588 	/* dst = dst << imm (logical) */
589 	case BPF_ALU | BPF_LSH | BPF_K:
590 		emit_insn(ctx, slliw, dst, dst, imm);
591 		emit_zext_32(ctx, dst, is32);
592 		break;
593 
594 	case BPF_ALU64 | BPF_LSH | BPF_K:
595 		emit_insn(ctx, sllid, dst, dst, imm);
596 		break;
597 
598 	/* dst = dst >> src (logical) */
599 	case BPF_ALU | BPF_RSH | BPF_X:
600 		emit_insn(ctx, srlw, dst, dst, src);
601 		emit_zext_32(ctx, dst, is32);
602 		break;
603 
604 	case BPF_ALU64 | BPF_RSH | BPF_X:
605 		emit_insn(ctx, srld, dst, dst, src);
606 		break;
607 
608 	/* dst = dst >> imm (logical) */
609 	case BPF_ALU | BPF_RSH | BPF_K:
610 		emit_insn(ctx, srliw, dst, dst, imm);
611 		emit_zext_32(ctx, dst, is32);
612 		break;
613 
614 	case BPF_ALU64 | BPF_RSH | BPF_K:
615 		emit_insn(ctx, srlid, dst, dst, imm);
616 		break;
617 
618 	/* dst = dst >> src (arithmetic) */
619 	case BPF_ALU | BPF_ARSH | BPF_X:
620 		emit_insn(ctx, sraw, dst, dst, src);
621 		emit_zext_32(ctx, dst, is32);
622 		break;
623 
624 	case BPF_ALU64 | BPF_ARSH | BPF_X:
625 		emit_insn(ctx, srad, dst, dst, src);
626 		break;
627 
628 	/* dst = dst >> imm (arithmetic) */
629 	case BPF_ALU | BPF_ARSH | BPF_K:
630 		emit_insn(ctx, sraiw, dst, dst, imm);
631 		emit_zext_32(ctx, dst, is32);
632 		break;
633 
634 	case BPF_ALU64 | BPF_ARSH | BPF_K:
635 		emit_insn(ctx, sraid, dst, dst, imm);
636 		break;
637 
638 	/* dst = BSWAP##imm(dst) */
639 	case BPF_ALU | BPF_END | BPF_FROM_LE:
640 		switch (imm) {
641 		case 16:
642 			/* zero-extend 16 bits into 64 bits */
643 			emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
644 			break;
645 		case 32:
646 			/* zero-extend 32 bits into 64 bits */
647 			emit_zext_32(ctx, dst, is32);
648 			break;
649 		case 64:
650 			/* do nothing */
651 			break;
652 		}
653 		break;
654 
655 	case BPF_ALU | BPF_END | BPF_FROM_BE:
656 		switch (imm) {
657 		case 16:
658 			emit_insn(ctx, revb2h, dst, dst);
659 			/* zero-extend 16 bits into 64 bits */
660 			emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
661 			break;
662 		case 32:
663 			emit_insn(ctx, revb2w, dst, dst);
664 			/* zero-extend 32 bits into 64 bits */
665 			emit_zext_32(ctx, dst, is32);
666 			break;
667 		case 64:
668 			emit_insn(ctx, revbd, dst, dst);
669 			break;
670 		}
671 		break;
672 
673 	/* PC += off if dst cond src */
674 	case BPF_JMP | BPF_JEQ | BPF_X:
675 	case BPF_JMP | BPF_JNE | BPF_X:
676 	case BPF_JMP | BPF_JGT | BPF_X:
677 	case BPF_JMP | BPF_JGE | BPF_X:
678 	case BPF_JMP | BPF_JLT | BPF_X:
679 	case BPF_JMP | BPF_JLE | BPF_X:
680 	case BPF_JMP | BPF_JSGT | BPF_X:
681 	case BPF_JMP | BPF_JSGE | BPF_X:
682 	case BPF_JMP | BPF_JSLT | BPF_X:
683 	case BPF_JMP | BPF_JSLE | BPF_X:
684 	case BPF_JMP32 | BPF_JEQ | BPF_X:
685 	case BPF_JMP32 | BPF_JNE | BPF_X:
686 	case BPF_JMP32 | BPF_JGT | BPF_X:
687 	case BPF_JMP32 | BPF_JGE | BPF_X:
688 	case BPF_JMP32 | BPF_JLT | BPF_X:
689 	case BPF_JMP32 | BPF_JLE | BPF_X:
690 	case BPF_JMP32 | BPF_JSGT | BPF_X:
691 	case BPF_JMP32 | BPF_JSGE | BPF_X:
692 	case BPF_JMP32 | BPF_JSLT | BPF_X:
693 	case BPF_JMP32 | BPF_JSLE | BPF_X:
694 		jmp_offset = bpf2la_offset(i, off, ctx);
695 		move_reg(ctx, t1, dst);
696 		move_reg(ctx, t2, src);
697 		if (is_signed_bpf_cond(BPF_OP(code))) {
698 			emit_sext_32(ctx, t1, is32);
699 			emit_sext_32(ctx, t2, is32);
700 		} else {
701 			emit_zext_32(ctx, t1, is32);
702 			emit_zext_32(ctx, t2, is32);
703 		}
704 		if (emit_cond_jmp(ctx, cond, t1, t2, jmp_offset) < 0)
705 			goto toofar;
706 		break;
707 
708 	/* PC += off if dst cond imm */
709 	case BPF_JMP | BPF_JEQ | BPF_K:
710 	case BPF_JMP | BPF_JNE | BPF_K:
711 	case BPF_JMP | BPF_JGT | BPF_K:
712 	case BPF_JMP | BPF_JGE | BPF_K:
713 	case BPF_JMP | BPF_JLT | BPF_K:
714 	case BPF_JMP | BPF_JLE | BPF_K:
715 	case BPF_JMP | BPF_JSGT | BPF_K:
716 	case BPF_JMP | BPF_JSGE | BPF_K:
717 	case BPF_JMP | BPF_JSLT | BPF_K:
718 	case BPF_JMP | BPF_JSLE | BPF_K:
719 	case BPF_JMP32 | BPF_JEQ | BPF_K:
720 	case BPF_JMP32 | BPF_JNE | BPF_K:
721 	case BPF_JMP32 | BPF_JGT | BPF_K:
722 	case BPF_JMP32 | BPF_JGE | BPF_K:
723 	case BPF_JMP32 | BPF_JLT | BPF_K:
724 	case BPF_JMP32 | BPF_JLE | BPF_K:
725 	case BPF_JMP32 | BPF_JSGT | BPF_K:
726 	case BPF_JMP32 | BPF_JSGE | BPF_K:
727 	case BPF_JMP32 | BPF_JSLT | BPF_K:
728 	case BPF_JMP32 | BPF_JSLE | BPF_K:
729 		jmp_offset = bpf2la_offset(i, off, ctx);
730 		if (imm) {
731 			move_imm(ctx, t1, imm, false);
732 			tm = t1;
733 		} else {
734 			/* If imm is 0, simply use zero register. */
735 			tm = LOONGARCH_GPR_ZERO;
736 		}
737 		move_reg(ctx, t2, dst);
738 		if (is_signed_bpf_cond(BPF_OP(code))) {
739 			emit_sext_32(ctx, tm, is32);
740 			emit_sext_32(ctx, t2, is32);
741 		} else {
742 			emit_zext_32(ctx, tm, is32);
743 			emit_zext_32(ctx, t2, is32);
744 		}
745 		if (emit_cond_jmp(ctx, cond, t2, tm, jmp_offset) < 0)
746 			goto toofar;
747 		break;
748 
749 	/* PC += off if dst & src */
750 	case BPF_JMP | BPF_JSET | BPF_X:
751 	case BPF_JMP32 | BPF_JSET | BPF_X:
752 		jmp_offset = bpf2la_offset(i, off, ctx);
753 		emit_insn(ctx, and, t1, dst, src);
754 		emit_zext_32(ctx, t1, is32);
755 		if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
756 			goto toofar;
757 		break;
758 
759 	/* PC += off if dst & imm */
760 	case BPF_JMP | BPF_JSET | BPF_K:
761 	case BPF_JMP32 | BPF_JSET | BPF_K:
762 		jmp_offset = bpf2la_offset(i, off, ctx);
763 		move_imm(ctx, t1, imm, is32);
764 		emit_insn(ctx, and, t1, dst, t1);
765 		emit_zext_32(ctx, t1, is32);
766 		if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
767 			goto toofar;
768 		break;
769 
770 	/* PC += off */
771 	case BPF_JMP | BPF_JA:
772 		jmp_offset = bpf2la_offset(i, off, ctx);
773 		if (emit_uncond_jmp(ctx, jmp_offset) < 0)
774 			goto toofar;
775 		break;
776 
777 	/* function call */
778 	case BPF_JMP | BPF_CALL:
779 		mark_call(ctx);
780 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
781 					    &func_addr, &func_addr_fixed);
782 		if (ret < 0)
783 			return ret;
784 
785 		move_imm(ctx, t1, func_addr, is32);
786 		emit_insn(ctx, jirl, t1, LOONGARCH_GPR_RA, 0);
787 		move_reg(ctx, regmap[BPF_REG_0], LOONGARCH_GPR_A0);
788 		break;
789 
790 	/* tail call */
791 	case BPF_JMP | BPF_TAIL_CALL:
792 		mark_tail_call(ctx);
793 		if (emit_bpf_tail_call(ctx) < 0)
794 			return -EINVAL;
795 		break;
796 
797 	/* function return */
798 	case BPF_JMP | BPF_EXIT:
799 		emit_sext_32(ctx, regmap[BPF_REG_0], true);
800 
801 		if (i == ctx->prog->len - 1)
802 			break;
803 
804 		jmp_offset = epilogue_offset(ctx);
805 		if (emit_uncond_jmp(ctx, jmp_offset) < 0)
806 			goto toofar;
807 		break;
808 
809 	/* dst = imm64 */
810 	case BPF_LD | BPF_IMM | BPF_DW:
811 		move_imm(ctx, dst, imm64, is32);
812 		return 1;
813 
814 	/* dst = *(size *)(src + off) */
815 	case BPF_LDX | BPF_MEM | BPF_B:
816 	case BPF_LDX | BPF_MEM | BPF_H:
817 	case BPF_LDX | BPF_MEM | BPF_W:
818 	case BPF_LDX | BPF_MEM | BPF_DW:
819 		switch (BPF_SIZE(code)) {
820 		case BPF_B:
821 			if (is_signed_imm12(off)) {
822 				emit_insn(ctx, ldbu, dst, src, off);
823 			} else {
824 				move_imm(ctx, t1, off, is32);
825 				emit_insn(ctx, ldxbu, dst, src, t1);
826 			}
827 			break;
828 		case BPF_H:
829 			if (is_signed_imm12(off)) {
830 				emit_insn(ctx, ldhu, dst, src, off);
831 			} else {
832 				move_imm(ctx, t1, off, is32);
833 				emit_insn(ctx, ldxhu, dst, src, t1);
834 			}
835 			break;
836 		case BPF_W:
837 			if (is_signed_imm12(off)) {
838 				emit_insn(ctx, ldwu, dst, src, off);
839 			} else if (is_signed_imm14(off)) {
840 				emit_insn(ctx, ldptrw, dst, src, off);
841 			} else {
842 				move_imm(ctx, t1, off, is32);
843 				emit_insn(ctx, ldxwu, dst, src, t1);
844 			}
845 			break;
846 		case BPF_DW:
847 			if (is_signed_imm12(off)) {
848 				emit_insn(ctx, ldd, dst, src, off);
849 			} else if (is_signed_imm14(off)) {
850 				emit_insn(ctx, ldptrd, dst, src, off);
851 			} else {
852 				move_imm(ctx, t1, off, is32);
853 				emit_insn(ctx, ldxd, dst, src, t1);
854 			}
855 			break;
856 		}
857 		break;
858 
859 	/* *(size *)(dst + off) = imm */
860 	case BPF_ST | BPF_MEM | BPF_B:
861 	case BPF_ST | BPF_MEM | BPF_H:
862 	case BPF_ST | BPF_MEM | BPF_W:
863 	case BPF_ST | BPF_MEM | BPF_DW:
864 		switch (BPF_SIZE(code)) {
865 		case BPF_B:
866 			move_imm(ctx, t1, imm, is32);
867 			if (is_signed_imm12(off)) {
868 				emit_insn(ctx, stb, t1, dst, off);
869 			} else {
870 				move_imm(ctx, t2, off, is32);
871 				emit_insn(ctx, stxb, t1, dst, t2);
872 			}
873 			break;
874 		case BPF_H:
875 			move_imm(ctx, t1, imm, is32);
876 			if (is_signed_imm12(off)) {
877 				emit_insn(ctx, sth, t1, dst, off);
878 			} else {
879 				move_imm(ctx, t2, off, is32);
880 				emit_insn(ctx, stxh, t1, dst, t2);
881 			}
882 			break;
883 		case BPF_W:
884 			move_imm(ctx, t1, imm, is32);
885 			if (is_signed_imm12(off)) {
886 				emit_insn(ctx, stw, t1, dst, off);
887 			} else if (is_signed_imm14(off)) {
888 				emit_insn(ctx, stptrw, t1, dst, off);
889 			} else {
890 				move_imm(ctx, t2, off, is32);
891 				emit_insn(ctx, stxw, t1, dst, t2);
892 			}
893 			break;
894 		case BPF_DW:
895 			move_imm(ctx, t1, imm, is32);
896 			if (is_signed_imm12(off)) {
897 				emit_insn(ctx, std, t1, dst, off);
898 			} else if (is_signed_imm14(off)) {
899 				emit_insn(ctx, stptrd, t1, dst, off);
900 			} else {
901 				move_imm(ctx, t2, off, is32);
902 				emit_insn(ctx, stxd, t1, dst, t2);
903 			}
904 			break;
905 		}
906 		break;
907 
908 	/* *(size *)(dst + off) = src */
909 	case BPF_STX | BPF_MEM | BPF_B:
910 	case BPF_STX | BPF_MEM | BPF_H:
911 	case BPF_STX | BPF_MEM | BPF_W:
912 	case BPF_STX | BPF_MEM | BPF_DW:
913 		switch (BPF_SIZE(code)) {
914 		case BPF_B:
915 			if (is_signed_imm12(off)) {
916 				emit_insn(ctx, stb, src, dst, off);
917 			} else {
918 				move_imm(ctx, t1, off, is32);
919 				emit_insn(ctx, stxb, src, dst, t1);
920 			}
921 			break;
922 		case BPF_H:
923 			if (is_signed_imm12(off)) {
924 				emit_insn(ctx, sth, src, dst, off);
925 			} else {
926 				move_imm(ctx, t1, off, is32);
927 				emit_insn(ctx, stxh, src, dst, t1);
928 			}
929 			break;
930 		case BPF_W:
931 			if (is_signed_imm12(off)) {
932 				emit_insn(ctx, stw, src, dst, off);
933 			} else if (is_signed_imm14(off)) {
934 				emit_insn(ctx, stptrw, src, dst, off);
935 			} else {
936 				move_imm(ctx, t1, off, is32);
937 				emit_insn(ctx, stxw, src, dst, t1);
938 			}
939 			break;
940 		case BPF_DW:
941 			if (is_signed_imm12(off)) {
942 				emit_insn(ctx, std, src, dst, off);
943 			} else if (is_signed_imm14(off)) {
944 				emit_insn(ctx, stptrd, src, dst, off);
945 			} else {
946 				move_imm(ctx, t1, off, is32);
947 				emit_insn(ctx, stxd, src, dst, t1);
948 			}
949 			break;
950 		}
951 		break;
952 
953 	case BPF_STX | BPF_ATOMIC | BPF_W:
954 	case BPF_STX | BPF_ATOMIC | BPF_DW:
955 		emit_atomic(insn, ctx);
956 		break;
957 
958 	default:
959 		pr_err("bpf_jit: unknown opcode %02x\n", code);
960 		return -EINVAL;
961 	}
962 
963 	return 0;
964 
965 toofar:
966 	pr_info_once("bpf_jit: opcode %02x, jump too far\n", code);
967 	return -E2BIG;
968 }
969 
970 static int build_body(struct jit_ctx *ctx, bool extra_pass)
971 {
972 	int i;
973 	const struct bpf_prog *prog = ctx->prog;
974 
975 	for (i = 0; i < prog->len; i++) {
976 		const struct bpf_insn *insn = &prog->insnsi[i];
977 		int ret;
978 
979 		if (ctx->image == NULL)
980 			ctx->offset[i] = ctx->idx;
981 
982 		ret = build_insn(insn, ctx, extra_pass);
983 		if (ret > 0) {
984 			i++;
985 			if (ctx->image == NULL)
986 				ctx->offset[i] = ctx->idx;
987 			continue;
988 		}
989 		if (ret)
990 			return ret;
991 	}
992 
993 	if (ctx->image == NULL)
994 		ctx->offset[i] = ctx->idx;
995 
996 	return 0;
997 }
998 
999 /* Fill space with break instructions */
1000 static void jit_fill_hole(void *area, unsigned int size)
1001 {
1002 	u32 *ptr;
1003 
1004 	/* We are guaranteed to have aligned memory */
1005 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
1006 		*ptr++ = INSN_BREAK;
1007 }
1008 
1009 static int validate_code(struct jit_ctx *ctx)
1010 {
1011 	int i;
1012 	union loongarch_instruction insn;
1013 
1014 	for (i = 0; i < ctx->idx; i++) {
1015 		insn = ctx->image[i];
1016 		/* Check INSN_BREAK */
1017 		if (insn.word == INSN_BREAK)
1018 			return -1;
1019 	}
1020 
1021 	return 0;
1022 }
1023 
1024 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1025 {
1026 	bool tmp_blinded = false, extra_pass = false;
1027 	u8 *image_ptr;
1028 	int image_size;
1029 	struct jit_ctx ctx;
1030 	struct jit_data *jit_data;
1031 	struct bpf_binary_header *header;
1032 	struct bpf_prog *tmp, *orig_prog = prog;
1033 
1034 	/*
1035 	 * If BPF JIT was not enabled then we must fall back to
1036 	 * the interpreter.
1037 	 */
1038 	if (!prog->jit_requested)
1039 		return orig_prog;
1040 
1041 	tmp = bpf_jit_blind_constants(prog);
1042 	/*
1043 	 * If blinding was requested and we failed during blinding,
1044 	 * we must fall back to the interpreter. Otherwise, we save
1045 	 * the new JITed code.
1046 	 */
1047 	if (IS_ERR(tmp))
1048 		return orig_prog;
1049 
1050 	if (tmp != prog) {
1051 		tmp_blinded = true;
1052 		prog = tmp;
1053 	}
1054 
1055 	jit_data = prog->aux->jit_data;
1056 	if (!jit_data) {
1057 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1058 		if (!jit_data) {
1059 			prog = orig_prog;
1060 			goto out;
1061 		}
1062 		prog->aux->jit_data = jit_data;
1063 	}
1064 	if (jit_data->ctx.offset) {
1065 		ctx = jit_data->ctx;
1066 		image_ptr = jit_data->image;
1067 		header = jit_data->header;
1068 		extra_pass = true;
1069 		image_size = sizeof(u32) * ctx.idx;
1070 		goto skip_init_ctx;
1071 	}
1072 
1073 	memset(&ctx, 0, sizeof(ctx));
1074 	ctx.prog = prog;
1075 
1076 	ctx.offset = kvcalloc(prog->len + 1, sizeof(u32), GFP_KERNEL);
1077 	if (ctx.offset == NULL) {
1078 		prog = orig_prog;
1079 		goto out_offset;
1080 	}
1081 
1082 	/* 1. Initial fake pass to compute ctx->idx and set ctx->flags */
1083 	build_prologue(&ctx);
1084 	if (build_body(&ctx, extra_pass)) {
1085 		prog = orig_prog;
1086 		goto out_offset;
1087 	}
1088 	ctx.epilogue_offset = ctx.idx;
1089 	build_epilogue(&ctx);
1090 
1091 	/* Now we know the actual image size.
1092 	 * As each LoongArch instruction is of length 32bit,
1093 	 * we are translating number of JITed intructions into
1094 	 * the size required to store these JITed code.
1095 	 */
1096 	image_size = sizeof(u32) * ctx.idx;
1097 	/* Now we know the size of the structure to make */
1098 	header = bpf_jit_binary_alloc(image_size, &image_ptr,
1099 				      sizeof(u32), jit_fill_hole);
1100 	if (header == NULL) {
1101 		prog = orig_prog;
1102 		goto out_offset;
1103 	}
1104 
1105 	/* 2. Now, the actual pass to generate final JIT code */
1106 	ctx.image = (union loongarch_instruction *)image_ptr;
1107 
1108 skip_init_ctx:
1109 	ctx.idx = 0;
1110 
1111 	build_prologue(&ctx);
1112 	if (build_body(&ctx, extra_pass)) {
1113 		bpf_jit_binary_free(header);
1114 		prog = orig_prog;
1115 		goto out_offset;
1116 	}
1117 	build_epilogue(&ctx);
1118 
1119 	/* 3. Extra pass to validate JITed code */
1120 	if (validate_code(&ctx)) {
1121 		bpf_jit_binary_free(header);
1122 		prog = orig_prog;
1123 		goto out_offset;
1124 	}
1125 
1126 	/* And we're done */
1127 	if (bpf_jit_enable > 1)
1128 		bpf_jit_dump(prog->len, image_size, 2, ctx.image);
1129 
1130 	/* Update the icache */
1131 	flush_icache_range((unsigned long)header, (unsigned long)(ctx.image + ctx.idx));
1132 
1133 	if (!prog->is_func || extra_pass) {
1134 		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1135 			pr_err_once("multi-func JIT bug %d != %d\n",
1136 				    ctx.idx, jit_data->ctx.idx);
1137 			bpf_jit_binary_free(header);
1138 			prog->bpf_func = NULL;
1139 			prog->jited = 0;
1140 			prog->jited_len = 0;
1141 			goto out_offset;
1142 		}
1143 		bpf_jit_binary_lock_ro(header);
1144 	} else {
1145 		jit_data->ctx = ctx;
1146 		jit_data->image = image_ptr;
1147 		jit_data->header = header;
1148 	}
1149 	prog->jited = 1;
1150 	prog->jited_len = image_size;
1151 	prog->bpf_func = (void *)ctx.image;
1152 
1153 	if (!prog->is_func || extra_pass) {
1154 		int i;
1155 
1156 		/* offset[prog->len] is the size of program */
1157 		for (i = 0; i <= prog->len; i++)
1158 			ctx.offset[i] *= LOONGARCH_INSN_SIZE;
1159 		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1160 
1161 out_offset:
1162 		kvfree(ctx.offset);
1163 		kfree(jit_data);
1164 		prog->aux->jit_data = NULL;
1165 	}
1166 
1167 out:
1168 	if (tmp_blinded)
1169 		bpf_jit_prog_release_other(prog, prog == orig_prog ? tmp : orig_prog);
1170 
1171 	out_offset = -1;
1172 
1173 	return prog;
1174 }
1175