diff mbox series

[bpf-next,v3,07/11] bpf: Fix a false rejection caused by AND operation

Message ID 20240411122752.2873562-8-xukuohai@huaweicloud.com
State New
Headers show
Series Add check for bpf lsm return value | expand

Commit Message

Xu Kuohai April 11, 2024, 12:27 p.m. UTC
From: Xu Kuohai <xukuohai@huawei.com>

With lsm return value check, the no-alu32 version test_libbpf_get_fd_by_id_opts
is rejected by the verifier, and the log says:

  0: R1=ctx() R10=fp0
  ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
  0: (b7) r0 = 0                        ; R0_w=0
  1: (79) r2 = *(u64 *)(r1 +0)
  func 'bpf_lsm_bpf_map' arg0 has btf_id 916 type STRUCT 'bpf_map'
  2: R1=ctx() R2_w=trusted_ptr_bpf_map()
  ; if (map != (struct bpf_map *)&data_input) @ test_libbpf_get_fd_by_id_opts.c:29
  2: (18) r3 = 0xffff9742c0951a00       ; R3_w=map_ptr(map=data_input,ks=4,vs=4)
  4: (5d) if r2 != r3 goto pc+4         ; R2_w=trusted_ptr_bpf_map() R3_w=map_ptr(map=data_input,ks=4,vs=4)
  ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
  5: (79) r0 = *(u64 *)(r1 +8)          ; R0_w=scalar() R1=ctx()
  ; if (fmode & FMODE_WRITE) @ test_libbpf_get_fd_by_id_opts.c:32
  6: (67) r0 <<= 62                     ; R0_w=scalar(smax=0x4000000000000000,umax=0xc000000000000000,smin32=0,smax32=umax32=0,var_off=(0x0; 0xc000000000000000))
  7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
  ;  @ test_libbpf_get_fd_by_id_opts.c:0
  8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
  ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
  9: (95) exit

And here is the C code of the prog.

SEC("lsm/bpf_map")
int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode)
{
	if (map != (struct bpf_map *)&data_input)
		return 0;

	if (fmode & FMODE_WRITE)
		return -EACCES;

	return 0;
}

It is clear that the prog can only return either 0 or -EACCESS, and both
values are legal.

So why is it rejected by the verifier?

The verifier log shows that the second if and return value setting
statements in the prog is optimized to bitwise operations "r0 s>>= 63"
and "r0 &= -13". The verifier correctly deduces that the the value of
r0 is in the range [-1, 0] after verifing instruction "r0 s>>= 63".
But when the verifier proceeds to verify instruction "r0 &= -13", it
fails to deduce the correct value range of r0.

7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))

So why the verifier fails to deduce the result of 'r0 &= -13'?

The verifier uses tnum to track values, and the two ranges "[-1, 0]" and
"[0, -1ULL]" are encoded to the same tnum. When verifing instruction
"r0 &= -13", the verifier erroneously deduces the result from
"[0, -1ULL] AND -13", which is out of the expected return range
[-4095, 0].

To fix it, this patch simply adds a special SCALAR32 case for the
verifier. That is, when the source operand of the AND instruction is
a constant and the destination operand changes from negative to
non-negative and falls in range [-256, 256], deduce the result range
by enumerating all possible AND results.

Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
---
 kernel/bpf/verifier.c | 23 +++++++++++++++++++++++
 1 file changed, 23 insertions(+)

Comments

Eduard Zingerman April 19, 2024, 11 p.m. UTC | #1
On Thu, 2024-04-11 at 20:27 +0800, Xu Kuohai wrote:
> From: Xu Kuohai <xukuohai@huawei.com>
> 
> With lsm return value check, the no-alu32 version test_libbpf_get_fd_by_id_opts
> is rejected by the verifier, and the log says:
> 
>   0: R1=ctx() R10=fp0
>   ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>   0: (b7) r0 = 0                        ; R0_w=0
>   1: (79) r2 = *(u64 *)(r1 +0)
>   func 'bpf_lsm_bpf_map' arg0 has btf_id 916 type STRUCT 'bpf_map'
>   2: R1=ctx() R2_w=trusted_ptr_bpf_map()
>   ; if (map != (struct bpf_map *)&data_input) @ test_libbpf_get_fd_by_id_opts.c:29
>   2: (18) r3 = 0xffff9742c0951a00       ; R3_w=map_ptr(map=data_input,ks=4,vs=4)
>   4: (5d) if r2 != r3 goto pc+4         ; R2_w=trusted_ptr_bpf_map() R3_w=map_ptr(map=data_input,ks=4,vs=4)
>   ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>   5: (79) r0 = *(u64 *)(r1 +8)          ; R0_w=scalar() R1=ctx()
>   ; if (fmode & FMODE_WRITE) @ test_libbpf_get_fd_by_id_opts.c:32
>   6: (67) r0 <<= 62                     ; R0_w=scalar(smax=0x4000000000000000,umax=0xc000000000000000,smin32=0,smax32=umax32=0,var_off=(0x0; 0xc000000000000000))
>   7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
>   ;  @ test_libbpf_get_fd_by_id_opts.c:0
>   8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
>   ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>   9: (95) exit
> 
> And here is the C code of the prog.
> 
> SEC("lsm/bpf_map")
> int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode)
> {
> 	if (map != (struct bpf_map *)&data_input)
> 		return 0;
> 
> 	if (fmode & FMODE_WRITE)
> 		return -EACCES;
> 
> 	return 0;
> }
> 
> It is clear that the prog can only return either 0 or -EACCESS, and both
> values are legal.
> 
> So why is it rejected by the verifier?
> 
> The verifier log shows that the second if and return value setting
> statements in the prog is optimized to bitwise operations "r0 s>>= 63"
> and "r0 &= -13". The verifier correctly deduces that the the value of
> r0 is in the range [-1, 0] after verifing instruction "r0 s>>= 63".
> But when the verifier proceeds to verify instruction "r0 &= -13", it
> fails to deduce the correct value range of r0.
> 
> 7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
> 8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
> 
> So why the verifier fails to deduce the result of 'r0 &= -13'?
> 
> The verifier uses tnum to track values, and the two ranges "[-1, 0]" and
> "[0, -1ULL]" are encoded to the same tnum. When verifing instruction
> "r0 &= -13", the verifier erroneously deduces the result from
> "[0, -1ULL] AND -13", which is out of the expected return range
> [-4095, 0].
> 
> To fix it, this patch simply adds a special SCALAR32 case for the
> verifier. That is, when the source operand of the AND instruction is
> a constant and the destination operand changes from negative to
> non-negative and falls in range [-256, 256], deduce the result range
> by enumerating all possible AND results.
> 
> Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
> ---

Hello,

Sorry for the delay, I had to think about this issue a bit.
I found the clang transformation that generates the pattern this patch
tries to handle.
It is located in DAGCombiner::SimplifySelectCC() method (see [1]).
The transformation happens as a part of DAG to DAG rewrites
(LLVM uses several internal representations:
 - generic optimizer uses LLVM IR, most of the work is done
   using this representation;
 - before instruction selection IR is converted to Selection DAG,
   some optimizations are applied at this stage,
   all such optimizations are a set of pattern replacements;
 - Selection DAG is converted to machine code, some optimizations
   are applied at the machine code level).

Full pattern is described as follows:

  // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
  // where y is has a single bit set.
  // A plaintext description would be, we can turn the SELECT_CC into an AND
  // when the condition can be materialized as an all-ones register.  Any
  // single bit-test can be materialized as an all-ones register with
  // shift-left and shift-right-arith.

For this particular test case the DAG is converted as follows:

                    .---------------- lhs         The meaning of this select_cc is:
                    |        .------- rhs         `lhs == rhs ? true value : false value`
                    |        | .----- true value
                    |        | |  .-- false value
                    v        v v  v 
  (select_cc seteq (and X 2) 0 0 -13)
                          ^
->                        '---------------.
  (and (sra (sll X 62) 63)                |
       -13)                               |
                                          |
Before pattern is applied, it checks that second 'and' operand has
only one bit set, (which is true for '2').

The pattern itself generates logical shift left / arithmetic shift
right pair, that ensures that result is either all ones (-1) or all
zeros (0). Hence, applying 'and' to shifts result and false value
generates a correct result.

In my opinion the approach taken by this patch is sub-optimal:
- 512 iterations is too much;
- this does not cover all code that could be generated by the above
  mentioned LLVM transformation
  (e.g. second 'and' operand could be 1 << 16).

Instead, I suggest to make a special case for source or dst register
of '&=' operation being in range [-1,0].
Meaning that one of the '&=' operands is either:
- all ones, in which case the counterpart is the result of the operation;
- all zeros, in which case zero is the result of the operation;
- derive MIN and MAX values based on above two observations.

[1] https://github.com/llvm/llvm-project/blob/4523a267829c807f3fc8fab8e5e9613985a51565/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp#L5391

Best regards,
Eduard
Xu Kuohai April 20, 2024, 8:33 a.m. UTC | #2
On 4/20/2024 7:00 AM, Eduard Zingerman wrote:
> On Thu, 2024-04-11 at 20:27 +0800, Xu Kuohai wrote:
>> From: Xu Kuohai <xukuohai@huawei.com>
>>
>> With lsm return value check, the no-alu32 version test_libbpf_get_fd_by_id_opts
>> is rejected by the verifier, and the log says:
>>
>>    0: R1=ctx() R10=fp0
>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>>    0: (b7) r0 = 0                        ; R0_w=0
>>    1: (79) r2 = *(u64 *)(r1 +0)
>>    func 'bpf_lsm_bpf_map' arg0 has btf_id 916 type STRUCT 'bpf_map'
>>    2: R1=ctx() R2_w=trusted_ptr_bpf_map()
>>    ; if (map != (struct bpf_map *)&data_input) @ test_libbpf_get_fd_by_id_opts.c:29
>>    2: (18) r3 = 0xffff9742c0951a00       ; R3_w=map_ptr(map=data_input,ks=4,vs=4)
>>    4: (5d) if r2 != r3 goto pc+4         ; R2_w=trusted_ptr_bpf_map() R3_w=map_ptr(map=data_input,ks=4,vs=4)
>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>>    5: (79) r0 = *(u64 *)(r1 +8)          ; R0_w=scalar() R1=ctx()
>>    ; if (fmode & FMODE_WRITE) @ test_libbpf_get_fd_by_id_opts.c:32
>>    6: (67) r0 <<= 62                     ; R0_w=scalar(smax=0x4000000000000000,umax=0xc000000000000000,smin32=0,smax32=umax32=0,var_off=(0x0; 0xc000000000000000))
>>    7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
>>    ;  @ test_libbpf_get_fd_by_id_opts.c:0
>>    8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>>    9: (95) exit
>>
>> And here is the C code of the prog.
>>
>> SEC("lsm/bpf_map")
>> int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode)
>> {
>> 	if (map != (struct bpf_map *)&data_input)
>> 		return 0;
>>
>> 	if (fmode & FMODE_WRITE)
>> 		return -EACCES;
>>
>> 	return 0;
>> }
>>
>> It is clear that the prog can only return either 0 or -EACCESS, and both
>> values are legal.
>>
>> So why is it rejected by the verifier?
>>
>> The verifier log shows that the second if and return value setting
>> statements in the prog is optimized to bitwise operations "r0 s>>= 63"
>> and "r0 &= -13". The verifier correctly deduces that the the value of
>> r0 is in the range [-1, 0] after verifing instruction "r0 s>>= 63".
>> But when the verifier proceeds to verify instruction "r0 &= -13", it
>> fails to deduce the correct value range of r0.
>>
>> 7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
>> 8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
>>
>> So why the verifier fails to deduce the result of 'r0 &= -13'?
>>
>> The verifier uses tnum to track values, and the two ranges "[-1, 0]" and
>> "[0, -1ULL]" are encoded to the same tnum. When verifing instruction
>> "r0 &= -13", the verifier erroneously deduces the result from
>> "[0, -1ULL] AND -13", which is out of the expected return range
>> [-4095, 0].
>>
>> To fix it, this patch simply adds a special SCALAR32 case for the
>> verifier. That is, when the source operand of the AND instruction is
>> a constant and the destination operand changes from negative to
>> non-negative and falls in range [-256, 256], deduce the result range
>> by enumerating all possible AND results.
>>
>> Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
>> ---
> 
> Hello,
> 
> Sorry for the delay, I had to think about this issue a bit.
> I found the clang transformation that generates the pattern this patch
> tries to handle.
> It is located in DAGCombiner::SimplifySelectCC() method (see [1]).
> The transformation happens as a part of DAG to DAG rewrites
> (LLVM uses several internal representations:
>   - generic optimizer uses LLVM IR, most of the work is done
>     using this representation;
>   - before instruction selection IR is converted to Selection DAG,
>     some optimizations are applied at this stage,
>     all such optimizations are a set of pattern replacements;
>   - Selection DAG is converted to machine code, some optimizations
>     are applied at the machine code level).
> 
> Full pattern is described as follows:
> 
>    // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
>    // where y is has a single bit set.
>    // A plaintext description would be, we can turn the SELECT_CC into an AND
>    // when the condition can be materialized as an all-ones register.  Any
>    // single bit-test can be materialized as an all-ones register with
>    // shift-left and shift-right-arith.
> 
> For this particular test case the DAG is converted as follows:
> 
>                      .---------------- lhs         The meaning of this select_cc is:
>                      |        .------- rhs         `lhs == rhs ? true value : false value`
>                      |        | .----- true value
>                      |        | |  .-- false value
>                      v        v v  v
>    (select_cc seteq (and X 2) 0 0 -13)
>                            ^
> ->                        '---------------.
>    (and (sra (sll X 62) 63)                |
>         -13)                               |
>                                            |
> Before pattern is applied, it checks that second 'and' operand has
> only one bit set, (which is true for '2').
> 
> The pattern itself generates logical shift left / arithmetic shift
> right pair, that ensures that result is either all ones (-1) or all
> zeros (0). Hence, applying 'and' to shifts result and false value
> generates a correct result.
>

Thanks for your detailed and invaluable explanation!

> In my opinion the approach taken by this patch is sub-optimal:
> - 512 iterations is too much;
> - this does not cover all code that could be generated by the above
>    mentioned LLVM transformation
>    (e.g. second 'and' operand could be 1 << 16).
> 
> Instead, I suggest to make a special case for source or dst register
> of '&=' operation being in range [-1,0].
> Meaning that one of the '&=' operands is either:
> - all ones, in which case the counterpart is the result of the operation;
> - all zeros, in which case zero is the result of the operation;
> - derive MIN and MAX values based on above two observations.
>

Totally agree, I'll cook a new patch as you suggested.

> [1] https://github.com/llvm/llvm-project/blob/4523a267829c807f3fc8fab8e5e9613985a51565/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp#L5391
> 
> Best regards,
> Eduard
Yonghong Song April 23, 2024, 9:55 p.m. UTC | #3
On 4/20/24 1:33 AM, Xu Kuohai wrote:
> On 4/20/2024 7:00 AM, Eduard Zingerman wrote:
>> On Thu, 2024-04-11 at 20:27 +0800, Xu Kuohai wrote:
>>> From: Xu Kuohai <xukuohai@huawei.com>
>>>
>>> With lsm return value check, the no-alu32 version 
>>> test_libbpf_get_fd_by_id_opts
>>> is rejected by the verifier, and the log says:
>>>
>>>    0: R1=ctx() R10=fp0
>>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) 
>>> @ test_libbpf_get_fd_by_id_opts.c:27
>>>    0: (b7) r0 = 0                        ; R0_w=0
>>>    1: (79) r2 = *(u64 *)(r1 +0)
>>>    func 'bpf_lsm_bpf_map' arg0 has btf_id 916 type STRUCT 'bpf_map'
>>>    2: R1=ctx() R2_w=trusted_ptr_bpf_map()
>>>    ; if (map != (struct bpf_map *)&data_input) @ 
>>> test_libbpf_get_fd_by_id_opts.c:29
>>>    2: (18) r3 = 0xffff9742c0951a00       ; 
>>> R3_w=map_ptr(map=data_input,ks=4,vs=4)
>>>    4: (5d) if r2 != r3 goto pc+4         ; 
>>> R2_w=trusted_ptr_bpf_map() R3_w=map_ptr(map=data_input,ks=4,vs=4)
>>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) 
>>> @ test_libbpf_get_fd_by_id_opts.c:27
>>>    5: (79) r0 = *(u64 *)(r1 +8)          ; R0_w=scalar() R1=ctx()
>>>    ; if (fmode & FMODE_WRITE) @ test_libbpf_get_fd_by_id_opts.c:32
>>>    6: (67) r0 <<= 62                     ; 
>>> R0_w=scalar(smax=0x4000000000000000,umax=0xc000000000000000,smin32=0,smax32=umax32=0,var_off=(0x0; 
>>> 0xc000000000000000))
>>>    7: (c7) r0 s>>= 63                    ; 
>>> R0_w=scalar(smin=smin32=-1,smax=smax32=0)
>>>    ;  @ test_libbpf_get_fd_by_id_opts.c:0
>>>    8: (57) r0 &= -13                     ; 
>>> R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 
>>> 0xfffffffffffffff3))
>>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) 
>>> @ test_libbpf_get_fd_by_id_opts.c:27
>>>    9: (95) exit
>>>
>>> And here is the C code of the prog.
>>>
>>> SEC("lsm/bpf_map")
>>> int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode)
>>> {
>>>     if (map != (struct bpf_map *)&data_input)
>>>         return 0;
>>>
>>>     if (fmode & FMODE_WRITE)
>>>         return -EACCES;
>>>
>>>     return 0;
>>> }
>>>
>>> It is clear that the prog can only return either 0 or -EACCESS, and 
>>> both
>>> values are legal.
>>>
>>> So why is it rejected by the verifier?
>>>
>>> The verifier log shows that the second if and return value setting
>>> statements in the prog is optimized to bitwise operations "r0 s>>= 63"
>>> and "r0 &= -13". The verifier correctly deduces that the the value of
>>> r0 is in the range [-1, 0] after verifing instruction "r0 s>>= 63".
>>> But when the verifier proceeds to verify instruction "r0 &= -13", it
>>> fails to deduce the correct value range of r0.
>>>
>>> 7: (c7) r0 s>>= 63                    ; 
>>> R0_w=scalar(smin=smin32=-1,smax=smax32=0)
>>> 8: (57) r0 &= -13                     ; 
>>> R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 
>>> 0xfffffffffffffff3))
>>>
>>> So why the verifier fails to deduce the result of 'r0 &= -13'?
>>>
>>> The verifier uses tnum to track values, and the two ranges "[-1, 0]" 
>>> and
>>> "[0, -1ULL]" are encoded to the same tnum. When verifing instruction
>>> "r0 &= -13", the verifier erroneously deduces the result from
>>> "[0, -1ULL] AND -13", which is out of the expected return range
>>> [-4095, 0].
>>>
>>> To fix it, this patch simply adds a special SCALAR32 case for the
>>> verifier. That is, when the source operand of the AND instruction is
>>> a constant and the destination operand changes from negative to
>>> non-negative and falls in range [-256, 256], deduce the result range
>>> by enumerating all possible AND results.
>>>
>>> Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
>>> ---
>>
>> Hello,
>>
>> Sorry for the delay, I had to think about this issue a bit.
>> I found the clang transformation that generates the pattern this patch
>> tries to handle.
>> It is located in DAGCombiner::SimplifySelectCC() method (see [1]).
>> The transformation happens as a part of DAG to DAG rewrites
>> (LLVM uses several internal representations:
>>   - generic optimizer uses LLVM IR, most of the work is done
>>     using this representation;
>>   - before instruction selection IR is converted to Selection DAG,
>>     some optimizations are applied at this stage,
>>     all such optimizations are a set of pattern replacements;
>>   - Selection DAG is converted to machine code, some optimizations
>>     are applied at the machine code level).
>>
>> Full pattern is described as follows:
>>
>>    // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl 
>> x)) A)
>>    // where y is has a single bit set.
>>    // A plaintext description would be, we can turn the SELECT_CC 
>> into an AND
>>    // when the condition can be materialized as an all-ones 
>> register.  Any
>>    // single bit-test can be materialized as an all-ones register with
>>    // shift-left and shift-right-arith.
>>
>> For this particular test case the DAG is converted as follows:
>>
>>                      .---------------- lhs         The meaning of 
>> this select_cc is:
>>                      |        .------- rhs         `lhs == rhs ? true 
>> value : false value`
>>                      |        | .----- true value
>>                      |        | |  .-- false value
>>                      v        v v  v
>>    (select_cc seteq (and X 2) 0 0 -13)
>>                            ^
>> ->                        '---------------.
>>    (and (sra (sll X 62) 63)                |
>>         -13)                               |
>>                                            |
>> Before pattern is applied, it checks that second 'and' operand has
>> only one bit set, (which is true for '2').
>>
>> The pattern itself generates logical shift left / arithmetic shift
>> right pair, that ensures that result is either all ones (-1) or all
>> zeros (0). Hence, applying 'and' to shifts result and false value
>> generates a correct result.
>>
>
> Thanks for your detailed and invaluable explanation!

Thanks Eduard for detailed explanation. It looks like we could
resolve this issue without adding too much complexity to verifier.
Also, this code pattern above seems generic enough to be worthwhile
with verifier change.

Kuohai, please added detailed explanation (as described by Eduard)
in the commit message.

>
>> In my opinion the approach taken by this patch is sub-optimal:
>> - 512 iterations is too much;
>> - this does not cover all code that could be generated by the above
>>    mentioned LLVM transformation
>>    (e.g. second 'and' operand could be 1 << 16).
>>
>> Instead, I suggest to make a special case for source or dst register
>> of '&=' operation being in range [-1,0].
>> Meaning that one of the '&=' operands is either:
>> - all ones, in which case the counterpart is the result of the 
>> operation;
>> - all zeros, in which case zero is the result of the operation;
>> - derive MIN and MAX values based on above two observations.
>>
>
> Totally agree, I'll cook a new patch as you suggested.
>
>> [1] 
>> https://github.com/llvm/llvm-project/blob/4523a267829c807f3fc8fab8e5e9613985a51565/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp#L5391
>>
>> Best regards,
>> Eduard
>
>
Andrii Nakryiko April 26, 2024, 8:36 p.m. UTC | #4
On Tue, Apr 23, 2024 at 7:26 PM Xu Kuohai <xukuohai@huaweicloud.com> wrote:
>
> On 4/24/2024 5:55 AM, Yonghong Song wrote:
> >
> > On 4/20/24 1:33 AM, Xu Kuohai wrote:
> >> On 4/20/2024 7:00 AM, Eduard Zingerman wrote:
> >>> On Thu, 2024-04-11 at 20:27 +0800, Xu Kuohai wrote:
> >>>> From: Xu Kuohai <xukuohai@huawei.com>
> >>>>
> >>>> With lsm return value check, the no-alu32 version test_libbpf_get_fd_by_id_opts
> >>>> is rejected by the verifier, and the log says:
> >>>>
> >>>>    0: R1=ctx() R10=fp0
> >>>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
> >>>>    0: (b7) r0 = 0                        ; R0_w=0
> >>>>    1: (79) r2 = *(u64 *)(r1 +0)
> >>>>    func 'bpf_lsm_bpf_map' arg0 has btf_id 916 type STRUCT 'bpf_map'
> >>>>    2: R1=ctx() R2_w=trusted_ptr_bpf_map()
> >>>>    ; if (map != (struct bpf_map *)&data_input) @ test_libbpf_get_fd_by_id_opts.c:29
> >>>>    2: (18) r3 = 0xffff9742c0951a00       ; R3_w=map_ptr(map=data_input,ks=4,vs=4)
> >>>>    4: (5d) if r2 != r3 goto pc+4         ; R2_w=trusted_ptr_bpf_map() R3_w=map_ptr(map=data_input,ks=4,vs=4)
> >>>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
> >>>>    5: (79) r0 = *(u64 *)(r1 +8)          ; R0_w=scalar() R1=ctx()
> >>>>    ; if (fmode & FMODE_WRITE) @ test_libbpf_get_fd_by_id_opts.c:32
> >>>>    6: (67) r0 <<= 62                     ; R0_w=scalar(smax=0x4000000000000000,umax=0xc000000000000000,smin32=0,smax32=umax32=0,var_off=(0x0; 0xc000000000000000))
> >>>>    7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
> >>>>    ;  @ test_libbpf_get_fd_by_id_opts.c:0
> >>>>    8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
> >>>>    ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
> >>>>    9: (95) exit

[...]

>
>      As suggested by Eduard, this patch makes a special case for source
>      or destination register of '&=' operation being in range [-1, 0].
>
>      Meaning that one of the '&=' operands is either:
>      - all ones, in which case the counterpart is the result of the operation;
>      - all zeros, in which case zero is the result of the operation.
>
>      And MIN and MAX values could be derived based on above two observations.
>
>      [0] https://lore.kernel.org/bpf/e62e2971301ca7f2e9eb74fc500c520285cad8f5.camel@gmail.com/
>      [1] https://github.com/llvm/llvm-project/blob/4523a267829c807f3fc8fab8e5e9613985a51565/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
>
>      Suggested-by: Eduard Zingerman <eddyz87@gmail.com>
>      Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
>
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 640747b53745..30c551d39329 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -13374,6 +13374,24 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
>          dst_reg->u32_min_value = var32_off.value;
>          dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
>
> +       /* Special case: src_reg is known and dst_reg is in range [-1, 0] */
> +       if (src_known &&
> +               dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0 &&
> +               dst_reg->smin_value == -1 && dst_reg->smax_value == 0) {

please keep if () condition aligned across multiple lines, it's super
confusing this way

> +               dst_reg->s32_min_value = min_t(s32, src_reg->s32_min_value, 0);
> +               dst_reg->s32_max_value = max_t(s32, src_reg->s32_min_value, 0);

do we need to update tnum parts as well (or reset and re-derive, probably)?

btw, can't we support src being a range here? the idea is that dst_reg
either all ones or all zeros. For and it means that it either stays
all zero, or will be *exactly equal* to src, right? So I think the
logic would be:

a) if [s32_min, s32_max] is on the same side of zero, then resulting
range would be [min(s32_min, 0), max(s32_max, 0)], just like you have
here

b) if [s32_min, s32_max] contains zero, then resulting range will be
exactly [s32_min, s32_max]

Or did I make a mistake above?

> +               return;
> +       }
> +
> +       /* Special case: dst_reg is known and src_reg is in range [-1, 0] */
> +       if (dst_known &&
> +               src_reg->s32_min_value == -1 && src_reg->s32_max_value == 0 &&
> +               src_reg->smin_value == -1 && src_reg->smax_value == 0) {
> +               dst_reg->s32_min_value = min_t(s32, dst_reg->s32_min_value, 0);
> +               dst_reg->s32_max_value = max_t(s32, dst_reg->s32_min_value, 0);
> +               return;
> +       }
> +
>          /* Safe to set s32 bounds by casting u32 result into s32 when u32
>           * doesn't cross sign boundary. Otherwise set s32 bounds to unbounded.
>           */

[...]
Andrii Nakryiko April 29, 2024, 8:58 p.m. UTC | #5
On Sun, Apr 28, 2024 at 8:15 AM Xu Kuohai <xukuohai@huaweicloud.com> wrote:
>
> On 4/27/2024 4:36 AM, Andrii Nakryiko wrote:
> > On Tue, Apr 23, 2024 at 7:26 PM Xu Kuohai <xukuohai@huaweicloud.com> wrote:
> >>
> >> On 4/24/2024 5:55 AM, Yonghong Song wrote:
> >>>
> >>> On 4/20/24 1:33 AM, Xu Kuohai wrote:
> >>>> On 4/20/2024 7:00 AM, Eduard Zingerman wrote:
> >>>>> On Thu, 2024-04-11 at 20:27 +0800, Xu Kuohai wrote:
> >>>>>> From: Xu Kuohai <xukuohai@huawei.com>
> >>>>>>
> >>>>>> With lsm return value check, the no-alu32 version test_libbpf_get_fd_by_id_opts
> >>>>>> is rejected by the verifier, and the log says:
> >>>>>>
> >>>>>>     0: R1=ctx() R10=fp0
> >>>>>>     ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
> >>>>>>     0: (b7) r0 = 0                        ; R0_w=0
> >>>>>>     1: (79) r2 = *(u64 *)(r1 +0)
> >>>>>>     func 'bpf_lsm_bpf_map' arg0 has btf_id 916 type STRUCT 'bpf_map'
> >>>>>>     2: R1=ctx() R2_w=trusted_ptr_bpf_map()
> >>>>>>     ; if (map != (struct bpf_map *)&data_input) @ test_libbpf_get_fd_by_id_opts.c:29
> >>>>>>     2: (18) r3 = 0xffff9742c0951a00       ; R3_w=map_ptr(map=data_input,ks=4,vs=4)
> >>>>>>     4: (5d) if r2 != r3 goto pc+4         ; R2_w=trusted_ptr_bpf_map() R3_w=map_ptr(map=data_input,ks=4,vs=4)
> >>>>>>     ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
> >>>>>>     5: (79) r0 = *(u64 *)(r1 +8)          ; R0_w=scalar() R1=ctx()
> >>>>>>     ; if (fmode & FMODE_WRITE) @ test_libbpf_get_fd_by_id_opts.c:32
> >>>>>>     6: (67) r0 <<= 62                     ; R0_w=scalar(smax=0x4000000000000000,umax=0xc000000000000000,smin32=0,smax32=umax32=0,var_off=(0x0; 0xc000000000000000))
> >>>>>>     7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
> >>>>>>     ;  @ test_libbpf_get_fd_by_id_opts.c:0
> >>>>>>     8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
> >>>>>>     ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
> >>>>>>     9: (95) exit
> >
> > [...]
> >
> >>
> >>       As suggested by Eduard, this patch makes a special case for source
> >>       or destination register of '&=' operation being in range [-1, 0].
> >>
> >>       Meaning that one of the '&=' operands is either:
> >>       - all ones, in which case the counterpart is the result of the operation;
> >>       - all zeros, in which case zero is the result of the operation.
> >>
> >>       And MIN and MAX values could be derived based on above two observations.
> >>
> >>       [0] https://lore.kernel.org/bpf/e62e2971301ca7f2e9eb74fc500c520285cad8f5.camel@gmail.com/
> >>       [1] https://github.com/llvm/llvm-project/blob/4523a267829c807f3fc8fab8e5e9613985a51565/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
> >>
> >>       Suggested-by: Eduard Zingerman <eddyz87@gmail.com>
> >>       Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
> >>
> >> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> >> index 640747b53745..30c551d39329 100644
> >> --- a/kernel/bpf/verifier.c
> >> +++ b/kernel/bpf/verifier.c
> >> @@ -13374,6 +13374,24 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
> >>           dst_reg->u32_min_value = var32_off.value;
> >>           dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
> >>
> >> +       /* Special case: src_reg is known and dst_reg is in range [-1, 0] */
> >> +       if (src_known &&
> >> +               dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0 &&
> >> +               dst_reg->smin_value == -1 && dst_reg->smax_value == 0) {
> >
> > please keep if () condition aligned across multiple lines, it's super
> > confusing this way
> >
>
> OK, will update the align style
>
> >> +               dst_reg->s32_min_value = min_t(s32, src_reg->s32_min_value, 0);
> >> +               dst_reg->s32_max_value = max_t(s32, src_reg->s32_min_value, 0);
> >
> > do we need to update tnum parts as well (or reset and re-derive, probably)?
> >
> > btw, can't we support src being a range here? the idea is that dst_reg
> > either all ones or all zeros. For and it means that it either stays
> > all zero, or will be *exactly equal* to src, right? So I think the
> > logic would be:
> >
> > a) if [s32_min, s32_max] is on the same side of zero, then resulting
> > range would be [min(s32_min, 0), max(s32_max, 0)], just like you have
> > here
> >
> > b) if [s32_min, s32_max] contains zero, then resulting range will be
> > exactly [s32_min, s32_max]
> >
> > Or did I make a mistake above?
> >
>
> Totally agree, the AND of any set with the range [-1,0] is equivalent
> to adding number 0 to the set!
>
> Based on this observation, I've rewritten the patch as follows.
>
> diff --git a/include/linux/tnum.h b/include/linux/tnum.h
> index 3c13240077b8..5e795d728b9f 100644
> --- a/include/linux/tnum.h
> +++ b/include/linux/tnum.h
> @@ -52,6 +52,9 @@ struct tnum tnum_mul(struct tnum a, struct tnum b);
>   /* Return a tnum representing numbers satisfying both @a and @b */
>   struct tnum tnum_intersect(struct tnum a, struct tnum b);
>
> +/* Return a tnum representing numbers satisfying either @a or @b */
> +struct tnum tnum_union(struct tnum a, struct tnum b);
> +
>   /* Return @a with all but the lowest @size bytes cleared */
>   struct tnum tnum_cast(struct tnum a, u8 size);
>
> diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> index 9dbc31b25e3d..9d4480a683ca 100644
> --- a/kernel/bpf/tnum.c
> +++ b/kernel/bpf/tnum.c
> @@ -150,6 +150,29 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b)
>          return TNUM(v & ~mu, mu);
>   }
>
> +/*
> + * Each bit has 3 states: unkown, known 0, known 1. If using x to represent
> + * unknown state, the result of the union of two bits is as follows:
> + *
> + *         | x    0    1
> + *    -----+------------
> + *     x   | x    x    x
> + *     0   | x    0    x
> + *     1   | x    x    1
> + *
> + * For tnum a and b, only the bits that are both known 0 or known 1 in a
> + * and b are known in the result of union a and b.
> + */
> +struct tnum tnum_union(struct tnum a, struct tnum b)
> +{
> +       u64 v0, v1, mu;
> +
> +       mu = a.mask | b.mask; // unkown bits either in a or b
> +       v1 = (a.value & b.value) & ~mu; // "known 1" bits in both a and b
> +       v0 = (~a.value & ~b.value) & ~mu; // "known 0" bits in both a and b

no C++-style comments, please

> +       return TNUM(v1, mu | ~(v0 | v1));
> +}
> +

I've CC'ed Edward, hopefully he can take a look as well. Please CC him
on future patches touching tnum as well.

>   struct tnum tnum_cast(struct tnum a, u8 size)
>   {
>          a.value &= (1ULL << (size * 8)) - 1;
>   {
>          a.value &= (1ULL << (size * 8)) - 1;
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 8f0f2e21699e..b69c89bc5cfc 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -13478,6 +13478,28 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
>                  return;
>          }
>
> +       /* Special case: dst_reg is in range [-1, 0] */
> +       if (dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0) {
> +               var32_off = tnum_union(src_reg->var_off, tnum_const(0));
> +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
> +               dst_reg->u32_min_value = var32_off.value;
> +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);

can you explain the logic behing u32 min/max updates, especially that
we use completely different values for min/max and it's not clear why
u32_min <= u32_max invariant will always hold. Same below

> +               dst_reg->s32_min_value = min_t(s32, src_reg->s32_min_value, 0);
> +               dst_reg->s32_max_value = max_t(s32, src_reg->s32_max_value, 0);
> +               return;
> +       }
> +
> +       /* Special case: src_reg is in range [-1, 0] */
> +       if (src_reg->s32_min_value == -1 && src_reg->s32_max_value == 0) {
> +               var32_off = tnum_union(dst_reg->var_off, tnum_const(0));
> +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
> +               dst_reg->u32_min_value = var32_off.value;
> +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
> +               dst_reg->s32_min_value = min_t(s32, dst_reg->s32_min_value, 0);
> +               dst_reg->s32_max_value = max_t(s32, dst_reg->s32_max_value, 0);
> +               return;
> +       }
> +
>          /* We get our minimum from the var_off, since that's inherently
>           * bitwise.  Our maximum is the minimum of the operands' maxima.
>           */
> @@ -13508,6 +13530,26 @@ static void scalar_min_max_and(struct bpf_reg_state *dst_reg,
>                  return;
>          }
>
> +       /* Special case: dst_reg is in range [-1, 0] */
> +       if (dst_reg->smin_value == -1 && dst_reg->smax_value == 0) {
> +               dst_reg->var_off = tnum_union(src_reg->var_off, tnum_const(0));
> +               dst_reg->umin_value = dst_reg->var_off.value;
> +               dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
> +               dst_reg->smin_value = min_t(s64, src_reg->smin_value, 0);
> +               dst_reg->smax_value = max_t(s64, src_reg->smax_value, 0);
> +               return;
> +       }
> +
> +       /* Special case: src_reg is in range [-1, 0] */
> +       if (src_reg->smin_value == -1 && src_reg->smax_value == 0) {
> +               dst_reg->var_off = tnum_union(dst_reg->var_off, tnum_const(0));
> +               dst_reg->umin_value = dst_reg->var_off.value;
> +               dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
> +               dst_reg->smin_value = min_t(s64, dst_reg->smin_value, 0);
> +               dst_reg->smax_value = max_t(s64, dst_reg->smax_value, 0);
> +               return;
> +       }
> +
>
> >> +               return;
> >> +       }
> >> +
> >> +       /* Special case: dst_reg is known and src_reg is in range [-1, 0] */
> >> +       if (dst_known &&
> >> +               src_reg->s32_min_value == -1 && src_reg->s32_max_value == 0 &&
> >> +               src_reg->smin_value == -1 && src_reg->smax_value == 0) {
> >> +               dst_reg->s32_min_value = min_t(s32, dst_reg->s32_min_value, 0);
> >> +               dst_reg->s32_max_value = max_t(s32, dst_reg->s32_min_value, 0);
> >> +               return;
> >> +       }
> >> +
> >>           /* Safe to set s32 bounds by casting u32 result into s32 when u32
> >>            * doesn't cross sign boundary. Otherwise set s32 bounds to unbounded.
> >>            */
> >
> > [...]
> >
>
Eduard Zingerman April 29, 2024, 9:56 p.m. UTC | #6
On Sun, 2024-04-28 at 23:15 +0800, Xu Kuohai wrote:

[...]

> diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> index 9dbc31b25e3d..9d4480a683ca 100644
> --- a/kernel/bpf/tnum.c
> +++ b/kernel/bpf/tnum.c
> @@ -150,6 +150,29 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b)
>          return TNUM(v & ~mu, mu);
>   }
> 
> +/*
> + * Each bit has 3 states: unkown, known 0, known 1. If using x to represent
> + * unknown state, the result of the union of two bits is as follows:
> + *
> + *         | x    0    1
> + *    -----+------------
> + *     x   | x    x    x
> + *     0   | x    0    x
> + *     1   | x    x    1
> + *
> + * For tnum a and b, only the bits that are both known 0 or known 1 in a
> + * and b are known in the result of union a and b.
> + */
> +struct tnum tnum_union(struct tnum a, struct tnum b)
> +{
> +       u64 v0, v1, mu;
> +
> +       mu = a.mask | b.mask; // unkown bits either in a or b
> +       v1 = (a.value & b.value) & ~mu; // "known 1" bits in both a and b
> +       v0 = (~a.value & ~b.value) & ~mu; // "known 0" bits in both a and b
> +       return TNUM(v1, mu | ~(v0 | v1));
> +}
> +

Zero would be represented as {.value=0,.mask=0}, suppose 'b' is zero:

1. mu = a.mask | 0;                     2. mu = a.mask;
   v1 = (a.value & 0) & ~mu;               v1 = 0;
   v0 = (~a.value & ~0) & ~mu;             v0 = ~a.value & ~mu;
   return TNUM(v1, mu | ~(v0 | v1));       return TNUM(v1, mu | ~(v0 | v1));

3. v1 = 0;                              4. v1 = 0;
   v0 = ~a.value & ~a.mask;                v0 = ~a.value & ~a.mask;
   return TNUM(v1, a.mask | ~(v0 | v1));   return TNUM(0, a.mask | ~(~a.value & ~a.mask));

5. return TNUM(0, a.mask | a.value)

So ultimately this says that for 1's that we knew
we no longer know if those are 1's.
Which seems to make sense.
Eduard Zingerman April 29, 2024, 10:18 p.m. UTC | #7
On Mon, 2024-04-29 at 13:58 -0700, Andrii Nakryiko wrote:

[...]

> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index 8f0f2e21699e..b69c89bc5cfc 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -13478,6 +13478,28 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
> >                  return;
> >          }
> > 
> > +       /* Special case: dst_reg is in range [-1, 0] */
> > +       if (dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0) {
> > +               var32_off = tnum_union(src_reg->var_off, tnum_const(0));
> > +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
> > +               dst_reg->u32_min_value = var32_off.value;
> > +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
> 
> can you explain the logic behing u32 min/max updates, especially that
> we use completely different values for min/max and it's not clear why
> u32_min <= u32_max invariant will always hold. Same below

I agree with Andrii here.
It appears that dst_reg.{min,max} fields should be set as
{min(src.min, 0), max(src.max, 0)} for both signed and unsigned cases.
Wdyt?

> 
> > +               dst_reg->s32_min_value = min_t(s32, src_reg->s32_min_value, 0);
> > +               dst_reg->s32_max_value = max_t(s32, src_reg->s32_max_value, 0);
> > +               return;
> > +       }
> > +
> > +       /* Special case: src_reg is in range [-1, 0] */
> > +       if (src_reg->s32_min_value == -1 && src_reg->s32_max_value == 0) {
> > +               var32_off = tnum_union(dst_reg->var_off, tnum_const(0));
> > +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
> > +               dst_reg->u32_min_value = var32_off.value;
> > +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
> > +               dst_reg->s32_min_value = min_t(s32, dst_reg->s32_min_value, 0);
> > +               dst_reg->s32_max_value = max_t(s32, dst_reg->s32_max_value, 0);
> > +               return;
> > +       }
> > +
> >          /* We get our minimum from the var_off, since that's inherently
> >           * bitwise.  Our maximum is the minimum of the operands' maxima.
> >           */

[...]
Xu Kuohai April 30, 2024, 3:54 a.m. UTC | #8
On 4/30/2024 4:58 AM, Andrii Nakryiko wrote:
> On Sun, Apr 28, 2024 at 8:15 AM Xu Kuohai <xukuohai@huaweicloud.com> wrote:
>>
>> On 4/27/2024 4:36 AM, Andrii Nakryiko wrote:
>>> On Tue, Apr 23, 2024 at 7:26 PM Xu Kuohai <xukuohai@huaweicloud.com> wrote:
>>>>
>>>> On 4/24/2024 5:55 AM, Yonghong Song wrote:
>>>>>
>>>>> On 4/20/24 1:33 AM, Xu Kuohai wrote:
>>>>>> On 4/20/2024 7:00 AM, Eduard Zingerman wrote:
>>>>>>> On Thu, 2024-04-11 at 20:27 +0800, Xu Kuohai wrote:
>>>>>>>> From: Xu Kuohai <xukuohai@huawei.com>
>>>>>>>>
>>>>>>>> With lsm return value check, the no-alu32 version test_libbpf_get_fd_by_id_opts
>>>>>>>> is rejected by the verifier, and the log says:
>>>>>>>>
>>>>>>>>      0: R1=ctx() R10=fp0
>>>>>>>>      ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>>>>>>>>      0: (b7) r0 = 0                        ; R0_w=0
>>>>>>>>      1: (79) r2 = *(u64 *)(r1 +0)
>>>>>>>>      func 'bpf_lsm_bpf_map' arg0 has btf_id 916 type STRUCT 'bpf_map'
>>>>>>>>      2: R1=ctx() R2_w=trusted_ptr_bpf_map()
>>>>>>>>      ; if (map != (struct bpf_map *)&data_input) @ test_libbpf_get_fd_by_id_opts.c:29
>>>>>>>>      2: (18) r3 = 0xffff9742c0951a00       ; R3_w=map_ptr(map=data_input,ks=4,vs=4)
>>>>>>>>      4: (5d) if r2 != r3 goto pc+4         ; R2_w=trusted_ptr_bpf_map() R3_w=map_ptr(map=data_input,ks=4,vs=4)
>>>>>>>>      ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>>>>>>>>      5: (79) r0 = *(u64 *)(r1 +8)          ; R0_w=scalar() R1=ctx()
>>>>>>>>      ; if (fmode & FMODE_WRITE) @ test_libbpf_get_fd_by_id_opts.c:32
>>>>>>>>      6: (67) r0 <<= 62                     ; R0_w=scalar(smax=0x4000000000000000,umax=0xc000000000000000,smin32=0,smax32=umax32=0,var_off=(0x0; 0xc000000000000000))
>>>>>>>>      7: (c7) r0 s>>= 63                    ; R0_w=scalar(smin=smin32=-1,smax=smax32=0)
>>>>>>>>      ;  @ test_libbpf_get_fd_by_id_opts.c:0
>>>>>>>>      8: (57) r0 &= -13                     ; R0_w=scalar(smax=0x7ffffffffffffff3,umax=0xfffffffffffffff3,smax32=0x7ffffff3,umax32=0xfffffff3,var_off=(0x0; 0xfffffffffffffff3))
>>>>>>>>      ; int BPF_PROG(check_access, struct bpf_map *map, fmode_t fmode) @ test_libbpf_get_fd_by_id_opts.c:27
>>>>>>>>      9: (95) exit
>>>
>>> [...]
>>>
>>>>
>>>>        As suggested by Eduard, this patch makes a special case for source
>>>>        or destination register of '&=' operation being in range [-1, 0].
>>>>
>>>>        Meaning that one of the '&=' operands is either:
>>>>        - all ones, in which case the counterpart is the result of the operation;
>>>>        - all zeros, in which case zero is the result of the operation.
>>>>
>>>>        And MIN and MAX values could be derived based on above two observations.
>>>>
>>>>        [0] https://lore.kernel.org/bpf/e62e2971301ca7f2e9eb74fc500c520285cad8f5.camel@gmail.com/
>>>>        [1] https://github.com/llvm/llvm-project/blob/4523a267829c807f3fc8fab8e5e9613985a51565/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
>>>>
>>>>        Suggested-by: Eduard Zingerman <eddyz87@gmail.com>
>>>>        Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
>>>>
>>>> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
>>>> index 640747b53745..30c551d39329 100644
>>>> --- a/kernel/bpf/verifier.c
>>>> +++ b/kernel/bpf/verifier.c
>>>> @@ -13374,6 +13374,24 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
>>>>            dst_reg->u32_min_value = var32_off.value;
>>>>            dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
>>>>
>>>> +       /* Special case: src_reg is known and dst_reg is in range [-1, 0] */
>>>> +       if (src_known &&
>>>> +               dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0 &&
>>>> +               dst_reg->smin_value == -1 && dst_reg->smax_value == 0) {
>>>
>>> please keep if () condition aligned across multiple lines, it's super
>>> confusing this way
>>>
>>
>> OK, will update the align style
>>
>>>> +               dst_reg->s32_min_value = min_t(s32, src_reg->s32_min_value, 0);
>>>> +               dst_reg->s32_max_value = max_t(s32, src_reg->s32_min_value, 0);
>>>
>>> do we need to update tnum parts as well (or reset and re-derive, probably)?
>>>
>>> btw, can't we support src being a range here? the idea is that dst_reg
>>> either all ones or all zeros. For and it means that it either stays
>>> all zero, or will be *exactly equal* to src, right? So I think the
>>> logic would be:
>>>
>>> a) if [s32_min, s32_max] is on the same side of zero, then resulting
>>> range would be [min(s32_min, 0), max(s32_max, 0)], just like you have
>>> here
>>>
>>> b) if [s32_min, s32_max] contains zero, then resulting range will be
>>> exactly [s32_min, s32_max]
>>>
>>> Or did I make a mistake above?
>>>
>>
>> Totally agree, the AND of any set with the range [-1,0] is equivalent
>> to adding number 0 to the set!
>>
>> Based on this observation, I've rewritten the patch as follows.
>>
>> diff --git a/include/linux/tnum.h b/include/linux/tnum.h
>> index 3c13240077b8..5e795d728b9f 100644
>> --- a/include/linux/tnum.h
>> +++ b/include/linux/tnum.h
>> @@ -52,6 +52,9 @@ struct tnum tnum_mul(struct tnum a, struct tnum b);
>>    /* Return a tnum representing numbers satisfying both @a and @b */
>>    struct tnum tnum_intersect(struct tnum a, struct tnum b);
>>
>> +/* Return a tnum representing numbers satisfying either @a or @b */
>> +struct tnum tnum_union(struct tnum a, struct tnum b);
>> +
>>    /* Return @a with all but the lowest @size bytes cleared */
>>    struct tnum tnum_cast(struct tnum a, u8 size);
>>
>> diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
>> index 9dbc31b25e3d..9d4480a683ca 100644
>> --- a/kernel/bpf/tnum.c
>> +++ b/kernel/bpf/tnum.c
>> @@ -150,6 +150,29 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b)
>>           return TNUM(v & ~mu, mu);
>>    }
>>
>> +/*
>> + * Each bit has 3 states: unkown, known 0, known 1. If using x to represent
>> + * unknown state, the result of the union of two bits is as follows:
>> + *
>> + *         | x    0    1
>> + *    -----+------------
>> + *     x   | x    x    x
>> + *     0   | x    0    x
>> + *     1   | x    x    1
>> + *
>> + * For tnum a and b, only the bits that are both known 0 or known 1 in a
>> + * and b are known in the result of union a and b.
>> + */
>> +struct tnum tnum_union(struct tnum a, struct tnum b)
>> +{
>> +       u64 v0, v1, mu;
>> +
>> +       mu = a.mask | b.mask; // unkown bits either in a or b
>> +       v1 = (a.value & b.value) & ~mu; // "known 1" bits in both a and b
>> +       v0 = (~a.value & ~b.value) & ~mu; // "known 0" bits in both a and b
> 
> no C++-style comments, please
>

OK, will fix in the formal patch.

>> +       return TNUM(v1, mu | ~(v0 | v1));
>> +}
>> +
> 
> I've CC'ed Edward, hopefully he can take a look as well. Please CC him
> on future patches touching tnum as well.
> 

Sure

>>    struct tnum tnum_cast(struct tnum a, u8 size)
>>    {
>>           a.value &= (1ULL << (size * 8)) - 1;
>>    {
>>           a.value &= (1ULL << (size * 8)) - 1;
>> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
>> index 8f0f2e21699e..b69c89bc5cfc 100644
>> --- a/kernel/bpf/verifier.c
>> +++ b/kernel/bpf/verifier.c
>> @@ -13478,6 +13478,28 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
>>                   return;
>>           }
>>
>> +       /* Special case: dst_reg is in range [-1, 0] */
>> +       if (dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0) {
>> +               var32_off = tnum_union(src_reg->var_off, tnum_const(0));
>> +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
>> +               dst_reg->u32_min_value = var32_off.value;
>> +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
> 
> can you explain the logic behing u32 min/max updates, especially that
> we use completely different values for min/max and it's not clear why
> u32_min <= u32_max invariant will always hold. Same below
>

We're adding 0 to the existing range, and 0 is the smallest unsigned
number, so the resulted unsigned min can only get smaller, and the
unsigned max will not be affected. In fact, since 0 is added to the
range, var32_off.value should be 0. And since -1 is in included in
dst_reg, dst_reg->u32_max_value should be -1U, the maximum unsigned
integer. So we can just set u32_min to 0, and set u32_max to umax_val.

>> +               dst_reg->s32_min_value = min_t(s32, src_reg->s32_min_value, 0);
>> +               dst_reg->s32_max_value = max_t(s32, src_reg->s32_max_value, 0);
>> +               return;
>> +       }
>> +
>> +       /* Special case: src_reg is in range [-1, 0] */
>> +       if (src_reg->s32_min_value == -1 && src_reg->s32_max_value == 0) {
>> +               var32_off = tnum_union(dst_reg->var_off, tnum_const(0));
>> +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
>> +               dst_reg->u32_min_value = var32_off.value;
>> +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
>> +               dst_reg->s32_min_value = min_t(s32, dst_reg->s32_min_value, 0);
>> +               dst_reg->s32_max_value = max_t(s32, dst_reg->s32_max_value, 0);
>> +               return;
>> +       }
>> +
>>           /* We get our minimum from the var_off, since that's inherently
>>            * bitwise.  Our maximum is the minimum of the operands' maxima.
>>            */
>> @@ -13508,6 +13530,26 @@ static void scalar_min_max_and(struct bpf_reg_state *dst_reg,
>>                   return;
>>           }
>>
>> +       /* Special case: dst_reg is in range [-1, 0] */
>> +       if (dst_reg->smin_value == -1 && dst_reg->smax_value == 0) {
>> +               dst_reg->var_off = tnum_union(src_reg->var_off, tnum_const(0));
>> +               dst_reg->umin_value = dst_reg->var_off.value;
>> +               dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
>> +               dst_reg->smin_value = min_t(s64, src_reg->smin_value, 0);
>> +               dst_reg->smax_value = max_t(s64, src_reg->smax_value, 0);
>> +               return;
>> +       }
>> +
>> +       /* Special case: src_reg is in range [-1, 0] */
>> +       if (src_reg->smin_value == -1 && src_reg->smax_value == 0) {
>> +               dst_reg->var_off = tnum_union(dst_reg->var_off, tnum_const(0));
>> +               dst_reg->umin_value = dst_reg->var_off.value;
>> +               dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
>> +               dst_reg->smin_value = min_t(s64, dst_reg->smin_value, 0);
>> +               dst_reg->smax_value = max_t(s64, dst_reg->smax_value, 0);
>> +               return;
>> +       }
>> +
>>
>>>> +               return;
>>>> +       }
>>>> +
>>>> +       /* Special case: dst_reg is known and src_reg is in range [-1, 0] */
>>>> +       if (dst_known &&
>>>> +               src_reg->s32_min_value == -1 && src_reg->s32_max_value == 0 &&
>>>> +               src_reg->smin_value == -1 && src_reg->smax_value == 0) {
>>>> +               dst_reg->s32_min_value = min_t(s32, dst_reg->s32_min_value, 0);
>>>> +               dst_reg->s32_max_value = max_t(s32, dst_reg->s32_min_value, 0);
>>>> +               return;
>>>> +       }
>>>> +
>>>>            /* Safe to set s32 bounds by casting u32 result into s32 when u32
>>>>             * doesn't cross sign boundary. Otherwise set s32 bounds to unbounded.
>>>>             */
>>>
>>> [...]
>>>
>>
Xu Kuohai April 30, 2024, 3:56 a.m. UTC | #9
On 4/30/2024 6:18 AM, Eduard Zingerman wrote:
> On Mon, 2024-04-29 at 13:58 -0700, Andrii Nakryiko wrote:
> 
> [...]
> 
>>> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
>>> index 8f0f2e21699e..b69c89bc5cfc 100644
>>> --- a/kernel/bpf/verifier.c
>>> +++ b/kernel/bpf/verifier.c
>>> @@ -13478,6 +13478,28 @@ static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
>>>                   return;
>>>           }
>>>
>>> +       /* Special case: dst_reg is in range [-1, 0] */
>>> +       if (dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0) {
>>> +               var32_off = tnum_union(src_reg->var_off, tnum_const(0));
>>> +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
>>> +               dst_reg->u32_min_value = var32_off.value;
>>> +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
>>
>> can you explain the logic behing u32 min/max updates, especially that
>> we use completely different values for min/max and it's not clear why
>> u32_min <= u32_max invariant will always hold. Same below
> 
> I agree with Andrii here.
> It appears that dst_reg.{min,max} fields should be set as
> {min(src.min, 0), max(src.max, 0)} for both signed and unsigned cases.
> Wdyt?
>

Agree, since 0 is the minimum unsigned number, the result range is
equal to [0, src.u32_max].

>>
>>> +               dst_reg->s32_min_value = min_t(s32, src_reg->s32_min_value, 0);
>>> +               dst_reg->s32_max_value = max_t(s32, src_reg->s32_max_value, 0);
>>> +               return;
>>> +       }
>>> +
>>> +       /* Special case: src_reg is in range [-1, 0] */
>>> +       if (src_reg->s32_min_value == -1 && src_reg->s32_max_value == 0) {
>>> +               var32_off = tnum_union(dst_reg->var_off, tnum_const(0));
>>> +               dst_reg->var_off = tnum_with_subreg(dst_reg->var_off, var32_off);
>>> +               dst_reg->u32_min_value = var32_off.value;
>>> +               dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
>>> +               dst_reg->s32_min_value = min_t(s32, dst_reg->s32_min_value, 0);
>>> +               dst_reg->s32_max_value = max_t(s32, dst_reg->s32_max_value, 0);
>>> +               return;
>>> +       }
>>> +
>>>           /* We get our minimum from the var_off, since that's inherently
>>>            * bitwise.  Our maximum is the minimum of the operands' maxima.
>>>            */
> 
> [...]
diff mbox series

Patch

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 5393d576c76f..62e259f18f35 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -13369,6 +13369,29 @@  static void scalar32_min_max_and(struct bpf_reg_state *dst_reg,
 		return;
 	}
 
+	if (src_known &&
+		dst_reg->s32_min_value < 0 && dst_reg->s32_min_value >= -256 &&
+		dst_reg->s32_max_value >= 0 && dst_reg->s32_max_value <= 256 &&
+		dst_reg->s32_min_value == dst_reg->smin_value &&
+		dst_reg->s32_max_value == dst_reg->smax_value) {
+		s32 s32_min = S32_MAX;
+		s32 s32_max = S32_MIN;
+		s32 v = dst_reg->s32_min_value;
+		while (v <= dst_reg->s32_max_value) {
+			s32 w = (v & src_reg->s32_min_value);
+			if (w < s32_min)
+				s32_min = w;
+			if (w > s32_max)
+				s32_max = w;
+			v++;
+		}
+		dst_reg->s32_min_value = s32_min;
+		dst_reg->s32_max_value = s32_max;
+		dst_reg->u32_min_value = var32_off.value;
+		dst_reg->u32_max_value = min(dst_reg->u32_max_value, umax_val);
+		return;
+	}
+
 	/* We get our minimum from the var_off, since that's inherently
 	 * bitwise.  Our maximum is the minimum of the operands' maxima.
 	 */