Commit ebf7d1f5 authored by Maciej Fijalkowski's avatar Maciej Fijalkowski Committed by Alexei Starovoitov

bpf, x64: rework pro/epilogue and tailcall handling in JIT

This commit serves two things:
1) it optimizes BPF prologue/epilogue generation
2) it makes possible to have tailcalls within BPF subprogram

Both points are related to each other since without 1), 2) could not be
achieved.

In [1], Alexei says:
"The prologue will look like:
nop5
xor eax,eax  // two new bytes if bpf_tail_call() is used in this
             // function
push rbp
mov rbp, rsp
sub rsp, rounded_stack_depth
push rax // zero init tail_call counter
variable number of push rbx,r13,r14,r15

Then bpf_tail_call will pop variable number rbx,..
and final 'pop rax'
Then 'add rsp, size_of_current_stack_frame'
jmp to next function and skip over 'nop5; xor eax,eax; push rpb; mov
rbp, rsp'

This way new function will set its own stack size and will init tail
call
counter with whatever value the parent had.

If next function doesn't use bpf_tail_call it won't have 'xor eax,eax'.
Instead it would need to have 'nop2' in there."

Implement that suggestion.

Since the layout of stack is changed, tail call counter handling can not
rely anymore on popping it to rbx just like it have been handled for
constant prologue case and later overwrite of rbx with actual value of
rbx pushed to stack. Therefore, let's use one of the register (%rcx) that
is considered to be volatile/caller-saved and pop the value of tail call
counter in there in the epilogue.

Drop the BUILD_BUG_ON in emit_prologue and in
emit_bpf_tail_call_indirect where instruction layout is not constant
anymore.

Introduce new poke target, 'tailcall_bypass' to poke descriptor that is
dedicated for skipping the register pops and stack unwind that are
generated right before the actual jump to target program.
For case when the target program is not present, BPF program will skip
the pop instructions and nop5 dedicated for jmpq $target. An example of
such state when only R6 of callee saved registers is used by program:

ffffffffc0513aa1:       e9 0e 00 00 00          jmpq   0xffffffffc0513ab4
ffffffffc0513aa6:       5b                      pop    %rbx
ffffffffc0513aa7:       58                      pop    %rax
ffffffffc0513aa8:       48 81 c4 00 00 00 00    add    $0x0,%rsp
ffffffffc0513aaf:       0f 1f 44 00 00          nopl   0x0(%rax,%rax,1)
ffffffffc0513ab4:       48 89 df                mov    %rbx,%rdi

When target program is inserted, the jump that was there to skip
pops/nop5 will become the nop5, so CPU will go over pops and do the
actual tailcall.

One might ask why there simply can not be pushes after the nop5?
In the following example snippet:

ffffffffc037030c:       48 89 fb                mov    %rdi,%rbx
(...)
ffffffffc0370332:       5b                      pop    %rbx
ffffffffc0370333:       58                      pop    %rax
ffffffffc0370334:       48 81 c4 00 00 00 00    add    $0x0,%rsp
ffffffffc037033b:       0f 1f 44 00 00          nopl   0x0(%rax,%rax,1)
ffffffffc0370340:       48 81 ec 00 00 00 00    sub    $0x0,%rsp
ffffffffc0370347:       50                      push   %rax
ffffffffc0370348:       53                      push   %rbx
ffffffffc0370349:       48 89 df                mov    %rbx,%rdi
ffffffffc037034c:       e8 f7 21 00 00          callq  0xffffffffc0372548

There is the bpf2bpf call (at ffffffffc037034c) right after the tailcall
and jump target is not present. ctx is in %rbx register and BPF
subprogram that we will call into on ffffffffc037034c is relying on it,
e.g. it will pick ctx from there. Such code layout is therefore broken
as we would overwrite the content of %rbx with the value that was pushed
on the prologue. That is the reason for the 'bypass' approach.

Special care needs to be taken during the install/update/remove of
tailcall target. In case when target program is not present, the CPU
must not execute the pop instructions that precede the tailcall.

To address that, the following states can be defined:
A nop, unwind, nop
B nop, unwind, tail
C skip, unwind, nop
D skip, unwind, tail

A is forbidden (lead to incorrectness). The state transitions between
tailcall install/update/remove will work as follows:

First install tail call f: C->D->B(f)
 * poke the tailcall, after that get rid of the skip
Update tail call f to f': B(f)->B(f')
 * poke the tailcall (poke->tailcall_target) and do NOT touch the
   poke->tailcall_bypass
Remove tail call: B(f')->C(f')
 * poke->tailcall_bypass is poked back to jump, then we wait the RCU
   grace period so that other programs will finish its execution and
   after that we are safe to remove the poke->tailcall_target
Install new tail call (f''): C(f')->D(f'')->B(f'').
 * same as first step

This way CPU can never be exposed to "unwind, tail" state.

Last but not least, when tailcalls get mixed with bpf2bpf calls, it
would be possible to encounter the endless loop due to clearing the
tailcall counter if for example we would use the tailcall3-like from BPF
selftests program that would be subprogram-based, meaning the tailcall
would be present within the BPF subprogram.

This test, broken down to particular steps, would do:
entry -> set tailcall counter to 0, bump it by 1, tailcall to func0
func0 -> call subprog_tail
(we are NOT skipping the first 11 bytes of prologue and this subprogram
has a tailcall, therefore we clear the counter...)
subprog -> do the same thing as entry

and then loop forever.

To address this, the idea is to go through the call chain of bpf2bpf progs
and look for a tailcall presence throughout whole chain. If we saw a single
tail call then each node in this call chain needs to be marked as a subprog
that can reach the tailcall. We would later feed the JIT with this info
and:
- set eax to 0 only when tailcall is reachable and this is the entry prog
- if tailcall is reachable but there's no tailcall in insns of currently
  JITed prog then push rax anyway, so that it will be possible to
  propagate further down the call chain
- finally if tailcall is reachable, then we need to precede the 'call'
  insn with mov rax, [rbp - (stack_depth + 8)]

Tail call related cases from test_verifier kselftest are also working
fine. Sample BPF programs that utilize tail calls (sockex3, tracex5)
work properly as well.

[1]: https://lore.kernel.org/bpf/20200517043227.2gpq22ifoq37ogst@ast-mbp.dhcp.thefacebook.com/Suggested-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarMaciej Fijalkowski <maciej.fijalkowski@intel.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 7f6e4312
...@@ -221,14 +221,48 @@ struct jit_context { ...@@ -221,14 +221,48 @@ struct jit_context {
/* Number of bytes emit_patch() needs to generate instructions */ /* Number of bytes emit_patch() needs to generate instructions */
#define X86_PATCH_SIZE 5 #define X86_PATCH_SIZE 5
/* Number of bytes that will be skipped on tailcall */
#define X86_TAIL_CALL_OFFSET 11
#define PROLOGUE_SIZE 25 static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
{
u8 *prog = *pprog;
int cnt = 0;
if (callee_regs_used[0])
EMIT1(0x53); /* push rbx */
if (callee_regs_used[1])
EMIT2(0x41, 0x55); /* push r13 */
if (callee_regs_used[2])
EMIT2(0x41, 0x56); /* push r14 */
if (callee_regs_used[3])
EMIT2(0x41, 0x57); /* push r15 */
*pprog = prog;
}
static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
{
u8 *prog = *pprog;
int cnt = 0;
if (callee_regs_used[3])
EMIT2(0x41, 0x5F); /* pop r15 */
if (callee_regs_used[2])
EMIT2(0x41, 0x5E); /* pop r14 */
if (callee_regs_used[1])
EMIT2(0x41, 0x5D); /* pop r13 */
if (callee_regs_used[0])
EMIT1(0x5B); /* pop rbx */
*pprog = prog;
}
/* /*
* Emit x86-64 prologue code for BPF program and check its size. * Emit x86-64 prologue code for BPF program.
* bpf_tail_call helper will skip it while jumping into another program * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
* while jumping to another program
*/ */
static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
bool tail_call_reachable, bool is_subprog)
{ {
u8 *prog = *pprog; u8 *prog = *pprog;
int cnt = X86_PATCH_SIZE; int cnt = X86_PATCH_SIZE;
...@@ -238,19 +272,18 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) ...@@ -238,19 +272,18 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
*/ */
memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt); memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt);
prog += cnt; prog += cnt;
if (!ebpf_from_cbpf) {
if (tail_call_reachable && !is_subprog)
EMIT2(0x31, 0xC0); /* xor eax, eax */
else
EMIT2(0x66, 0x90); /* nop2 */
}
EMIT1(0x55); /* push rbp */ EMIT1(0x55); /* push rbp */
EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
/* sub rsp, rounded_stack_depth */ /* sub rsp, rounded_stack_depth */
EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
EMIT1(0x53); /* push rbx */ if (tail_call_reachable)
EMIT2(0x41, 0x55); /* push r13 */ EMIT1(0x50); /* push rax */
EMIT2(0x41, 0x56); /* push r14 */
EMIT2(0x41, 0x57); /* push r15 */
if (!ebpf_from_cbpf) {
/* zero init tail_call_cnt */
EMIT2(0x6a, 0x00);
BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
}
*pprog = prog; *pprog = prog;
} }
...@@ -314,13 +347,14 @@ static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, ...@@ -314,13 +347,14 @@ static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
mutex_lock(&text_mutex); mutex_lock(&text_mutex);
if (memcmp(ip, old_insn, X86_PATCH_SIZE)) if (memcmp(ip, old_insn, X86_PATCH_SIZE))
goto out; goto out;
ret = 1;
if (memcmp(ip, new_insn, X86_PATCH_SIZE)) { if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
if (text_live) if (text_live)
text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
else else
memcpy(ip, new_insn, X86_PATCH_SIZE); memcpy(ip, new_insn, X86_PATCH_SIZE);
ret = 0;
} }
ret = 0;
out: out:
mutex_unlock(&text_mutex); mutex_unlock(&text_mutex);
return ret; return ret;
...@@ -337,6 +371,22 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, ...@@ -337,6 +371,22 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true); return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
} }
static int get_pop_bytes(bool *callee_regs_used)
{
int bytes = 0;
if (callee_regs_used[3])
bytes += 2;
if (callee_regs_used[2])
bytes += 2;
if (callee_regs_used[1])
bytes += 2;
if (callee_regs_used[0])
bytes += 1;
return bytes;
}
/* /*
* Generate the following code: * Generate the following code:
* *
...@@ -351,12 +401,26 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, ...@@ -351,12 +401,26 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
* goto *(prog->bpf_func + prologue_size); * goto *(prog->bpf_func + prologue_size);
* out: * out:
*/ */
static void emit_bpf_tail_call_indirect(u8 **pprog) static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
u32 stack_depth)
{ {
int tcc_off = -4 - round_up(stack_depth, 8);
u8 *prog = *pprog; u8 *prog = *pprog;
int label1, label2, label3; int pop_bytes = 0;
int off1 = 49;
int off2 = 38;
int off3 = 16;
int cnt = 0; int cnt = 0;
/* count the additional bytes used for popping callee regs from stack
* that need to be taken into account for each of the offsets that
* are used for bailing out of the tail call
*/
pop_bytes = get_pop_bytes(callee_regs_used);
off1 += pop_bytes;
off2 += pop_bytes;
off3 += pop_bytes;
/* /*
* rdi - pointer to ctx * rdi - pointer to ctx
* rsi - pointer to bpf_array * rsi - pointer to bpf_array
...@@ -370,21 +434,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog) ...@@ -370,21 +434,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog)
EMIT2(0x89, 0xD2); /* mov edx, edx */ EMIT2(0x89, 0xD2); /* mov edx, edx */
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */ EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
offsetof(struct bpf_array, map.max_entries)); offsetof(struct bpf_array, map.max_entries));
#define OFFSET1 (41 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */ #define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
EMIT2(X86_JBE, OFFSET1); /* jbe out */ EMIT2(X86_JBE, OFFSET1); /* jbe out */
label1 = cnt;
/* /*
* if (tail_call_cnt > MAX_TAIL_CALL_CNT) * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out; * goto out;
*/ */
EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
#define OFFSET2 (30 + RETPOLINE_RCX_BPF_JIT_SIZE) #define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
EMIT2(X86_JA, OFFSET2); /* ja out */ EMIT2(X86_JA, OFFSET2); /* ja out */
label2 = cnt;
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
/* prog = array->ptrs[index]; */ /* prog = array->ptrs[index]; */
EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
...@@ -394,48 +456,84 @@ static void emit_bpf_tail_call_indirect(u8 **pprog) ...@@ -394,48 +456,84 @@ static void emit_bpf_tail_call_indirect(u8 **pprog)
* if (prog == NULL) * if (prog == NULL)
* goto out; * goto out;
*/ */
EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */ EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
#define OFFSET3 (8 + RETPOLINE_RCX_BPF_JIT_SIZE) #define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
EMIT2(X86_JE, OFFSET3); /* je out */ EMIT2(X86_JE, OFFSET3); /* je out */
label3 = cnt;
/* goto *(prog->bpf_func + prologue_size); */ *pprog = prog;
pop_callee_regs(pprog, callee_regs_used);
prog = *pprog;
EMIT1(0x58); /* pop rax */
EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */
round_up(stack_depth, 8));
/* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
EMIT4(0x48, 0x8B, 0x49, /* mov rcx, qword ptr [rcx + 32] */ EMIT4(0x48, 0x8B, 0x49, /* mov rcx, qword ptr [rcx + 32] */
offsetof(struct bpf_prog, bpf_func)); offsetof(struct bpf_prog, bpf_func));
EMIT4(0x48, 0x83, 0xC1, PROLOGUE_SIZE); /* add rcx, prologue_size */ EMIT4(0x48, 0x83, 0xC1, /* add rcx, X86_TAIL_CALL_OFFSET */
X86_TAIL_CALL_OFFSET);
/* /*
* Now we're ready to jump into next BPF program * Now we're ready to jump into next BPF program
* rdi == ctx (1st arg) * rdi == ctx (1st arg)
* rcx == prog->bpf_func + prologue_size * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
*/ */
RETPOLINE_RCX_BPF_JIT(); RETPOLINE_RCX_BPF_JIT();
/* out: */ /* out: */
BUILD_BUG_ON(cnt - label1 != OFFSET1);
BUILD_BUG_ON(cnt - label2 != OFFSET2);
BUILD_BUG_ON(cnt - label3 != OFFSET3);
*pprog = prog; *pprog = prog;
} }
static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke, static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
u8 **pprog, int addr, u8 *image) u8 **pprog, int addr, u8 *image,
bool *callee_regs_used, u32 stack_depth)
{ {
int tcc_off = -4 - round_up(stack_depth, 8);
u8 *prog = *pprog; u8 *prog = *pprog;
int pop_bytes = 0;
int off1 = 27;
int poke_off;
int cnt = 0; int cnt = 0;
/* count the additional bytes used for popping callee regs to stack
* that need to be taken into account for jump offset that is used for
* bailing out from of the tail call when limit is reached
*/
pop_bytes = get_pop_bytes(callee_regs_used);
off1 += pop_bytes;
/*
* total bytes for:
* - nop5/ jmpq $off
* - pop callee regs
* - sub rsp, $val
* - pop rax
*/
poke_off = X86_PATCH_SIZE + pop_bytes + 7 + 1;
/* /*
* if (tail_call_cnt > MAX_TAIL_CALL_CNT) * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out; * goto out;
*/ */
EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
EMIT2(X86_JA, 14); /* ja out */ EMIT2(X86_JA, off1); /* ja out */
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE);
poke->adj_off = X86_TAIL_CALL_OFFSET;
poke->tailcall_target = image + (addr - X86_PATCH_SIZE); poke->tailcall_target = image + (addr - X86_PATCH_SIZE);
poke->adj_off = PROLOGUE_SIZE; poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
poke->tailcall_bypass);
*pprog = prog;
pop_callee_regs(pprog, callee_regs_used);
prog = *pprog;
EMIT1(0x58); /* pop rax */
EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE); memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
prog += X86_PATCH_SIZE; prog += X86_PATCH_SIZE;
...@@ -476,6 +574,11 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog) ...@@ -476,6 +574,11 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
(u8 *)target->bpf_func + (u8 *)target->bpf_func +
poke->adj_off, false); poke->adj_off, false);
BUG_ON(ret < 0); BUG_ON(ret < 0);
ret = __bpf_arch_text_poke(poke->tailcall_bypass,
BPF_MOD_JUMP,
(u8 *)poke->tailcall_target +
X86_PATCH_SIZE, NULL, false);
BUG_ON(ret < 0);
} }
WRITE_ONCE(poke->tailcall_target_stable, true); WRITE_ONCE(poke->tailcall_target_stable, true);
mutex_unlock(&array->aux->poke_mutex); mutex_unlock(&array->aux->poke_mutex);
...@@ -654,19 +757,49 @@ static bool ex_handler_bpf(const struct exception_table_entry *x, ...@@ -654,19 +757,49 @@ static bool ex_handler_bpf(const struct exception_table_entry *x,
return true; return true;
} }
static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
bool *regs_used, bool *tail_call_seen)
{
int i;
for (i = 1; i <= insn_cnt; i++, insn++) {
if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
*tail_call_seen = true;
if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
regs_used[0] = true;
if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
regs_used[1] = true;
if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
regs_used[2] = true;
if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
regs_used[3] = true;
}
}
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
int oldproglen, struct jit_context *ctx) int oldproglen, struct jit_context *ctx)
{ {
bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
struct bpf_insn *insn = bpf_prog->insnsi; struct bpf_insn *insn = bpf_prog->insnsi;
bool callee_regs_used[4] = {};
int insn_cnt = bpf_prog->len; int insn_cnt = bpf_prog->len;
bool tail_call_seen = false;
bool seen_exit = false; bool seen_exit = false;
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY]; u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
int i, cnt = 0, excnt = 0; int i, cnt = 0, excnt = 0;
int proglen = 0; int proglen = 0;
u8 *prog = temp; u8 *prog = temp;
detect_reg_usage(insn, insn_cnt, callee_regs_used,
&tail_call_seen);
/* tail call's presence in current prog implies it is reachable */
tail_call_reachable |= tail_call_seen;
emit_prologue(&prog, bpf_prog->aux->stack_depth, emit_prologue(&prog, bpf_prog->aux->stack_depth,
bpf_prog_was_classic(bpf_prog)); bpf_prog_was_classic(bpf_prog), tail_call_reachable,
bpf_prog->aux->func_idx != 0);
push_callee_regs(&prog, callee_regs_used);
addrs[0] = prog - temp; addrs[0] = prog - temp;
for (i = 1; i <= insn_cnt; i++, insn++) { for (i = 1; i <= insn_cnt; i++, insn++) {
...@@ -1104,16 +1237,27 @@ xadd: if (is_imm8(insn->off)) ...@@ -1104,16 +1237,27 @@ xadd: if (is_imm8(insn->off))
/* call */ /* call */
case BPF_JMP | BPF_CALL: case BPF_JMP | BPF_CALL:
func = (u8 *) __bpf_call_base + imm32; func = (u8 *) __bpf_call_base + imm32;
if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) if (tail_call_reachable) {
return -EINVAL; EMIT3_off32(0x48, 0x8B, 0x85,
-(bpf_prog->aux->stack_depth + 8));
if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
return -EINVAL;
} else {
if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
return -EINVAL;
}
break; break;
case BPF_JMP | BPF_TAIL_CALL: case BPF_JMP | BPF_TAIL_CALL:
if (imm32) if (imm32)
emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1], emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
&prog, addrs[i], image); &prog, addrs[i], image,
callee_regs_used,
bpf_prog->aux->stack_depth);
else else
emit_bpf_tail_call_indirect(&prog); emit_bpf_tail_call_indirect(&prog,
callee_regs_used,
bpf_prog->aux->stack_depth);
break; break;
/* cond jump */ /* cond jump */
...@@ -1296,12 +1440,9 @@ xadd: if (is_imm8(insn->off)) ...@@ -1296,12 +1440,9 @@ xadd: if (is_imm8(insn->off))
seen_exit = true; seen_exit = true;
/* Update cleanup_addr */ /* Update cleanup_addr */
ctx->cleanup_addr = proglen; ctx->cleanup_addr = proglen;
if (!bpf_prog_was_classic(bpf_prog)) pop_callee_regs(&prog, callee_regs_used);
EMIT1(0x5B); /* get rid of tail_call_cnt */ if (tail_call_reachable)
EMIT2(0x41, 0x5F); /* pop r15 */ EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */
EMIT2(0x41, 0x5E); /* pop r14 */
EMIT2(0x41, 0x5D); /* pop r13 */
EMIT1(0x5B); /* pop rbx */
EMIT1(0xC9); /* leave */ EMIT1(0xC9); /* leave */
EMIT1(0xC3); /* ret */ EMIT1(0xC3); /* ret */
break; break;
......
...@@ -698,6 +698,8 @@ enum bpf_jit_poke_reason { ...@@ -698,6 +698,8 @@ enum bpf_jit_poke_reason {
/* Descriptor of pokes pointing /into/ the JITed image. */ /* Descriptor of pokes pointing /into/ the JITed image. */
struct bpf_jit_poke_descriptor { struct bpf_jit_poke_descriptor {
void *tailcall_target; void *tailcall_target;
void *tailcall_bypass;
void *bypass_addr;
union { union {
struct { struct {
struct bpf_map *map; struct bpf_map *map;
...@@ -738,6 +740,7 @@ struct bpf_prog_aux { ...@@ -738,6 +740,7 @@ struct bpf_prog_aux {
bool attach_btf_trace; /* true if attaching to BTF-enabled raw tp */ bool attach_btf_trace; /* true if attaching to BTF-enabled raw tp */
bool func_proto_unreliable; bool func_proto_unreliable;
bool sleepable; bool sleepable;
bool tail_call_reachable;
enum bpf_tramp_prog_type trampoline_prog_type; enum bpf_tramp_prog_type trampoline_prog_type;
struct bpf_trampoline *trampoline; struct bpf_trampoline *trampoline;
struct hlist_node tramp_hlist; struct hlist_node tramp_hlist;
......
...@@ -359,6 +359,7 @@ struct bpf_subprog_info { ...@@ -359,6 +359,7 @@ struct bpf_subprog_info {
u32 linfo_idx; /* The idx to the main_prog->aux->linfo */ u32 linfo_idx; /* The idx to the main_prog->aux->linfo */
u16 stack_depth; /* max. stack depth used by this function */ u16 stack_depth; /* max. stack depth used by this function */
bool has_tail_call; bool has_tail_call;
bool tail_call_reachable;
}; };
/* single container for all structs /* single container for all structs
......
...@@ -898,6 +898,7 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key, ...@@ -898,6 +898,7 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key,
struct bpf_prog *old, struct bpf_prog *old,
struct bpf_prog *new) struct bpf_prog *new)
{ {
u8 *old_addr, *new_addr, *old_bypass_addr;
struct prog_poke_elem *elem; struct prog_poke_elem *elem;
struct bpf_array_aux *aux; struct bpf_array_aux *aux;
...@@ -949,12 +950,39 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key, ...@@ -949,12 +950,39 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key,
poke->tail_call.key != key) poke->tail_call.key != key)
continue; continue;
ret = bpf_arch_text_poke(poke->tailcall_target, BPF_MOD_JUMP, old_bypass_addr = old ? NULL : poke->bypass_addr;
old ? (u8 *)old->bpf_func + old_addr = old ? (u8 *)old->bpf_func + poke->adj_off : NULL;
poke->adj_off : NULL, new_addr = new ? (u8 *)new->bpf_func + poke->adj_off : NULL;
new ? (u8 *)new->bpf_func +
poke->adj_off : NULL); if (new) {
BUG_ON(ret < 0 && ret != -EINVAL); ret = bpf_arch_text_poke(poke->tailcall_target,
BPF_MOD_JUMP,
old_addr, new_addr);
BUG_ON(ret < 0 && ret != -EINVAL);
if (!old) {
ret = bpf_arch_text_poke(poke->tailcall_bypass,
BPF_MOD_JUMP,
poke->bypass_addr,
NULL);
BUG_ON(ret < 0 && ret != -EINVAL);
}
} else {
ret = bpf_arch_text_poke(poke->tailcall_bypass,
BPF_MOD_JUMP,
old_bypass_addr,
poke->bypass_addr);
BUG_ON(ret < 0 && ret != -EINVAL);
/* let other CPUs finish the execution of program
* so that it will not possible to expose them
* to invalid nop, stack unwind, nop state
*/
if (!ret)
synchronize_rcu();
ret = bpf_arch_text_poke(poke->tailcall_target,
BPF_MOD_JUMP,
old_addr, NULL);
BUG_ON(ret < 0 && ret != -EINVAL);
}
} }
} }
} }
......
...@@ -776,7 +776,7 @@ int bpf_jit_add_poke_descriptor(struct bpf_prog *prog, ...@@ -776,7 +776,7 @@ int bpf_jit_add_poke_descriptor(struct bpf_prog *prog,
if (size > poke_tab_max) if (size > poke_tab_max)
return -ENOSPC; return -ENOSPC;
if (poke->tailcall_target || poke->tailcall_target_stable || if (poke->tailcall_target || poke->tailcall_target_stable ||
poke->adj_off) poke->tailcall_bypass || poke->adj_off || poke->bypass_addr)
return -EINVAL; return -EINVAL;
switch (poke->reason) { switch (poke->reason) {
......
...@@ -2983,8 +2983,10 @@ static int check_max_stack_depth(struct bpf_verifier_env *env) ...@@ -2983,8 +2983,10 @@ static int check_max_stack_depth(struct bpf_verifier_env *env)
int depth = 0, frame = 0, idx = 0, i = 0, subprog_end; int depth = 0, frame = 0, idx = 0, i = 0, subprog_end;
struct bpf_subprog_info *subprog = env->subprog_info; struct bpf_subprog_info *subprog = env->subprog_info;
struct bpf_insn *insn = env->prog->insnsi; struct bpf_insn *insn = env->prog->insnsi;
bool tail_call_reachable = false;
int ret_insn[MAX_CALL_FRAMES]; int ret_insn[MAX_CALL_FRAMES];
int ret_prog[MAX_CALL_FRAMES]; int ret_prog[MAX_CALL_FRAMES];
int j;
process_func: process_func:
/* protect against potential stack overflow that might happen when /* protect against potential stack overflow that might happen when
...@@ -3040,6 +3042,10 @@ static int check_max_stack_depth(struct bpf_verifier_env *env) ...@@ -3040,6 +3042,10 @@ static int check_max_stack_depth(struct bpf_verifier_env *env)
i); i);
return -EFAULT; return -EFAULT;
} }
if (subprog[idx].has_tail_call)
tail_call_reachable = true;
frame++; frame++;
if (frame >= MAX_CALL_FRAMES) { if (frame >= MAX_CALL_FRAMES) {
verbose(env, "the call stack of %d frames is too deep !\n", verbose(env, "the call stack of %d frames is too deep !\n",
...@@ -3048,6 +3054,15 @@ static int check_max_stack_depth(struct bpf_verifier_env *env) ...@@ -3048,6 +3054,15 @@ static int check_max_stack_depth(struct bpf_verifier_env *env)
} }
goto process_func; goto process_func;
} }
/* if tail call got detected across bpf2bpf calls then mark each of the
* currently present subprog frames as tail call reachable subprogs;
* this info will be utilized by JIT so that we will be preserving the
* tail call counter throughout bpf2bpf calls combined with tailcalls
*/
if (tail_call_reachable)
for (j = 0; j < frame; j++)
subprog[ret_prog[j]].tail_call_reachable = true;
/* end of for() loop means the last insn of the 'subprog' /* end of for() loop means the last insn of the 'subprog'
* was reached. Doesn't matter whether it was JA or EXIT * was reached. Doesn't matter whether it was JA or EXIT
*/ */
...@@ -10322,6 +10337,7 @@ static int jit_subprogs(struct bpf_verifier_env *env) ...@@ -10322,6 +10337,7 @@ static int jit_subprogs(struct bpf_verifier_env *env)
num_exentries++; num_exentries++;
} }
func[i]->aux->num_exentries = num_exentries; func[i]->aux->num_exentries = num_exentries;
func[i]->aux->tail_call_reachable = env->subprog_info[i].tail_call_reachable;
func[i] = bpf_int_jit_compile(func[i]); func[i] = bpf_int_jit_compile(func[i]);
if (!func[i]->jited) { if (!func[i]->jited) {
err = -ENOTSUPP; err = -ENOTSUPP;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment