Skip to content

[llvm] Support save/restore point splitting in shrink-wrap #119359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
43 changes: 36 additions & 7 deletions llvm/include/llvm/CodeGen/MIRYamlMapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,24 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::MachineJumpTable::Entry)
namespace llvm {
namespace yaml {

struct SRPEntry {
StringValue Point;
std::vector<StringValue> Registers;

bool operator==(const SRPEntry &Other) const {
return Point == Other.Point && Registers == Other.Registers;
}
};

using SaveRestorePoints = std::vector<SRPEntry>;

template <> struct MappingTraits<SRPEntry> {
static void mapping(IO &YamlIO, SRPEntry &Entry) {
YamlIO.mapRequired("point", Entry.Point);
YamlIO.mapRequired("registers", Entry.Registers);
}
};

template <> struct MappingTraits<MachineJumpTable> {
static void mapping(IO &YamlIO, MachineJumpTable &JT) {
YamlIO.mapRequired("kind", JT.Kind);
Expand All @@ -618,6 +636,14 @@ template <> struct MappingTraits<MachineJumpTable> {
}
};

} // namespace yaml
} // namespace llvm

LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::SRPEntry)

namespace llvm {
namespace yaml {

/// Serializable representation of MachineFrameInfo.
///
/// Doesn't serialize attributes like 'StackAlignment', 'IsStackRealignable' and
Expand Down Expand Up @@ -645,8 +671,8 @@ struct MachineFrameInfo {
bool HasTailCall = false;
bool IsCalleeSavedInfoValid = false;
unsigned LocalFrameSize = 0;
StringValue SavePoint;
StringValue RestorePoint;
SaveRestorePoints SavePoints;
SaveRestorePoints RestorePoints;

bool operator==(const MachineFrameInfo &Other) const {
return IsFrameAddressTaken == Other.IsFrameAddressTaken &&
Expand All @@ -667,7 +693,8 @@ struct MachineFrameInfo {
HasMustTailInVarArgFunc == Other.HasMustTailInVarArgFunc &&
HasTailCall == Other.HasTailCall &&
LocalFrameSize == Other.LocalFrameSize &&
SavePoint == Other.SavePoint && RestorePoint == Other.RestorePoint &&
SavePoints == Other.SavePoints &&
RestorePoints == Other.RestorePoints &&
IsCalleeSavedInfoValid == Other.IsCalleeSavedInfoValid;
}
};
Expand Down Expand Up @@ -699,10 +726,12 @@ template <> struct MappingTraits<MachineFrameInfo> {
YamlIO.mapOptional("isCalleeSavedInfoValid", MFI.IsCalleeSavedInfoValid,
false);
YamlIO.mapOptional("localFrameSize", MFI.LocalFrameSize, (unsigned)0);
YamlIO.mapOptional("savePoint", MFI.SavePoint,
StringValue()); // Don't print it out when it's empty.
YamlIO.mapOptional("restorePoint", MFI.RestorePoint,
StringValue()); // Don't print it out when it's empty.
YamlIO.mapOptional(
"savePoints", MFI.SavePoints,
SaveRestorePoints()); // Don't print it out when it's empty.
YamlIO.mapOptional(
"restorePoints", MFI.RestorePoints,
SaveRestorePoints()); // Don't print it out when it's empty.
}
};

Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/MachineDominators.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ class MachineDominatorTree : public DomTreeBase<MachineBasicBlock> {
return Base::findNearestCommonDominator(A, B);
}

/// Returns the nearest common dominator of the given blocks.
/// If that tree node is a virtual root, a nullptr will be returned.
MachineBasicBlock *
findNearestCommonDominator(ArrayRef<MachineBasicBlock *> Blocks) const;

MachineDomTreeNode *operator[](MachineBasicBlock *BB) const {
applySplitCriticalEdges();
return Base::getNode(BB);
Expand Down
140 changes: 131 additions & 9 deletions llvm/include/llvm/CodeGen/MachineFrameInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ class MachineBasicBlock;
class BitVector;
class AllocaInst;

using SaveRestorePoints = DenseMap<MachineBasicBlock *, std::vector<Register>>;

class CalleeSavedInfoPerBB {
DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> Map;

public:
std::vector<CalleeSavedInfo> get(MachineBasicBlock *MBB) const {
return Map.lookup(MBB);
}

void set(DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
Map = std::move(CSI);
}
};

/// The CalleeSavedInfo class tracks the information need to locate where a
/// callee saved register is in the current frame.
/// Callee saved reg can also be saved to a different register rather than
Expand All @@ -37,6 +52,8 @@ class CalleeSavedInfo {
int FrameIdx;
unsigned DstReg;
};
std::vector<MachineBasicBlock *> SpilledIn;
std::vector<MachineBasicBlock *> RestoredIn;
/// Flag indicating whether the register is actually restored in the epilog.
/// In most cases, if a register is saved, it is also restored. There are
/// some situations, though, when this is not the case. For example, the
Expand All @@ -58,9 +75,9 @@ class CalleeSavedInfo {
explicit CalleeSavedInfo(unsigned R, int FI = 0) : Reg(R), FrameIdx(FI) {}

// Accessors.
Register getReg() const { return Reg; }
int getFrameIdx() const { return FrameIdx; }
unsigned getDstReg() const { return DstReg; }
Register getReg() const { return Reg; }
int getFrameIdx() const { return FrameIdx; }
unsigned getDstReg() const { return DstReg; }
void setFrameIdx(int FI) {
FrameIdx = FI;
SpilledToReg = false;
Expand All @@ -72,6 +89,16 @@ class CalleeSavedInfo {
bool isRestored() const { return Restored; }
void setRestored(bool R) { Restored = R; }
bool isSpilledToReg() const { return SpilledToReg; }
ArrayRef<MachineBasicBlock *> spilledIn() const { return SpilledIn; }
ArrayRef<MachineBasicBlock *> restoredIn() const { return RestoredIn; }
void addSpilledIn(MachineBasicBlock *MBB) { SpilledIn.push_back(MBB); }
void addRestoredIn(MachineBasicBlock *MBB) { RestoredIn.push_back(MBB); }
void setSpilledIn(std::vector<MachineBasicBlock *> BBV) {
SpilledIn = std::move(BBV);
}
void setRestoredIn(std::vector<MachineBasicBlock *> BBV) {
RestoredIn = std::move(BBV);
}
};

/// The MachineFrameInfo class represents an abstract stack frame until
Expand Down Expand Up @@ -295,6 +322,10 @@ class MachineFrameInfo {
/// Has CSInfo been set yet?
bool CSIValid = false;

CalleeSavedInfoPerBB CSInfoPerSave;

CalleeSavedInfoPerBB CSInfoPerRestore;

/// References to frame indices which are mapped
/// into the local frame allocation block. <FrameIdx, LocalOffset>
SmallVector<std::pair<int, int64_t>, 32> LocalFrameObjects;
Expand Down Expand Up @@ -331,9 +362,16 @@ class MachineFrameInfo {
bool HasTailCall = false;

/// Not null, if shrink-wrapping found a better place for the prologue.
MachineBasicBlock *Save = nullptr;
MachineBasicBlock *Prolog = nullptr;
/// Not null, if shrink-wrapping found a better place for the epilogue.
MachineBasicBlock *Restore = nullptr;
MachineBasicBlock *Epilog = nullptr;

/// Not empty, if shrink-wrapping found a better place for saving callee
/// saves.
SaveRestorePoints SavePoints;
/// Not empty, if shrink-wrapping found a better place for restoring callee
/// saves.
SaveRestorePoints RestorePoints;

/// Size of the UnsafeStack Frame
uint64_t UnsafeStackSize = 0;
Expand Down Expand Up @@ -809,21 +847,105 @@ class MachineFrameInfo {
/// \copydoc getCalleeSavedInfo()
std::vector<CalleeSavedInfo> &getCalleeSavedInfo() { return CSInfo; }

/// Returns callee saved info vector for provided save point in
/// the current function.
std::vector<CalleeSavedInfo> getCSInfoPerSave(MachineBasicBlock *MBB) const {
return CSInfoPerSave.get(MBB);
}

/// Returns callee saved info vector for provided restore point
/// in the current function.
std::vector<CalleeSavedInfo>
getCSInfoPerRestore(MachineBasicBlock *MBB) const {
return CSInfoPerRestore.get(MBB);
}

/// Used by prolog/epilog inserter to set the function's callee saved
/// information.
void setCalleeSavedInfo(std::vector<CalleeSavedInfo> CSI) {
CSInfo = std::move(CSI);
}

/// Used by prolog/epilog inserter to set the function's callee saved
/// information for particular save point.
void setCSInfoPerSave(
DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
CSInfoPerSave.set(CSI);
}

/// Used by prolog/epilog inserter to set the function's callee saved
/// information for particular restore point.
void setCSInfoPerRestore(
DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
CSInfoPerRestore.set(CSI);
}

/// Has the callee saved info been calculated yet?
bool isCalleeSavedInfoValid() const { return CSIValid; }

void setCalleeSavedInfoValid(bool v) { CSIValid = v; }

MachineBasicBlock *getSavePoint() const { return Save; }
void setSavePoint(MachineBasicBlock *NewSave) { Save = NewSave; }
MachineBasicBlock *getRestorePoint() const { return Restore; }
void setRestorePoint(MachineBasicBlock *NewRestore) { Restore = NewRestore; }
const SaveRestorePoints &getRestorePoints() const { return RestorePoints; }

const SaveRestorePoints &getSavePoints() const { return SavePoints; }

std::pair<MachineBasicBlock *, std::vector<Register>>
getRestorePoint(MachineBasicBlock *MBB) const {
if (auto It = RestorePoints.find(MBB); It != RestorePoints.end())
return *It;

std::vector<Register> Regs = {};
return std::make_pair(nullptr, Regs);
}

std::pair<MachineBasicBlock *, std::vector<Register>>
getSavePoint(MachineBasicBlock *MBB) const {
if (auto It = SavePoints.find(MBB); It != SavePoints.end())
return *It;

std::vector<Register> Regs = {};
return std::make_pair(nullptr, Regs);
}

void setSavePoints(SaveRestorePoints NewSavePoints) {
SavePoints = std::move(NewSavePoints);
}

void setRestorePoints(SaveRestorePoints NewRestorePoints) {
RestorePoints = std::move(NewRestorePoints);
}

void setSavePoint(MachineBasicBlock *MBB, std::vector<Register> &Regs) {
if (SavePoints.contains(MBB))
SavePoints[MBB] = Regs;
else
SavePoints.insert(std::make_pair(MBB, Regs));
}

static const SaveRestorePoints constructSaveRestorePoints(
const SaveRestorePoints &SRP,
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &BBMap) {
SaveRestorePoints Pts{};
for (auto &Src : SRP) {
Pts.insert(std::make_pair(BBMap.find(Src.first)->second, Src.second));
}
return Pts;
}

void setRestorePoint(MachineBasicBlock *MBB, std::vector<Register> &Regs) {
if (RestorePoints.contains(MBB))
RestorePoints[MBB] = Regs;
else
RestorePoints.insert(std::make_pair(MBB, Regs));
}

MachineBasicBlock *getProlog() const { return Prolog; }
void setProlog(MachineBasicBlock *BB) { Prolog = BB; }
MachineBasicBlock *getEpilog() const { return Epilog; }
void setEpilog(MachineBasicBlock *BB) { Epilog = BB; }

void clearSavePoints() { SavePoints.clear(); }
void clearRestorePoints() { RestorePoints.clear(); }

uint64_t getUnsafeStackSize() const { return UnsafeStackSize; }
void setUnsafeStackSize(uint64_t Size) { UnsafeStackSize = Size; }
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetFrameLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ class TargetFrameLowering {
return false;
}

/// enableCSRSaveRestorePointsSplit - Returns true if the target support
/// multiple save/restore points in shrink wrapping.
virtual bool enableCSRSaveRestorePointsSplit() const { return false; }

/// Returns true if the stack slot holes in the fixed and callee-save stack
/// area should be used when allocating other stack locations to reduce stack
/// size.
Expand Down
55 changes: 41 additions & 14 deletions llvm/lib/CodeGen/MIRParser/MIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class MIRParserImpl {
bool initializeFrameInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineFunction &YamlMF);

bool initializeSaveRestorePoints(PerFunctionMIParsingState &PFS,
const yaml::SaveRestorePoints &YamlSRP,
bool IsSavePoints);

bool initializeCallSiteInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineFunction &YamlMF);

Expand Down Expand Up @@ -832,18 +836,9 @@ bool MIRParserImpl::initializeFrameInfo(PerFunctionMIParsingState &PFS,
MFI.setHasTailCall(YamlMFI.HasTailCall);
MFI.setCalleeSavedInfoValid(YamlMFI.IsCalleeSavedInfoValid);
MFI.setLocalFrameSize(YamlMFI.LocalFrameSize);
if (!YamlMFI.SavePoint.Value.empty()) {
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, YamlMFI.SavePoint))
return true;
MFI.setSavePoint(MBB);
}
if (!YamlMFI.RestorePoint.Value.empty()) {
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, YamlMFI.RestorePoint))
return true;
MFI.setRestorePoint(MBB);
}
initializeSaveRestorePoints(PFS, YamlMFI.SavePoints, true /*IsSavePoints*/);
initializeSaveRestorePoints(PFS, YamlMFI.RestorePoints,
false /*IsSavePoints*/);

std::vector<CalleeSavedInfo> CSIInfo;
// Initialize the fixed frame objects.
Expand Down Expand Up @@ -1058,8 +1053,40 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
return false;
}

bool MIRParserImpl::initializeJumpTableInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineJumpTable &YamlJTI) {
bool MIRParserImpl::initializeSaveRestorePoints(
PerFunctionMIParsingState &PFS, const yaml::SaveRestorePoints &YamlSRP,
bool IsSavePoints) {
SMDiagnostic Error;
MachineFunction &MF = PFS.MF;
MachineFrameInfo &MFI = MF.getFrameInfo();
llvm::SaveRestorePoints SRPoints;

for (const auto &Entry : YamlSRP) {
const auto &MBBSource = Entry.Point;
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, MBBSource.Value))
return true;

std::vector<Register> Registers{};
for (auto &RegStr : Entry.Registers) {
Register Reg;
if (parseNamedRegisterReference(PFS, Reg, RegStr.Value, Error))
return error(Error, RegStr.SourceRange);

Registers.push_back(Reg);
}
SRPoints.insert(std::make_pair(MBB, Registers));
}

if (IsSavePoints)
MFI.setSavePoints(SRPoints);
else
MFI.setRestorePoints(SRPoints);
return false;
}

bool MIRParserImpl::initializeJumpTableInfo(
PerFunctionMIParsingState &PFS, const yaml::MachineJumpTable &YamlJTI) {
MachineJumpTableInfo *JTI = PFS.MF.getOrCreateJumpTableInfo(YamlJTI.Kind);
for (const auto &Entry : YamlJTI.Entries) {
std::vector<MachineBasicBlock *> Blocks;
Expand Down
Loading
Loading