diff mbox series

[bpf-next] bpf, riscv: Fix stack layout of JITed code on RV32

Message ID 20200430005127.2205-1-luke.r.nels@gmail.com
State New
Headers show
Series [bpf-next] bpf, riscv: Fix stack layout of JITed code on RV32 | expand

Commit Message

Luke Nelson April 30, 2020, 12:51 a.m. UTC
This patch fixes issues with stackframe unwinding and alignment in the
current stack layout for BPF programs on RV32.

In the current layout, RV32 fp points to the JIT scratch registers, rather
than to the callee-saved registers. This breaks stackframe unwinding,
which expects fp to point just above the saved ra and fp registers.

This patch fixes the issue by moving the callee-saved registers to be
stored on the top of the stack, pointed to by fp. This satisfies the
assumptions of stackframe unwinding.

This patch also fixes an issue with the old layout that the stack was
not aligned to 16 bytes.

Stacktrace from JITed code using the old stack layout:

  [   12.196249 ] [<c0402200>] walk_stackframe+0x0/0x96

Stacktrace using the new stack layout:

  [   13.062888 ] [<c0402200>] walk_stackframe+0x0/0x96
  [   13.063028 ] [<c04023c6>] show_stack+0x28/0x32
  [   13.063253 ] [<a403e778>] bpf_prog_82b916b2dfa00464+0x80/0x908
  [   13.063417 ] [<c09270b2>] bpf_test_run+0x124/0x39a
  [   13.063553 ] [<c09276c0>] bpf_prog_test_run_skb+0x234/0x448
  [   13.063704 ] [<c048510e>] __do_sys_bpf+0x766/0x13b4
  [   13.063840 ] [<c0485d82>] sys_bpf+0xc/0x14
  [   13.063961 ] [<c04010f0>] ret_from_syscall+0x0/0x2

The new code is also simpler to understand and includes an ASCII diagram
of the stack layout.

Tested on riscv32 QEMU virt machine.

Signed-off-by: Luke Nelson <luke.r.nels@gmail.com>
---
 arch/riscv/net/bpf_jit_comp32.c | 98 ++++++++++++++++++++++-----------
 1 file changed, 65 insertions(+), 33 deletions(-)
diff mbox series

Patch

diff --git a/arch/riscv/net/bpf_jit_comp32.c b/arch/riscv/net/bpf_jit_comp32.c
index 11083d4d5f2d..b198eaa74456 100644
--- a/arch/riscv/net/bpf_jit_comp32.c
+++ b/arch/riscv/net/bpf_jit_comp32.c
@@ -13,8 +13,35 @@ 
 #include <linux/filter.h>
 #include "bpf_jit.h"
 
+/*
+ * Stack layout during BPF program execution:
+ *
+ *                     high
+ *     RV32 fp =>  +----------+
+ *                 | saved ra |
+ *                 | saved fp | RV32 callee-saved registers
+ *                 |   ...    |
+ *                 +----------+ <= (fp - 4 * NR_SAVED_REGISTERS)
+ *                 |  hi(R6)  |
+ *                 |  lo(R6)  |
+ *                 |  hi(R7)  | JIT scratch space for BPF registers
+ *                 |  lo(R7)  |
+ *                 |   ...    |
+ *  BPF_REG_FP =>  +----------+ <= (fp - 4 * NR_SAVED_REGISTERS
+ *                 |          |        - 4 * BPF_JIT_SCRATCH_REGS)
+ *                 |          |
+ *                 |   ...    | BPF program stack
+ *                 |          |
+ *     RV32 sp =>  +----------+
+ *                 |          |
+ *                 |   ...    | Function call stack
+ *                 |          |
+ *                 +----------+
+ *                     low
+ */
+
 enum {
-	/* Stack layout - these are offsets from (top of stack - 4). */
+	/* Stack layout - these are offsets from top of JIT scratch space. */
 	BPF_R6_HI,
 	BPF_R6_LO,
 	BPF_R7_HI,
@@ -29,7 +56,11 @@  enum {
 	BPF_JIT_SCRATCH_REGS,
 };
 
-#define STACK_OFFSET(k) (-4 - ((k) * 4))
+/* Number of callee-saved registers stored to stack: ra, fp, s1--s7. */
+#define NR_SAVED_REGISTERS	9
+
+/* Offset from fp for BPF registers stored on stack. */
+#define STACK_OFFSET(k)	(-4 - (4 * NR_SAVED_REGISTERS) - (4 * (k)))
 
 #define TMP_REG_1	(MAX_BPF_JIT_REG + 0)
 #define TMP_REG_2	(MAX_BPF_JIT_REG + 1)
@@ -111,11 +142,9 @@  static void emit_imm64(const s8 *rd, s32 imm_hi, s32 imm_lo,
 
 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
 {
-	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 4;
+	int stack_adjust = ctx->stack_size;
 	const s8 *r0 = bpf2rv32[BPF_REG_0];
 
-	store_offset -= 4 * BPF_JIT_SCRATCH_REGS;
-
 	/* Set return value if not tail call. */
 	if (!is_tail_call) {
 		emit(rv_addi(RV_REG_A0, lo(r0), 0), ctx);
@@ -123,15 +152,15 @@  static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
 	}
 
 	/* Restore callee-saved registers. */
-	emit(rv_lw(RV_REG_RA, store_offset - 0, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_FP, store_offset - 4, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S1, store_offset - 8, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S2, store_offset - 12, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S3, store_offset - 16, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S4, store_offset - 20, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S5, store_offset - 24, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S6, store_offset - 28, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S7, store_offset - 32, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_RA, stack_adjust - 4, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_FP, stack_adjust - 8, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S1, stack_adjust - 12, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S2, stack_adjust - 16, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S3, stack_adjust - 20, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S4, stack_adjust - 24, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S5, stack_adjust - 28, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S6, stack_adjust - 32, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S7, stack_adjust - 36, RV_REG_SP), ctx);
 
 	emit(rv_addi(RV_REG_SP, RV_REG_SP, stack_adjust), ctx);
 
@@ -1260,17 +1289,20 @@  int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 
 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
 {
-	/* Make space to save 9 registers: ra, fp, s1--s7. */
-	int stack_adjust = 9 * sizeof(u32), store_offset, bpf_stack_adjust;
 	const s8 *fp = bpf2rv32[BPF_REG_FP];
 	const s8 *r1 = bpf2rv32[BPF_REG_1];
-
-	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
+	int stack_adjust = 0;
+	int bpf_stack_adjust =
+		round_up(ctx->prog->aux->stack_depth, STACK_ALIGN);
+
+	/* Make space for callee-saved registers. */
+	stack_adjust += NR_SAVED_REGISTERS * sizeof(u32);
+	/* Make space for BPF registers on stack. */
+	stack_adjust += BPF_JIT_SCRATCH_REGS * sizeof(u32);
+	/* Make space for BPF stack. */
 	stack_adjust += bpf_stack_adjust;
-
-	store_offset = stack_adjust - 4;
-
-	stack_adjust += 4 * BPF_JIT_SCRATCH_REGS;
+	/* Round up for stack alignment. */
+	stack_adjust = round_up(stack_adjust, STACK_ALIGN);
 
 	/*
 	 * The first instruction sets the tail-call-counter (TCC) register.
@@ -1281,24 +1313,24 @@  void bpf_jit_build_prologue(struct rv_jit_context *ctx)
 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -stack_adjust), ctx);
 
 	/* Save callee-save registers. */
-	emit(rv_sw(RV_REG_SP, store_offset - 0, RV_REG_RA), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 4, RV_REG_FP), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 8, RV_REG_S1), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 12, RV_REG_S2), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 16, RV_REG_S3), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 20, RV_REG_S4), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 24, RV_REG_S5), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 28, RV_REG_S6), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 32, RV_REG_S7), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 4, RV_REG_RA), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 8, RV_REG_FP), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 12, RV_REG_S1), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 16, RV_REG_S2), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 20, RV_REG_S3), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 24, RV_REG_S4), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 28, RV_REG_S5), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 32, RV_REG_S6), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 36, RV_REG_S7), ctx);
 
 	/* Set fp: used as the base address for stacked BPF registers. */
 	emit(rv_addi(RV_REG_FP, RV_REG_SP, stack_adjust), ctx);
 
-	/* Set up BPF stack pointer. */
+	/* Set up BPF frame pointer. */
 	emit(rv_addi(lo(fp), RV_REG_SP, bpf_stack_adjust), ctx);
 	emit(rv_addi(hi(fp), RV_REG_ZERO, 0), ctx);
 
-	/* Set up context pointer. */
+	/* Set up BPF context pointer. */
 	emit(rv_addi(lo(r1), RV_REG_A0, 0), ctx);
 	emit(rv_addi(hi(r1), RV_REG_ZERO, 0), ctx);