diff mbox series

[bpf-next,v5,02/17] bpf/verifier: allow kfunc to return an allocated mem

Message ID 20220518205924.399291-3-benjamin.tissoires@redhat.com
State New
Headers show
Series Introduce eBPF support for HID devices | expand

Commit Message

Benjamin Tissoires May 18, 2022, 8:59 p.m. UTC
When a kfunc is not returning a pointer to a struct but to a plain type,
we can consider it is a valid allocated memory assuming that:
- one of the arguments is called rdonly_buf_size
- or one of the arguments is called rdwr_buf_size
- and this argument is a const from the caller point of view

We can then use this parameter as the size of the allocated memory.

The memory is either read-only or read-write based on the name
of the size parameter.

Signed-off-by: Benjamin Tissoires <benjamin.tissoires@redhat.com>

---

changes in v5:
- updated PTR_TO_MEM comment in btf.c to match upstream
- make it read-only or read-write based on the name of size

new in v4
---
 include/linux/btf.h   |  7 +++++
 kernel/bpf/btf.c      | 41 +++++++++++++++++++++++-
 kernel/bpf/verifier.c | 72 +++++++++++++++++++++++++++++++++----------
 3 files changed, 102 insertions(+), 18 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/btf.h b/include/linux/btf.h
index 2611cea2c2b6..2a4feafc083e 100644
--- a/include/linux/btf.h
+++ b/include/linux/btf.h
@@ -343,6 +343,13 @@  static inline struct btf_param *btf_params(const struct btf_type *t)
 	return (struct btf_param *)(t + 1);
 }
 
+struct bpf_reg_state;
+
+bool btf_is_kfunc_arg_mem_size(const struct btf *btf,
+			       const struct btf_param *arg,
+			       const struct bpf_reg_state *reg,
+			       const char *name);
+
 #ifdef CONFIG_BPF_SYSCALL
 struct bpf_prog;
 
diff --git a/kernel/bpf/btf.c b/kernel/bpf/btf.c
index 7bccaa4646e5..2d11d178807c 100644
--- a/kernel/bpf/btf.c
+++ b/kernel/bpf/btf.c
@@ -6049,6 +6049,31 @@  static bool is_kfunc_arg_mem_size(const struct btf *btf,
 	return true;
 }
 
+bool btf_is_kfunc_arg_mem_size(const struct btf *btf,
+			       const struct btf_param *arg,
+			       const struct bpf_reg_state *reg,
+			       const char *name)
+{
+	int len, target_len = strlen(name);
+	const struct btf_type *t;
+	const char *param_name;
+
+	t = btf_type_skip_modifiers(btf, arg->type, NULL);
+	if (!btf_type_is_scalar(t) || reg->type != SCALAR_VALUE)
+		return false;
+
+	param_name = btf_name_by_offset(btf, arg->name_off);
+	if (str_is_empty(param_name))
+		return false;
+	len = strlen(param_name);
+	if (len != target_len)
+		return false;
+	if (strncmp(param_name, name, target_len))
+		return false;
+
+	return true;
+}
+
 static int btf_check_func_arg_match(struct bpf_verifier_env *env,
 				    const struct btf *btf, u32 func_id,
 				    struct bpf_reg_state *regs,
@@ -6198,7 +6223,7 @@  static int btf_check_func_arg_match(struct bpf_verifier_env *env,
 			if (reg->type == PTR_TO_BTF_ID) {
 				reg_btf = reg->btf;
 				reg_ref_id = reg->btf_id;
-				/* Ensure only one argument is referenced PTR_TO_BTF_ID */
+				/* Ensure only one argument is reference PTR_TO_BTF_ID or PTR_TO_MEM */
 				if (reg->ref_obj_id) {
 					if (ref_obj_id) {
 						bpf_log(log, "verifier internal error: more than one arg with ref_obj_id R%d %u %u\n",
@@ -6258,6 +6283,20 @@  static int btf_check_func_arg_match(struct bpf_verifier_env *env,
 					i++;
 					continue;
 				}
+
+				if (rel && reg->ref_obj_id) {
+					/* Ensure only one argument is referenced PTR_TO_BTF_ID or PTR_TO_MEM */
+					if (ref_obj_id) {
+						bpf_log(log,
+							"verifier internal error: more than one arg with ref_obj_id R%d %u %u\n",
+							regno,
+							reg->ref_obj_id,
+							ref_obj_id);
+						return -EFAULT;
+					}
+					ref_regno = regno;
+					ref_obj_id = reg->ref_obj_id;
+				}
 			}
 
 			resolve_ret = btf_resolve_size(btf, ref_t, &type_size);
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 9b59581026f8..084319073064 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -7219,13 +7219,14 @@  static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 			    int *insn_idx_p)
 {
 	const struct btf_type *t, *func, *func_proto, *ptr_type;
-	struct bpf_reg_state *regs = cur_regs(env);
+	struct bpf_reg_state *reg, *regs = cur_regs(env);
 	const char *func_name, *ptr_type_name;
-	u32 i, nargs, func_id, ptr_type_id;
+	u32 i, nargs, func_id, ptr_type_id, regno;
 	int err, insn_idx = *insn_idx_p;
 	const struct btf_param *args;
 	struct btf *desc_btf;
 	bool acq;
+	size_t reg_rw_size = 0, reg_ro_size = 0;
 
 	/* skip for now, but return error when we find this in fixup_kfunc_call */
 	if (!insn->imm)
@@ -7266,8 +7267,8 @@  static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 		}
 	}
 
-	for (i = 0; i < CALLER_SAVED_REGS; i++)
-		mark_reg_not_init(env, regs, caller_saved[i]);
+	/* reset REG_0 */
+	mark_reg_not_init(env, regs, BPF_REG_0);
 
 	/* Check return type */
 	t = btf_type_skip_modifiers(desc_btf, func_proto->type, NULL);
@@ -7277,6 +7278,9 @@  static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 		return -EINVAL;
 	}
 
+	nargs = btf_type_vlen(func_proto);
+	args = btf_params(func_proto);
+
 	if (btf_type_is_scalar(t)) {
 		mark_reg_unknown(env, regs, BPF_REG_0);
 		mark_btf_func_reg_size(env, BPF_REG_0, t->size);
@@ -7284,24 +7288,57 @@  static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 		ptr_type = btf_type_skip_modifiers(desc_btf, t->type,
 						   &ptr_type_id);
 		if (!btf_type_is_struct(ptr_type)) {
-			ptr_type_name = btf_name_by_offset(desc_btf,
-							   ptr_type->name_off);
-			verbose(env, "kernel function %s returns pointer type %s %s is not supported\n",
-				func_name, btf_type_str(ptr_type),
-				ptr_type_name);
-			return -EINVAL;
+			/* if we have an array, look for the arguments */
+			for (i = 0; i < nargs; i++) {
+				regno = i + BPF_REG_1;
+				reg = &regs[regno];
+
+				/* look for any const scalar parameter of name "rdonly_buf_size"
+				 * or "rdwr_buf_size"
+				 */
+				if (!check_reg_arg(env, regno, SRC_OP) &&
+				    tnum_is_const(regs[regno].var_off)) {
+					if (btf_is_kfunc_arg_mem_size(desc_btf, &args[i], reg,
+								      "rdonly_buf_size"))
+						reg_ro_size = regs[regno].var_off.value;
+					else if (btf_is_kfunc_arg_mem_size(desc_btf, &args[i], reg,
+									   "rdwr_buf_size"))
+						reg_rw_size = regs[regno].var_off.value;
+				}
+			}
+
+			if (!reg_rw_size && !reg_ro_size) {
+				ptr_type_name = btf_name_by_offset(desc_btf,
+								   ptr_type->name_off);
+				verbose(env,
+					"kernel function %s returns pointer type %s %s is not supported\n",
+					func_name,
+					btf_type_str(ptr_type),
+					ptr_type_name);
+				return -EINVAL;
+			}
+
+			mark_reg_known_zero(env, regs, BPF_REG_0);
+			regs[BPF_REG_0].type = PTR_TO_MEM;
+			regs[BPF_REG_0].mem_size = reg_ro_size + reg_rw_size;
+
+			if (reg_ro_size)
+				regs[BPF_REG_0].type |= MEM_RDONLY;
+		} else {
+			mark_reg_known_zero(env, regs, BPF_REG_0);
+			regs[BPF_REG_0].type = PTR_TO_BTF_ID;
+			regs[BPF_REG_0].btf = desc_btf;
+			regs[BPF_REG_0].btf_id = ptr_type_id;
+			mark_btf_func_reg_size(env, BPF_REG_0, sizeof(void *));
 		}
-		mark_reg_known_zero(env, regs, BPF_REG_0);
-		regs[BPF_REG_0].btf = desc_btf;
-		regs[BPF_REG_0].type = PTR_TO_BTF_ID;
-		regs[BPF_REG_0].btf_id = ptr_type_id;
+
 		if (btf_kfunc_id_set_contains(desc_btf, resolve_prog_type(env->prog),
 					      BTF_KFUNC_TYPE_RET_NULL, func_id)) {
 			regs[BPF_REG_0].type |= PTR_MAYBE_NULL;
 			/* For mark_ptr_or_null_reg, see 93c230e3f5bd6 */
 			regs[BPF_REG_0].id = ++env->id_gen;
 		}
-		mark_btf_func_reg_size(env, BPF_REG_0, sizeof(void *));
+
 		if (acq) {
 			int id = acquire_reference_state(env, insn_idx);
 
@@ -7312,8 +7349,9 @@  static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 		}
 	} /* else { add_kfunc_call() ensures it is btf_type_is_void(t) } */
 
-	nargs = btf_type_vlen(func_proto);
-	args = (const struct btf_param *)(func_proto + 1);
+	for (i = 1 ; i < CALLER_SAVED_REGS; i++)
+		mark_reg_not_init(env, regs, caller_saved[i]);
+
 	for (i = 0; i < nargs; i++) {
 		u32 regno = i + 1;