Commit 93c5aecc authored by Gary Lin's avatar Gary Lin Committed by Alexei Starovoitov

bpf,x64: Pad NOPs to make images converge more easily

The x64 bpf jit expects bpf images converge within the given passes, but
it could fail to do so with some corner cases. For example:

      l0:     ja 40
      l1:     ja 40

        [... repeated ja 40 ]

      l39:    ja 40
      l40:    ret #0

This bpf program contains 40 "ja 40" instructions which are effectively
NOPs and designed to be replaced with valid code dynamically. Ideally,
bpf jit should optimize those "ja 40" instructions out when translating
the bpf instructions into x64 machine code. However, do_jit() can only
remove one "ja 40" for offset==0 on each pass, so it requires at least
40 runs to eliminate those JMPs and exceeds the current limit of
passes(20). In the end, the program got rejected when BPF_JIT_ALWAYS_ON
is set even though it's legit as a classic socket filter.

To make bpf images more likely converge within 20 passes, this commit
pads some instructions with NOPs in the last 5 passes:

1. conditional jumps
  A possible size variance comes from the adoption of imm8 JMP. If the
  offset is imm8, we calculate the size difference of this BPF instruction
  between the previous and the current pass and fill the gap with NOPs.
  To avoid the recalculation of jump offset, those NOPs are inserted before
  the JMP code, so we have to subtract the 2 bytes of imm8 JMP when
  calculating the NOP number.

2. BPF_JA
  There are two conditions for BPF_JA.
  a.) nop jumps
    If this instruction is not optimized out in the previous pass,
    instead of removing it, we insert the equivalent size of NOPs.
  b.) label jumps
    Similar to condition jumps, we prepend NOPs right before the JMP
    code.

To make the code concise, emit_nops() is modified to use the signed len and
return the number of inserted NOPs.

For bpf-to-bpf, we always enable padding for the extra pass since there
is only one extra run and the jump padding doesn't affected the images
that converge without padding.

After applying this patch, the corner case was loaded with the following
jit code:

    flen=45 proglen=77 pass=17 image=ffffffffc03367d4 from=jump pid=10097
    JIT code: 00000000: 0f 1f 44 00 00 55 48 89 e5 53 41 55 31 c0 45 31
    JIT code: 00000010: ed 48 89 fb eb 30 eb 2e eb 2c eb 2a eb 28 eb 26
    JIT code: 00000020: eb 24 eb 22 eb 20 eb 1e eb 1c eb 1a eb 18 eb 16
    JIT code: 00000030: eb 14 eb 12 eb 10 eb 0e eb 0c eb 0a eb 08 eb 06
    JIT code: 00000040: eb 04 eb 02 66 90 31 c0 41 5d 5b c9 c3

     0: 0f 1f 44 00 00          nop    DWORD PTR [rax+rax*1+0x0]
     5: 55                      push   rbp
     6: 48 89 e5                mov    rbp,rsp
     9: 53                      push   rbx
     a: 41 55                   push   r13
     c: 31 c0                   xor    eax,eax
     e: 45 31 ed                xor    r13d,r13d
    11: 48 89 fb                mov    rbx,rdi
    14: eb 30                   jmp    0x46
    16: eb 2e                   jmp    0x46
        ...
    3e: eb 06                   jmp    0x46
    40: eb 04                   jmp    0x46
    42: eb 02                   jmp    0x46
    44: 66 90                   xchg   ax,ax
    46: 31 c0                   xor    eax,eax
    48: 41 5d                   pop    r13
    4a: 5b                      pop    rbx
    4b: c9                      leave
    4c: c3                      ret

At the 16th pass, 15 jumps were already optimized out, and one jump was
replaced with NOPs at 44 and the image converged at the 17th pass.

v4:
  - Add the detailed comments about the possible padding bytes

v3:
  - Copy the instructions of prologue separately or the size calculation
    of the first BPF instruction would include the prologue.
  - Replace WARN_ONCE() with pr_err() and EFAULT
  - Use MAX_PASSES in the for loop condition check
  - Remove the "padded" flag from x64_jit_data. For the extra pass of
    subprogs, padding is always enabled since it won't hurt the images
    that converge without padding.

v2:
  - Simplify the sample code in the description and provide the jit code
  - Check the expected padding bytes with WARN_ONCE
  - Move the 'padded' flag to 'struct x64_jit_data'
Signed-off-by: default avatarGary Lin <glin@suse.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210119102501.511-2-glin@suse.com
parent d2e04b9d
...@@ -869,8 +869,31 @@ static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt, ...@@ -869,8 +869,31 @@ static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
} }
} }
static int emit_nops(u8 **pprog, int len)
{
u8 *prog = *pprog;
int i, noplen, cnt = 0;
while (len > 0) {
noplen = len;
if (noplen > ASM_NOP_MAX)
noplen = ASM_NOP_MAX;
for (i = 0; i < noplen; i++)
EMIT1(ideal_nops[noplen][i]);
len -= noplen;
}
*pprog = prog;
return cnt;
}
#define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
int oldproglen, struct jit_context *ctx) int oldproglen, struct jit_context *ctx, bool jmp_padding)
{ {
bool tail_call_reachable = bpf_prog->aux->tail_call_reachable; bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
struct bpf_insn *insn = bpf_prog->insnsi; struct bpf_insn *insn = bpf_prog->insnsi;
...@@ -880,7 +903,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, ...@@ -880,7 +903,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
bool seen_exit = false; bool seen_exit = false;
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY]; u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
int i, cnt = 0, excnt = 0; int i, cnt = 0, excnt = 0;
int proglen = 0; int ilen, proglen = 0;
u8 *prog = temp; u8 *prog = temp;
int err; int err;
...@@ -894,7 +917,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, ...@@ -894,7 +917,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
bpf_prog_was_classic(bpf_prog), tail_call_reachable, bpf_prog_was_classic(bpf_prog), tail_call_reachable,
bpf_prog->aux->func_idx != 0); bpf_prog->aux->func_idx != 0);
push_callee_regs(&prog, callee_regs_used); push_callee_regs(&prog, callee_regs_used);
addrs[0] = prog - temp;
ilen = prog - temp;
if (image)
memcpy(image + proglen, temp, ilen);
proglen += ilen;
addrs[0] = proglen;
prog = temp;
for (i = 1; i <= insn_cnt; i++, insn++) { for (i = 1; i <= insn_cnt; i++, insn++) {
const s32 imm32 = insn->imm; const s32 imm32 = insn->imm;
...@@ -903,8 +932,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, ...@@ -903,8 +932,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
u8 b2 = 0, b3 = 0; u8 b2 = 0, b3 = 0;
s64 jmp_offset; s64 jmp_offset;
u8 jmp_cond; u8 jmp_cond;
int ilen;
u8 *func; u8 *func;
int nops;
switch (insn->code) { switch (insn->code) {
/* ALU */ /* ALU */
...@@ -1502,6 +1531,30 @@ st: if (is_imm8(insn->off)) ...@@ -1502,6 +1531,30 @@ st: if (is_imm8(insn->off))
} }
jmp_offset = addrs[i + insn->off] - addrs[i]; jmp_offset = addrs[i + insn->off] - addrs[i];
if (is_imm8(jmp_offset)) { if (is_imm8(jmp_offset)) {
if (jmp_padding) {
/* To keep the jmp_offset valid, the extra bytes are
* padded before the jump insn, so we substract the
* 2 bytes of jmp_cond insn from INSN_SZ_DIFF.
*
* If the previous pass already emits an imm8
* jmp_cond, then this BPF insn won't shrink, so
* "nops" is 0.
*
* On the other hand, if the previous pass emits an
* imm32 jmp_cond, the extra 4 bytes(*) is padded to
* keep the image from shrinking further.
*
* (*) imm32 jmp_cond is 6 bytes, and imm8 jmp_cond
* is 2 bytes, so the size difference is 4 bytes.
*/
nops = INSN_SZ_DIFF - 2;
if (nops != 0 && nops != 4) {
pr_err("unexpected jmp_cond padding: %d bytes\n",
nops);
return -EFAULT;
}
cnt += emit_nops(&prog, nops);
}
EMIT2(jmp_cond, jmp_offset); EMIT2(jmp_cond, jmp_offset);
} else if (is_simm32(jmp_offset)) { } else if (is_simm32(jmp_offset)) {
EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset); EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
...@@ -1524,11 +1577,55 @@ st: if (is_imm8(insn->off)) ...@@ -1524,11 +1577,55 @@ st: if (is_imm8(insn->off))
else else
jmp_offset = addrs[i + insn->off] - addrs[i]; jmp_offset = addrs[i + insn->off] - addrs[i];
if (!jmp_offset) if (!jmp_offset) {
/* Optimize out nop jumps */ /*
* If jmp_padding is enabled, the extra nops will
* be inserted. Otherwise, optimize out nop jumps.
*/
if (jmp_padding) {
/* There are 3 possible conditions.
* (1) This BPF_JA is already optimized out in
* the previous run, so there is no need
* to pad any extra byte (0 byte).
* (2) The previous pass emits an imm8 jmp,
* so we pad 2 bytes to match the previous
* insn size.
* (3) Similarly, the previous pass emits an
* imm32 jmp, and 5 bytes is padded.
*/
nops = INSN_SZ_DIFF;
if (nops != 0 && nops != 2 && nops != 5) {
pr_err("unexpected nop jump padding: %d bytes\n",
nops);
return -EFAULT;
}
cnt += emit_nops(&prog, nops);
}
break; break;
}
emit_jmp: emit_jmp:
if (is_imm8(jmp_offset)) { if (is_imm8(jmp_offset)) {
if (jmp_padding) {
/* To avoid breaking jmp_offset, the extra bytes
* are padded before the actual jmp insn, so
* 2 bytes is substracted from INSN_SZ_DIFF.
*
* If the previous pass already emits an imm8
* jmp, there is nothing to pad (0 byte).
*
* If it emits an imm32 jmp (5 bytes) previously
* and now an imm8 jmp (2 bytes), then we pad
* (5 - 2 = 3) bytes to stop the image from
* shrinking further.
*/
nops = INSN_SZ_DIFF - 2;
if (nops != 0 && nops != 3) {
pr_err("unexpected jump padding: %d bytes\n",
nops);
return -EFAULT;
}
cnt += emit_nops(&prog, INSN_SZ_DIFF - 2);
}
EMIT2(0xEB, jmp_offset); EMIT2(0xEB, jmp_offset);
} else if (is_simm32(jmp_offset)) { } else if (is_simm32(jmp_offset)) {
EMIT1_off32(0xE9, jmp_offset); EMIT1_off32(0xE9, jmp_offset);
...@@ -1671,26 +1768,6 @@ static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog, ...@@ -1671,26 +1768,6 @@ static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
return 0; return 0;
} }
static void emit_nops(u8 **pprog, unsigned int len)
{
unsigned int i, noplen;
u8 *prog = *pprog;
int cnt = 0;
while (len > 0) {
noplen = len;
if (noplen > ASM_NOP_MAX)
noplen = ASM_NOP_MAX;
for (i = 0; i < noplen; i++)
EMIT1(ideal_nops[noplen][i]);
len -= noplen;
}
*pprog = prog;
}
static void emit_align(u8 **pprog, u32 align) static void emit_align(u8 **pprog, u32 align)
{ {
u8 *target, *prog = *pprog; u8 *target, *prog = *pprog;
...@@ -2065,6 +2142,9 @@ struct x64_jit_data { ...@@ -2065,6 +2142,9 @@ struct x64_jit_data {
struct jit_context ctx; struct jit_context ctx;
}; };
#define MAX_PASSES 20
#define PADDING_PASSES (MAX_PASSES - 5)
struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
{ {
struct bpf_binary_header *header = NULL; struct bpf_binary_header *header = NULL;
...@@ -2074,6 +2154,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -2074,6 +2154,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
struct jit_context ctx = {}; struct jit_context ctx = {};
bool tmp_blinded = false; bool tmp_blinded = false;
bool extra_pass = false; bool extra_pass = false;
bool padding = false;
u8 *image = NULL; u8 *image = NULL;
int *addrs; int *addrs;
int pass; int pass;
...@@ -2110,6 +2191,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -2110,6 +2191,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
image = jit_data->image; image = jit_data->image;
header = jit_data->header; header = jit_data->header;
extra_pass = true; extra_pass = true;
padding = true;
goto skip_init_addrs; goto skip_init_addrs;
} }
addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL); addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
...@@ -2135,8 +2217,10 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -2135,8 +2217,10 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
* may converge on the last pass. In such case do one more * may converge on the last pass. In such case do one more
* pass to emit the final image. * pass to emit the final image.
*/ */
for (pass = 0; pass < 20 || image; pass++) { for (pass = 0; pass < MAX_PASSES || image; pass++) {
proglen = do_jit(prog, addrs, image, oldproglen, &ctx); if (!padding && pass >= PADDING_PASSES)
padding = true;
proglen = do_jit(prog, addrs, image, oldproglen, &ctx, padding);
if (proglen <= 0) { if (proglen <= 0) {
out_image: out_image:
image = NULL; image = NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment