diff mbox series

[RFC] bpf, x64: add extra passes and relax size convergence check

Message ID 20201126080130.23303-1-glin@suse.com
State New
Headers show
Series [RFC] bpf, x64: add extra passes and relax size convergence check | expand

Commit Message

Gary Lin Nov. 26, 2020, 8:01 a.m. UTC
The x64 bpf jit expects bpf images converge within the given passes, but
it could fail to do so with some corner cases. For example:

      l0:     ldh [4]
      l1:     jeq #0x537d, l2, l40
      l2:     ld [0]
      l3:     jeq #0xfa163e0d, l4, l40
      l4:     ldh [12]
      l5:     ldx #0xe
      l6:     jeq #0x86dd, l41, l7
      l7:     jeq #0x800, l8, l41
      l8:     ld [x+16]
      l9:     ja 41

        [... repeated ja 41 ]

      l40:    ja 41
      l41:    ret #0
      l42:    ld #len
      l43:    ret a

This bpf program contains 32 "ja 41" instructions which are effectivly
NOPs and designed to be replaced with valid code dynamically. Ideally,
bpf jit should optimize those "ja 41" instructions out when translating
translating the bpf instrctions into x86_64 machine code. However,
do_jit() can only remove one "ja 41" for offset==0 on each pass, so it
requires at least 32 runs to eliminate those JMPs and exceeds the
current limit of passes (20). In the end, the program got rejected when
BPF_JIT_ALWAYS_ON is set even though it's legit as a classic socket
filter.

To allow the not-converged images, one possible solution is to only use
JMPs with imm32 to guarantee the correctness of jump offsets.

There are two goals of this commit:
  1. reduce the size variance by generating only jumps with imm32
  2. relax the requirement of size convergence

Since imm32 jump occupies 5 bytes compared with 2 bytes by imm8 jump,
the image size may swell. To minimize the impact, 5 extra passes are
introduced and the imm32-only rule is only applied to the extra passes,
so the bpf images converge within the original 20 passes won't be
affected.

If the image doesn't converge after the 5 extra passes, the image is
still allocated and a warning is issued to notify the user.

Signed-off-by: Gary Lin <glin@suse.com>
---
 arch/x86/net/bpf_jit_comp.c | 47 +++++++++++++++++++++++++++++--------
 1 file changed, 37 insertions(+), 10 deletions(-)

Comments

Gary Lin Nov. 27, 2020, 3:13 a.m. UTC | #1
On Thu, Nov 26, 2020 at 04:01:30PM +0800, Gary Lin wrote:
> The x64 bpf jit expects bpf images converge within the given passes, but
> it could fail to do so with some corner cases. For example:
> 
>       l0:     ldh [4]
>       l1:     jeq #0x537d, l2, l40
>       l2:     ld [0]
>       l3:     jeq #0xfa163e0d, l4, l40
>       l4:     ldh [12]
>       l5:     ldx #0xe
>       l6:     jeq #0x86dd, l41, l7
>       l7:     jeq #0x800, l8, l41
>       l8:     ld [x+16]
>       l9:     ja 41
> 
>         [... repeated ja 41 ]
> 
>       l40:    ja 41
>       l41:    ret #0
>       l42:    ld #len
>       l43:    ret a
> 
> This bpf program contains 32 "ja 41" instructions which are effectivly
> NOPs and designed to be replaced with valid code dynamically. Ideally,
> bpf jit should optimize those "ja 41" instructions out when translating
> translating the bpf instrctions into x86_64 machine code. However,
> do_jit() can only remove one "ja 41" for offset==0 on each pass, so it
> requires at least 32 runs to eliminate those JMPs and exceeds the
> current limit of passes (20). In the end, the program got rejected when
> BPF_JIT_ALWAYS_ON is set even though it's legit as a classic socket
> filter.
> 
> To allow the not-converged images, one possible solution is to only use
> JMPs with imm32 to guarantee the correctness of jump offsets.
> 
> There are two goals of this commit:
>   1. reduce the size variance by generating only jumps with imm32
>   2. relax the requirement of size convergence
I went through the code again and found that size convergence is still
necessary or the offset change could still affect the correctness of the
generated machine code.

Instead of relaxing size convergence check, we can disable the jump
optimization in the extra passes, i.e. imm32 jump only and no more
nop optimization.

Will submit another patch.

Gary Lin

> 
> Since imm32 jump occupies 5 bytes compared with 2 bytes by imm8 jump,
> the image size may swell. To minimize the impact, 5 extra passes are
> introduced and the imm32-only rule is only applied to the extra passes,
> so the bpf images converge within the original 20 passes won't be
> affected.
> 
> If the image doesn't converge after the 5 extra passes, the image is
> still allocated and a warning is issued to notify the user.
> 
> Signed-off-by: Gary Lin <glin@suse.com>
> ---
>  arch/x86/net/bpf_jit_comp.c | 47 +++++++++++++++++++++++++++++--------
>  1 file changed, 37 insertions(+), 10 deletions(-)
> 
> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
> index 796506dcfc42..6fe933e9e8c2 100644
> --- a/arch/x86/net/bpf_jit_comp.c
> +++ b/arch/x86/net/bpf_jit_comp.c
> @@ -790,7 +790,8 @@ static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
>  }
>  
>  static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
> -		  int oldproglen, struct jit_context *ctx)
> +		  int oldproglen, struct jit_context *ctx, bool force_jmp32,
> +		  bool allow_grow)
>  {
>  	bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
>  	struct bpf_insn *insn = bpf_prog->insnsi;
> @@ -1408,7 +1409,7 @@ xadd:			if (is_imm8(insn->off))
>  				return -EFAULT;
>  			}
>  			jmp_offset = addrs[i + insn->off] - addrs[i];
> -			if (is_imm8(jmp_offset)) {
> +			if (is_imm8(jmp_offset) && !force_jmp32) {
>  				EMIT2(jmp_cond, jmp_offset);
>  			} else if (is_simm32(jmp_offset)) {
>  				EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
> @@ -1435,7 +1436,7 @@ xadd:			if (is_imm8(insn->off))
>  				/* Optimize out nop jumps */
>  				break;
>  emit_jmp:
> -			if (is_imm8(jmp_offset)) {
> +			if (is_imm8(jmp_offset) && !force_jmp32) {
>  				EMIT2(0xEB, jmp_offset);
>  			} else if (is_simm32(jmp_offset)) {
>  				EMIT1_off32(0xE9, jmp_offset);
> @@ -1476,7 +1477,7 @@ xadd:			if (is_imm8(insn->off))
>  		}
>  
>  		if (image) {
> -			if (unlikely(proglen + ilen > oldproglen)) {
> +			if (unlikely(proglen + ilen > oldproglen) && !allow_grow) {
>  				pr_err("bpf_jit: fatal error\n");
>  				return -EFAULT;
>  			}
> @@ -1972,6 +1973,9 @@ struct x64_jit_data {
>  	struct jit_context ctx;
>  };
>  
> +#define MAX_JIT_PASSES 25
> +#define JMP32_ONLY_PASSES 20
> +
>  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
>  {
>  	struct bpf_binary_header *header = NULL;
> @@ -1981,6 +1985,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
>  	struct jit_context ctx = {};
>  	bool tmp_blinded = false;
>  	bool extra_pass = false;
> +	bool force_jmp32 = false;
> +	bool allow_grow = false;
>  	u8 *image = NULL;
>  	int *addrs;
>  	int pass;
> @@ -2042,8 +2048,24 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
>  	 * may converge on the last pass. In such case do one more
>  	 * pass to emit the final image.
>  	 */
> -	for (pass = 0; pass < 20 || image; pass++) {
> -		proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
> +	for (pass = 0; pass < MAX_JIT_PASSES || image; pass++) {
> +		/*
> +		 * On the 21th pass, if the image still doesn't converge,
> +		 * then force_jmp32 is set afterward to make do_jit() always
> +		 * generate 32bit offest JMP to reduce the chance of size
> +		 * variance. The side effect is that the image size may grow
> +		 * since the 8bit offset JMPs are now replaced with 32bit
> +		 * offset JMPs, so allow_grow is flipped to true only for
> +		 * this pass.
> +		 */
> +		if (pass == JMP32_ONLY_PASSES && !image) {
> +			force_jmp32 = true;
> +			allow_grow = true;
> +		} else {
> +			allow_grow = false;
> +		}
> +
> +		proglen = do_jit(prog, addrs, image, oldproglen, &ctx, force_jmp32, allow_grow);
>  		if (proglen <= 0) {
>  out_image:
>  			image = NULL;
> @@ -2054,13 +2076,18 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
>  		}
>  		if (image) {
>  			if (proglen != oldproglen) {
> -				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
> -				       proglen, oldproglen);
> -				goto out_image;
> +				if (pass < MAX_JIT_PASSES) {
> +					pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
> +					       proglen, oldproglen);
> +					goto out_image;
> +				} else {
> +					pr_warn("bpf_jit: proglen=%d != oldproglen=%d, pass=%d\n",
> +						proglen, oldproglen, pass);
> +				}
>  			}
>  			break;
>  		}
> -		if (proglen == oldproglen) {
> +		if (proglen == oldproglen || pass == (MAX_JIT_PASSES - 1)) {
>  			/*
>  			 * The number of entries in extable is the number of BPF_LDX
>  			 * insns that access kernel memory via "pointer to BTF type".
> -- 
> 2.28.0
>
diff mbox series

Patch

diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 796506dcfc42..6fe933e9e8c2 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -790,7 +790,8 @@  static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
 }
 
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
-		  int oldproglen, struct jit_context *ctx)
+		  int oldproglen, struct jit_context *ctx, bool force_jmp32,
+		  bool allow_grow)
 {
 	bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
 	struct bpf_insn *insn = bpf_prog->insnsi;
@@ -1408,7 +1409,7 @@  xadd:			if (is_imm8(insn->off))
 				return -EFAULT;
 			}
 			jmp_offset = addrs[i + insn->off] - addrs[i];
-			if (is_imm8(jmp_offset)) {
+			if (is_imm8(jmp_offset) && !force_jmp32) {
 				EMIT2(jmp_cond, jmp_offset);
 			} else if (is_simm32(jmp_offset)) {
 				EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
@@ -1435,7 +1436,7 @@  xadd:			if (is_imm8(insn->off))
 				/* Optimize out nop jumps */
 				break;
 emit_jmp:
-			if (is_imm8(jmp_offset)) {
+			if (is_imm8(jmp_offset) && !force_jmp32) {
 				EMIT2(0xEB, jmp_offset);
 			} else if (is_simm32(jmp_offset)) {
 				EMIT1_off32(0xE9, jmp_offset);
@@ -1476,7 +1477,7 @@  xadd:			if (is_imm8(insn->off))
 		}
 
 		if (image) {
-			if (unlikely(proglen + ilen > oldproglen)) {
+			if (unlikely(proglen + ilen > oldproglen) && !allow_grow) {
 				pr_err("bpf_jit: fatal error\n");
 				return -EFAULT;
 			}
@@ -1972,6 +1973,9 @@  struct x64_jit_data {
 	struct jit_context ctx;
 };
 
+#define MAX_JIT_PASSES 25
+#define JMP32_ONLY_PASSES 20
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
 	struct bpf_binary_header *header = NULL;
@@ -1981,6 +1985,8 @@  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 	struct jit_context ctx = {};
 	bool tmp_blinded = false;
 	bool extra_pass = false;
+	bool force_jmp32 = false;
+	bool allow_grow = false;
 	u8 *image = NULL;
 	int *addrs;
 	int pass;
@@ -2042,8 +2048,24 @@  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 	 * may converge on the last pass. In such case do one more
 	 * pass to emit the final image.
 	 */
-	for (pass = 0; pass < 20 || image; pass++) {
-		proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
+	for (pass = 0; pass < MAX_JIT_PASSES || image; pass++) {
+		/*
+		 * On the 21th pass, if the image still doesn't converge,
+		 * then force_jmp32 is set afterward to make do_jit() always
+		 * generate 32bit offest JMP to reduce the chance of size
+		 * variance. The side effect is that the image size may grow
+		 * since the 8bit offset JMPs are now replaced with 32bit
+		 * offset JMPs, so allow_grow is flipped to true only for
+		 * this pass.
+		 */
+		if (pass == JMP32_ONLY_PASSES && !image) {
+			force_jmp32 = true;
+			allow_grow = true;
+		} else {
+			allow_grow = false;
+		}
+
+		proglen = do_jit(prog, addrs, image, oldproglen, &ctx, force_jmp32, allow_grow);
 		if (proglen <= 0) {
 out_image:
 			image = NULL;
@@ -2054,13 +2076,18 @@  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 		}
 		if (image) {
 			if (proglen != oldproglen) {
-				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
-				       proglen, oldproglen);
-				goto out_image;
+				if (pass < MAX_JIT_PASSES) {
+					pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
+					       proglen, oldproglen);
+					goto out_image;
+				} else {
+					pr_warn("bpf_jit: proglen=%d != oldproglen=%d, pass=%d\n",
+						proglen, oldproglen, pass);
+				}
 			}
 			break;
 		}
-		if (proglen == oldproglen) {
+		if (proglen == oldproglen || pass == (MAX_JIT_PASSES - 1)) {
 			/*
 			 * The number of entries in extable is the number of BPF_LDX
 			 * insns that access kernel memory via "pointer to BTF type".