Skip to content

release/20.x: [AArch64][SME] Prevent spills of ZT0 when ZA is not enabled #137683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: release/20.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2818,6 +2818,9 @@ void Verifier::visitFunction(const Function &F) {
Check(!Attrs.hasAttrSomewhere(Attribute::ElementType),
"Attribute 'elementtype' can only be applied to a callsite.", &F);

Check(!Attrs.hasFnAttr("aarch64_zt0_undef"),
"Attribute 'aarch64_zt0_undef' can only be applied to a callsite.");

if (Attrs.hasFnAttr(Attribute::Naked))
for (const Argument &Arg : F.args())
Check(Arg.use_empty(), "cannot use argument of naked function", &Arg);
Expand Down
16 changes: 12 additions & 4 deletions llvm/lib/Target/AArch64/SMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,22 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
//===----------------------------------------------------------------------===//

// Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
auto &Ctx = M->getContext();
auto *TPIDR2SaveTy =
FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
auto Attrs = AttributeList().addFnAttribute(M->getContext(),
"aarch64_pstate_sm_compatible");
auto Attrs =
AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
FunctionCallee Callee =
M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
CallInst *Call = Builder.CreateCall(Callee);

// If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
// that on the __arm_tpidr2_save call. This prevents an unnecessary spill of
// ZT0 that can occur before ZA is enabled.
if (ZT0IsUndef)
Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef"));

Call->setCallingConv(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);

Expand Down Expand Up @@ -119,7 +127,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,

// Create a call __arm_tpidr2_save, which commits the lazy save.
Builder.SetInsertPoint(&SaveBB->back());
emitTPIDR2Save(M, Builder);
emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());

// Enable pstate.za at the start of the function.
Builder.SetInsertPoint(&OrigBB->front());
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Body;
if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
Bitmask |= ZA_State_Agnostic;
if (Attrs.hasFnAttr("aarch64_zt0_undef"))
Bitmask |= ZT0_Undef;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ class SMEAttrs {
SM_Body = 1 << 2, // aarch64_pstate_sm_body
SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
ZA_State_Agnostic = 1 << 4,
ZA_Shift = 5,
ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
ZA_Shift = 6,
ZA_Mask = 0b111 << ZA_Shift,
ZT0_Shift = 8,
ZT0_Shift = 9,
ZT0_Mask = 0b111 << ZT0_Shift
};

Expand Down Expand Up @@ -125,14 +126,15 @@ class SMEAttrs {
bool isPreservesZT0() const {
return decodeZT0State(Bitmask) == StateValue::Preserved;
}
bool isUndefZT0() const { return Bitmask & ZT0_Undef; }
bool sharesZT0() const {
StateValue State = decodeZT0State(Bitmask);
return State == StateValue::In || State == StateValue::Out ||
State == StateValue::InOut || State == StateValue::Preserved;
}
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
return hasZT0State() && !Callee.sharesZT0() &&
return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() &&
!Callee.hasAgnosticZAInterface();
}
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
Expand Down
9 changes: 2 additions & 7 deletions llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
Original file line number Diff line number Diff line change
Expand Up @@ -475,16 +475,12 @@ declare double @zt0_shared_callee(double) "aarch64_inout_zt0"
define double @zt0_new_caller_to_zt0_shared_callee(double %x) nounwind noinline optnone "aarch64_new_zt0" {
; CHECK-COMMON-LABEL: zt0_new_caller_to_zt0_shared_callee:
; CHECK-COMMON: // %bb.0: // %prelude
; CHECK-COMMON-NEXT: sub sp, sp, #80
; CHECK-COMMON-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
; CHECK-COMMON-NEXT: cbz x8, .LBB13_2
; CHECK-COMMON-NEXT: b .LBB13_1
; CHECK-COMMON-NEXT: .LBB13_1: // %save.za
; CHECK-COMMON-NEXT: mov x8, sp
; CHECK-COMMON-NEXT: str zt0, [x8]
; CHECK-COMMON-NEXT: bl __arm_tpidr2_save
; CHECK-COMMON-NEXT: ldr zt0, [x8]
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, xzr
; CHECK-COMMON-NEXT: b .LBB13_2
; CHECK-COMMON-NEXT: .LBB13_2: // %entry
Expand All @@ -495,8 +491,7 @@ define double @zt0_new_caller_to_zt0_shared_callee(double %x) nounwind noinline
; CHECK-COMMON-NEXT: fmov d1, x8
; CHECK-COMMON-NEXT: fadd d0, d0, d1
; CHECK-COMMON-NEXT: smstop za
; CHECK-COMMON-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: add sp, sp, #80
; CHECK-COMMON-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-COMMON-NEXT: ret
entry:
%call = call double @zt0_shared_callee(double %x)
Expand Down
14 changes: 14 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s

declare void @callee();

define void @private_za() "aarch64_new_zt0" {
call void @callee()
ret void
}

; CHECK: call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save() #[[TPIDR2_SAVE_CALL_ATTR:[0-9]+]]
; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]]

; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" }
; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_zt0_undef" }
94 changes: 75 additions & 19 deletions llvm/test/CodeGen/AArch64/sme-zt0-state.ll
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_inout_za" "aar
ret void;
}

; New-ZA Callee
; New-ZT0 Callee

; Expect spill & fill of ZT0 around call
; Expect smstop/smstart za around call
Expand All @@ -134,6 +134,72 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
ret void;
}

; New-ZT0 Callee

; Expect commit of lazy-save if ZA is dormant
; Expect smstart ZA & clear ZT0
; Expect spill & fill of ZT0 around call
; Before return, expect smstop ZA
define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller_zt0_new_callee:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB6_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB6_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee() "aarch64_new_zt0";
ret void;
}

; Expect commit of lazy-save if ZA is dormant
; Expect smstart ZA & clear ZT0
; No spill & fill of ZT0 around __arm_tpidr2_save
; Expect spill & fill of ZT0 around __arm_sme_state call
; Before return, expect smstop ZA
define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller_abi_routine_callee:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB7_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB7_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
%res = call {i64, i64} @__arm_sme_state()
%res.0 = extractvalue {i64, i64} %res, 0
ret i64 %res.0
}

declare {i64, i64} @__arm_sme_state()

;
; New-ZA Caller
;
Expand All @@ -144,23 +210,18 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB6_2
; CHECK-NEXT: cbz x8, .LBB8_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: mov x8, sp
; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: ldr zt0, [x8]
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB6_2:
; CHECK-NEXT: .LBB8_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_in_zt0";
ret void;
Expand All @@ -172,24 +233,19 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
; CHECK-LABEL: new_za_zt0_caller:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB7_2
; CHECK-NEXT: cbz x8, .LBB9_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: mov x8, sp
; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: ldr zt0, [x8]
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB7_2:
; CHECK-NEXT: .LBB9_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero {za}
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
Expand Down
3 changes: 3 additions & 0 deletions llvm/test/Verifier/sme-attributes.ll
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@ declare void @zt0_inout_out() "aarch64_inout_zt0" "aarch64_out_zt0";

declare void @zt0_inout_agnostic() "aarch64_inout_zt0" "aarch64_za_state_agnostic";
; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive

declare void @zt0_undef_function() "aarch64_zt0_undef";
; CHECK: Attribute 'aarch64_zt0_undef' can only be applied to a callsite.
30 changes: 30 additions & 0 deletions llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "Utils/AArch64SMEAttributes.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"

Expand Down Expand Up @@ -69,6 +70,15 @@ TEST(SMEAttributes, Constructors) {
ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"")
->getFunction("foo"))
.isNewZT0());
ASSERT_TRUE(
SA(cast<CallBase>((parseIR("declare void @callee()\n"
"define void @foo() {"
"call void @callee() \"aarch64_zt0_undef\"\n"
"ret void\n}")
->getFunction("foo")
->begin()
->front())))
.isUndefZT0());

// Invalid combinations.
EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
Expand Down Expand Up @@ -215,6 +225,18 @@ TEST(SMEAttributes, Basics) {
ASSERT_FALSE(ZT0_New.hasSharedZAInterface());
ASSERT_TRUE(ZT0_New.hasPrivateZAInterface());

SA ZT0_Undef = SA(SA::ZT0_Undef | SA::encodeZT0State(SA::StateValue::New));
ASSERT_TRUE(ZT0_Undef.isNewZT0());
ASSERT_FALSE(ZT0_Undef.isInZT0());
ASSERT_FALSE(ZT0_Undef.isOutZT0());
ASSERT_FALSE(ZT0_Undef.isInOutZT0());
ASSERT_FALSE(ZT0_Undef.isPreservesZT0());
ASSERT_FALSE(ZT0_Undef.sharesZT0());
ASSERT_TRUE(ZT0_Undef.hasZT0State());
ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
ASSERT_TRUE(ZT0_Undef.isUndefZT0());

ASSERT_FALSE(SA(SA::Normal).isInZT0());
ASSERT_FALSE(SA(SA::Normal).isOutZT0());
ASSERT_FALSE(SA(SA::Normal).isInOutZT0());
Expand Down Expand Up @@ -285,6 +307,7 @@ TEST(SMEAttributes, Transitions) {
SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
SA ZA_ZT0_Shared = SA(SA::encodeZAState(SA::StateValue::In) |
SA::encodeZT0State(SA::StateValue::In));
SA Undef_ZT0 = SA(SA::ZT0_Undef);

// Shared ZA -> Private ZA Interface
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
Expand All @@ -295,6 +318,13 @@ TEST(SMEAttributes, Transitions) {
ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));

// Shared Undef ZT0 -> Private ZA Interface
// Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at
// point the of the call.
ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0));
ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0));
ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0));

// Shared ZA & ZT0 -> Private ZA Interface
ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
Expand Down
Loading