diff mbox series

[v2,bpf-next,04/14] bpf: Support kernel function call in x86-32

Message ID 20210325015149.1545267-1-kafai@fb.com
State New
Headers show
Series [v2,bpf-next,01/14] bpf: Simplify freeing logic in linfo and jited_linfo | expand

Commit Message

Martin KaFai Lau March 25, 2021, 1:51 a.m. UTC
This patch adds kernel function call support to the x86-32 bpf jit.

Signed-off-by: Martin KaFai Lau <kafai@fb.com>
---
 arch/x86/net/bpf_jit_comp32.c | 198 ++++++++++++++++++++++++++++++++++
 1 file changed, 198 insertions(+)
diff mbox series

Patch

diff --git a/arch/x86/net/bpf_jit_comp32.c b/arch/x86/net/bpf_jit_comp32.c
index d17b67c69f89..0a7a2870f111 100644
--- a/arch/x86/net/bpf_jit_comp32.c
+++ b/arch/x86/net/bpf_jit_comp32.c
@@ -1390,6 +1390,19 @@  static inline void emit_push_r64(const u8 src[], u8 **pprog)
 	*pprog = prog;
 }
 
+static void emit_push_r32(const u8 src[], u8 **pprog)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	/* mov ecx,dword ptr [ebp+off] */
+	EMIT3(0x8B, add_2reg(0x40, IA32_EBP, IA32_ECX), STACK_VAR(src_lo));
+	/* push ecx */
+	EMIT1(0x51);
+
+	*pprog = prog;
+}
+
 static u8 get_cond_jmp_opcode(const u8 op, bool is_cmp_lo)
 {
 	u8 jmp_cond;
@@ -1459,6 +1472,174 @@  static u8 get_cond_jmp_opcode(const u8 op, bool is_cmp_lo)
 	return jmp_cond;
 }
 
+/* i386 kernel compiles with "-mregparm=3".  From gcc document:
+ *
+ * ==== snippet ====
+ * regparm (number)
+ *	On x86-32 targets, the regparm attribute causes the compiler
+ *	to pass arguments number one to (number) if they are of integral
+ *	type in registers EAX, EDX, and ECX instead of on the stack.
+ *	Functions that take a variable number of arguments continue
+ *	to be passed all of their arguments on the stack.
+ * ==== snippet ====
+ *
+ * The first three args of a function will be considered for
+ * putting into the 32bit register EAX, EDX, and ECX.
+ *
+ * Two 32bit registers are used to pass a 64bit arg.
+ *
+ * For example,
+ * void foo(u32 a, u32 b, u32 c, u32 d):
+ *	u32 a: EAX
+ *	u32 b: EDX
+ *	u32 c: ECX
+ *	u32 d: stack
+ *
+ * void foo(u64 a, u32 b, u32 c):
+ *	u64 a: EAX (lo32) EDX (hi32)
+ *	u32 b: ECX
+ *	u32 c: stack
+ *
+ * void foo(u32 a, u64 b, u32 c):
+ *	u32 a: EAX
+ *	u64 b: EDX (lo32) ECX (hi32)
+ *	u32 c: stack
+ *
+ * void foo(u32 a, u32 b, u64 c):
+ *	u32 a: EAX
+ *	u32 b: EDX
+ *	u64 c: stack
+ *
+ * The return value will be stored in the EAX (and EDX for 64bit value).
+ *
+ * For example,
+ * u32 foo(u32 a, u32 b, u32 c):
+ *	return value: EAX
+ *
+ * u64 foo(u32 a, u32 b, u32 c):
+ *	return value: EAX (lo32) EDX (hi32)
+ *
+ * Notes:
+ *	The verifier only accepts function having integer and pointers
+ *	as its args and return value, so it does not have
+ *	struct-by-value.
+ *
+ * emit_kfunc_call() finds out the btf_func_model by calling
+ * bpf_jit_find_kfunc_model().  A btf_func_model
+ * has the details about the number of args, size of each arg,
+ * and the size of the return value.
+ *
+ * It first decides how many args can be passed by EAX, EDX, and ECX.
+ * That will decide what args should be pushed to the stack:
+ * [first_stack_regno, last_stack_regno] are the bpf regnos
+ * that should be pushed to the stack.
+ *
+ * It will first push all args to the stack because the push
+ * will need to use ECX.  Then, it moves
+ * [BPF_REG_1, first_stack_regno) to EAX, EDX, and ECX.
+ *
+ * When emitting a call (0xE8), it needs to figure out
+ * the jmp_offset relative to the jit-insn address immediately
+ * following the call (0xE8) instruction.  At this point, it knows
+ * the end of the jit-insn address after completely translated the
+ * current (BPF_JMP | BPF_CALL) bpf-insn.  It is passed as "end_addr"
+ * to the emit_kfunc_call().  Thus, it can learn the "immediate-follow-call"
+ * address by figuring out how many jit-insn is generated between
+ * the call (0xE8) and the end_addr:
+ *	- 0-1 jit-insn (3 bytes each) to restore the esp pointer if there
+ *	  is arg pushed to the stack.
+ *	- 0-2 jit-insns (3 bytes each) to handle the return value.
+ */
+static int emit_kfunc_call(const struct bpf_prog *bpf_prog, u8 *end_addr,
+			   const struct bpf_insn *insn, u8 **pprog)
+{
+	const u8 arg_regs[] = { IA32_EAX, IA32_EDX, IA32_ECX };
+	int i, cnt = 0, first_stack_regno, last_stack_regno;
+	int free_arg_regs = ARRAY_SIZE(arg_regs);
+	const struct btf_func_model *fm;
+	int bytes_in_stack = 0;
+	const u8 *cur_arg_reg;
+	u8 *prog = *pprog;
+	s64 jmp_offset;
+
+	fm = bpf_jit_find_kfunc_model(bpf_prog, insn);
+	if (!fm)
+		return -EINVAL;
+
+	first_stack_regno = BPF_REG_1;
+	for (i = 0; i < fm->nr_args; i++) {
+		int regs_needed = fm->arg_size[i] > sizeof(u32) ? 2 : 1;
+
+		if (regs_needed > free_arg_regs)
+			break;
+
+		free_arg_regs -= regs_needed;
+		first_stack_regno++;
+	}
+
+	/* Push the args to the stack */
+	last_stack_regno = BPF_REG_0 + fm->nr_args;
+	for (i = last_stack_regno; i >= first_stack_regno; i--) {
+		if (fm->arg_size[i - 1] > sizeof(u32)) {
+			emit_push_r64(bpf2ia32[i], &prog);
+			bytes_in_stack += 8;
+		} else {
+			emit_push_r32(bpf2ia32[i], &prog);
+			bytes_in_stack += 4;
+		}
+	}
+
+	cur_arg_reg = &arg_regs[0];
+	for (i = BPF_REG_1; i < first_stack_regno; i++) {
+		/* mov e[adc]x,dword ptr [ebp+off] */
+		EMIT3(0x8B, add_2reg(0x40, IA32_EBP, *cur_arg_reg++),
+		      STACK_VAR(bpf2ia32[i][0]));
+		if (fm->arg_size[i - 1] > sizeof(u32))
+			/* mov e[adc]x,dword ptr [ebp+off] */
+			EMIT3(0x8B, add_2reg(0x40, IA32_EBP, *cur_arg_reg++),
+			      STACK_VAR(bpf2ia32[i][1]));
+	}
+
+	if (bytes_in_stack)
+		/* add esp,"bytes_in_stack" */
+		end_addr -= 3;
+
+	/* mov dword ptr [ebp+off],edx */
+	if (fm->ret_size > sizeof(u32))
+		end_addr -= 3;
+
+	/* mov dword ptr [ebp+off],eax */
+	if (fm->ret_size)
+		end_addr -= 3;
+
+	jmp_offset = (u8 *)__bpf_call_base + insn->imm - end_addr;
+	if (!is_simm32(jmp_offset)) {
+		pr_err("unsupported BPF kernel function jmp_offset:%lld\n",
+		       jmp_offset);
+		return -EINVAL;
+	}
+
+	EMIT1_off32(0xE8, jmp_offset);
+
+	if (fm->ret_size)
+		/* mov dword ptr [ebp+off],eax */
+		EMIT3(0x89, add_2reg(0x40, IA32_EBP, IA32_EAX),
+		      STACK_VAR(bpf2ia32[BPF_REG_0][0]));
+
+	if (fm->ret_size > sizeof(u32))
+		/* mov dword ptr [ebp+off],edx */
+		EMIT3(0x89, add_2reg(0x40, IA32_EBP, IA32_EDX),
+		      STACK_VAR(bpf2ia32[BPF_REG_0][1]));
+
+	if (bytes_in_stack)
+		/* add esp,"bytes_in_stack" */
+		EMIT3(0x83, add_1reg(0xC0, IA32_ESP), bytes_in_stack);
+
+	*pprog = prog;
+
+	return 0;
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 		  int oldproglen, struct jit_context *ctx)
 {
@@ -1888,6 +2069,18 @@  static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 			if (insn->src_reg == BPF_PSEUDO_CALL)
 				goto notyet;
 
+			if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
+				int err;
+
+				err = emit_kfunc_call(bpf_prog,
+						      image + addrs[i],
+						      insn, &prog);
+
+				if (err)
+					return err;
+				break;
+			}
+
 			func = (u8 *) __bpf_call_base + imm32;
 			jmp_offset = func - (image + addrs[i]);
 
@@ -2393,3 +2586,8 @@  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 					   tmp : orig_prog);
 	return prog;
 }
+
+bool bpf_jit_supports_kfunc_call(void)
+{
+	return true;
+}