Skip to content

[x86] Enable indirect tail calls with more arguments #137643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
98 changes: 85 additions & 13 deletions llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ namespace {
bool onlyUsesZeroFlag(SDValue Flags) const;
bool hasNoSignFlagUses(SDValue Flags) const;
bool hasNoCarryFlagUses(SDValue Flags) const;
bool checkTCRetRegUsage(SDNode *N, LoadSDNode *Load) const;
};

class X86DAGToDAGISelLegacy : public SelectionDAGISelLegacy {
Expand Down Expand Up @@ -890,27 +891,52 @@ static bool isCalleeLoad(SDValue Callee, SDValue &Chain, bool HasCallSeq) {
LD->getExtensionType() != ISD::NON_EXTLOAD)
return false;

// If the load's outgoing chain has more than one use, we can't (currently)
// move the load since we'd most likely create a loop. TODO: Maybe it could
// work if moveBelowOrigChain() updated *all* the chain users.
if (!Callee.getValue(1).hasOneUse())
return false;

// Now let's find the callseq_start.
while (HasCallSeq && Chain.getOpcode() != ISD::CALLSEQ_START) {
if (!Chain.hasOneUse())
return false;
Chain = Chain.getOperand(0);
}

if (!Chain.getNumOperands())
return false;
// Since we are not checking for AA here, conservatively abort if the chain
// writes to memory. It's not safe to move the callee (a load) across a store.
if (isa<MemSDNode>(Chain.getNode()) &&
cast<MemSDNode>(Chain.getNode())->writeMem())
while (true) {
if (!Chain.getNumOperands())
return false;

// It's not safe to move the callee (a load) across e.g. a store.
// Conservatively abort if the chain contains a node other than the ones
// below.
switch (Chain.getNode()->getOpcode()) {
case ISD::CALLSEQ_START:
case ISD::CopyToReg:
case ISD::LOAD:
break;
default:
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just allow specific nodes, and forbid anything unknown. Trying to list out every possible relevant node is guaranteed to fall out of date at some point, even if you manage to come up with a complete list.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

if (Chain.getOperand(0).getNode() == Callee.getNode())
return true;
if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor &&
Chain.getOperand(0).getValue(0).hasOneUse() &&
Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) &&
Callee.getValue(1).hasOneUse())
return true;

// Look past CopyToRegs. We only walk one path, so the chain mustn't branch.
if (Chain.getOperand(0).getOpcode() == ISD::CopyToReg &&
Chain.getOperand(0).getValue(0).hasOneUse()) {
Chain = Chain.getOperand(0);
continue;
}

return false;
if (Chain.getOperand(0).getNode() == Callee.getNode())
return true;
if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor &&
Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) &&
Callee.getValue(1).hasOneUse())
return true;
return false;
}
}

static bool isEndbrImm64(uint64_t Imm) {
Expand Down Expand Up @@ -1353,6 +1379,11 @@ void X86DAGToDAGISel::PreprocessISelDAG() {
(N->getOpcode() == X86ISD::TC_RETURN &&
(Subtarget->is64Bit() ||
!getTargetMachine().isPositionIndependent())))) {

if (N->getOpcode() == X86ISD::TC_RETURN &&
!checkTCRetRegUsage(N, nullptr))
continue;

/// Also try moving call address load from outside callseq_start to just
/// before the call to allow it to be folded.
///
Expand Down Expand Up @@ -3489,6 +3520,47 @@ static bool mayUseCarryFlag(X86::CondCode CC) {
return true;
}

bool X86DAGToDAGISel::checkTCRetRegUsage(SDNode *N, LoadSDNode *Load) const {
const X86RegisterInfo *RI = Subtarget->getRegisterInfo();
const TargetRegisterClass *TailCallGPRs = RI->getGPRsForTailCall(*MF);
unsigned MaxGPRs = TailCallGPRs->getNumRegs();
if (Subtarget->is64Bit()) {
assert(TailCallGPRs->contains(X86::RSP));
assert(TailCallGPRs->contains(X86::RIP));
MaxGPRs -= 2; // Can't use RSP or RIP for the address in general.
} else {
assert(TailCallGPRs->contains(X86::ESP));
MaxGPRs -= 1; // Can't use ESP for the address in general.
}

// The load's base and index potentially need two registers.
unsigned LoadGPRs = 2;

if (Load) {
// But not if it's loading from a frame slot or global.
// XXX: Couldn't we be indexing off of the global though?
const SDValue &BasePtr = Load->getBasePtr();
if (isa<FrameIndexSDNode>(BasePtr)) {
LoadGPRs = 0;
} else if (BasePtr->getNumOperands() &&
isa<GlobalAddressSDNode>(BasePtr->getOperand(0)))
LoadGPRs = 0;
}

unsigned TCGPRs = 0;
// X86tcret args: (*chain, ptr, imm, regs..., glue)
for (unsigned I = 3, E = N->getNumOperands(); I != E; ++I) {
if (const auto *RN = dyn_cast<RegisterSDNode>(N->getOperand(I))) {
if (!RI->isGeneralPurposeRegister(*MF, RN->getReg()))
continue;
if (++TCGPRs + LoadGPRs > MaxGPRs)
return false;
}
}

return true;
}

/// Check whether or not the chain ending in StoreNode is suitable for doing
/// the {load; op; store} to modify transformation.
static bool isFusableLoadOpStorePattern(StoreSDNode *StoreNode,
Expand Down
19 changes: 2 additions & 17 deletions llvm/lib/Target/X86/X86InstrFragments.td
Original file line number Diff line number Diff line change
Expand Up @@ -675,27 +675,12 @@ def X86lock_sub_nocf : PatFrag<(ops node:$lhs, node:$rhs),

def X86tcret_6regs : PatFrag<(ops node:$ptr, node:$off),
(X86tcret node:$ptr, node:$off), [{
// X86tcret args: (*chain, ptr, imm, regs..., glue)
unsigned NumRegs = 0;
for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i)
if (isa<RegisterSDNode>(N->getOperand(i)) && ++NumRegs > 6)
return false;
return true;
return checkTCRetRegUsage(N, nullptr);
}]>;

def X86tcret_1reg : PatFrag<(ops node:$ptr, node:$off),
(X86tcret node:$ptr, node:$off), [{
// X86tcret args: (*chain, ptr, imm, regs..., glue)
unsigned NumRegs = 1;
const SDValue& BasePtr = cast<LoadSDNode>(N->getOperand(1))->getBasePtr();
if (isa<FrameIndexSDNode>(BasePtr))
NumRegs = 3;
else if (BasePtr->getNumOperands() && isa<GlobalAddressSDNode>(BasePtr->getOperand(0)))
NumRegs = 3;
for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i)
if (isa<RegisterSDNode>(N->getOperand(i)) && ( NumRegs-- == 0))
return false;
return true;
return checkTCRetRegUsage(N, cast<LoadSDNode>(N->getOperand(1)));
}]>;

// If this is an anyext of the remainder of an 8-bit sdivrem, use a MOVSX
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/CodeGen/X86/cfguard-checks.ll
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ entry:
; X64-LABEL: vmptr_thunk:
; X64: movq (%rcx), %rax
; X64-NEXT: movq 8(%rax), %rax
; X64-NEXT: movq __guard_dispatch_icall_fptr(%rip), %rdx
; X64-NEXT: rex64 jmpq *%rdx # TAILCALL
; X64-NEXT: rex64 jmpq *__guard_dispatch_icall_fptr(%rip) # TAILCALL
; X64-NOT: callq
}

Expand Down
26 changes: 26 additions & 0 deletions llvm/test/CodeGen/X86/fold-call-4.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu | FileCheck %s --check-prefix=LIN
; RUN: llc < %s -mtriple=x86_64-pc-windows-msvc | FileCheck %s --check-prefix=WIN

; The callee address computation should get folded into the call.
; CHECK-LABEL: f:
; CHECK-NOT: mov
; LIN: jmpq *(%rdi,%rsi,8)
; WIN: rex64 jmpq *(%rcx,%rdx,8)
define void @f(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3) {
entry:
%arrayidx = getelementptr inbounds ptr, ptr %table, i64 %idx
%funcptr = load ptr, ptr %arrayidx, align 8
tail call void %funcptr(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3)
ret void
}

; Check that we don't assert here. On Win64 this has a TokenFactor with
; multiple uses, which we can't currently fold.
define void @thunk(ptr %this, ...) {
entry:
%vtable = load ptr, ptr %this, align 8
%vfn = getelementptr inbounds nuw i8, ptr %vtable, i64 8
%0 = load ptr, ptr %vfn, align 8
musttail call void (ptr, ...) %0(ptr %this, ...)
ret void
}
12 changes: 12 additions & 0 deletions llvm/test/CodeGen/X86/fold-call.ll
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,15 @@ entry:
tail call void %0()
ret void
}

; Don't fold the load+call if there's inline asm in between.
; CHECK: test3
; CHECK: mov{{.*}}
; CHECK: jmp{{.*}}
define void @test3(ptr nocapture %x) {
entry:
%0 = load ptr, ptr %x
call void asm sideeffect "", ""() ; It could do anything.
tail call void %0()
ret void
}